module medagentgym.env

Versatile Output Layer for Backbone Models

This module implements an output layer that can be applied to various backbone models. It provides the option of utilizing different uncertainty methods for model's output and supports both classification and regression tasks.


function

This Python module provides a suite of utility functions intended to facilitate interactions and operations within a medical reasoning environment. It encompasses functionalities ranging from information retrieval from data sources, code execution and validation, debugging assistance, and terminal command execution.

function send_msg_to_user

send_msg_to_user(text: str)

Sends a user-facing message. Useful for providing informative updates or responses based on operations performed by the agent.

Example:

send_msg_to_user("Based on the results of my search, the city was built in 1960.")

function report_infeasible

report_infeasible(text: str)

Reports instructions that cannot be executed, notifying the user.

Example:

send_msg_to_user("Based on the results of my search, the city was built in 1960.")

function request_info

request_info(data_path: str, info_type: str, keyterm: str) -> str

Retrieves specific information from a given data source, primarily CSV files.

Arguments:

  • data_path: Path to data (file or directory).
  • info_type: Type of requested info (e.g., 'column_names', 'column_values', or 'term').
  • keyterm: Keyword or column to search.

Returns:

  • Output: A formatted string containing relevant information or suggestions for next steps.

Example:

request_info("./mimic_iii/", "term", "aspirin")

function _run_code_file

_run_code_file(code_file: Path, timeout: Optional[int]) -> Tuple[int, str, str, float]

Executes a Python file with a specified timeout and provides detailed execution feedback.

Arguments:

  • code_file: Path to the Python script to execute.
  • timeout: Maximum execution time in seconds.

Returns:

  • Tuple containing return code, stdout, stderr, and execution duration.

function validate_code

validate_code(code: str) -> Dict[str, Any]

Validates Python code by compilation and execution, providing comprehensive results or errors.

Arguments:

  • code: Python code to validate.

Example:

validate_code("import numpy as np\na=[1,2,3]\nans=np.mean(a)")

function debug

debug(code, error_msg, debugger, history) -> Dict[str, Any]

Facilitates debugging of Python code by summarizing context and suggesting solutions using a provided debugging interface.

Arguments:

  • code: Python code snippet to debug.
  • error_msg: Error message encountered during code execution.
  • debugger: Debugging interface to use for suggestions.
  • history: History of previous interactions for context.

Returns:

  • Dictionary containing the original code, error message, and debugging suggestions.

function terminal

terminal(cmd: str) -> Dict[str, Any]

Executes a command in the system terminal and provides feedback.

Arguments:

  • cmd: Terminal command to execute.

Returns:

Dictionary containing command execution status, output message, and execution time.

Example:

terminal("ls -l")

class MedAgentEnv

MedAgentEnv is a Gymnasium environment class designed for encapsulating coding-based medical reasoning tasks. It supports interactions through a structured action-observation loop, enabling reinforcement learning agents to execute tasks based on medical data.

Class Definition:

MedAgentEnv(gym.Env, ABC)

  • Inherits: gym.Env, ABC
  • Purpose: Provides a structured Gymnasium environment tailored for coding-based biomedical reasoning tasks.

Initialization Parameters:

  • task_entrypoint (type[AbstractMedAgentTask]): A class or callable returning a task instance based on provided task IDs.
  • task_kwargs (dict): Dictionary containing task-specific configurations (e.g., data paths, debugger settings).
  • action_mapping (Optional[Callable]): Function mapping action strings to executable Python functions (default: BasicActionSet().action_set).

Attributes:

  • observation_space (gym.spaces.Dict): Defines the structure and type of the environment's observations, including chat messages, task goals, historical actions, and elapsed time.
  • action_space (gym.spaces.Unicode): Specifies the action type expected by the environment.
  • chat (Chat): Handles the interaction history between the agent and the environment.
  • terminate_on_infeasible (bool): Controls automatic termination upon infeasible instructions.
  • debugger: instance of a configured debugging model.
  • env_history (list): Stores the historical sequence of environment observations.
  • need_context (bool): Flag indicating whether contextual information is required for code validation.

function reset

reset(task_id, *args, **kwargs)

Initializes or resets the environment to begin a new task.

Parameters:

  • task_id: Identifier for the specific task to initialize.

Returns:

  • obs: Initial observation dictionary.
  • info: Dictionary containing task goal and initial task information.

function step

step(action: str, **kwargs)

Parameters:

  • action: String specifying the action to execute.

Returns:

  • obs: Resulting observation after executing the action.
  • reward: Numeric reward for the action taken.
  • terminated: Boolean indicating if the episode has terminated.
  • truncated: Boolean indicating if the episode was truncated prematurely.
  • info: Dictionary containing execution details and task-specific feedback.

function _task_validate

_task_validate(obs)

Validates task completion and calculates reward.

Parameters:

  • obs: Current observation after action execution.

Returns:

  • reward: Numeric reward based on task completion criteria.
  • done: Boolean indicating task completion.
  • user_message: Message to the user upon task validation.
  • task_info: Additional validation details.

function _get_obs

_get_obs(results=None)

Generates the observation dictionary for the current environment state.

Parameters:

  • results: Optional data from the latest action execution.

Returns:

  • Observation dictionary including environment messages and feedback.

Usage Example

task_kwargs = {
    "data_path": "./data",
    "debugger_config": {
        "model_name": "gpt-4",
        "temperature": 0.1,
        "max_new_tokens": 512,
        "deployment_name": "debugger-model",
        "log_probs": False
    }
}
env = EHREnv(task_entrypoint=MyEHRTask, task_kwargs=task_kwargs)
obs, info = env.reset(task_id="ehr_task_01")