module medagentgym.utils

This Python module implements a Ray-based parallel execution framework for running experiments with built-in timeout management and dependency handling. The module ensures tasks that exceed allocated execution times are reliably cancelled, ensuring efficient utilization of computational resources.


function run_exp

run_exp(exp_arg, *dependencies, avg_step_timeout=60)

Purpose: Execute a given experimental argument (exp_arg) within the Ray distributed computing framework.

Arguments:

  • exp_args: The experiment argument object, which must implement a run() method.
  • *dependencies: Any dependencies that must complete before the experiment runs.
  • avg_step_timeout (default=60): Average time per step, used to compute timeouts (currently commented out).

Returns:

Result of executing exp_arg.run().

Usage:

Typically invoked as a Ray remote task.


function parse_and_truncate_error

parse_and_truncate_error(error_msg: str) -> str

Purpose: Clean and truncate error messages to ensure concise and informative outputs.

Arguments:

  • error_msg: Original error message string.

Returns:

  • Cleaned and truncated error message string.

function _episode_timeout

_episode_timeout(exp_arg, avg_step_timeout=60)

Purpose: Calculate the maximum allowable runtime for an experiment based on provided arguments.

Arguments:

  • exp_arg: Experiment argument object, possibly containing environment parameters.
  • avg_step_timeout (default=60): Average allowable time per experimental step.

Logic:

  • Uses exp_args.env_args.max_steps if available; otherwise, defaults to a global timeout of 10 hours.

Returns:

  • Timeout value in seconds.

function poll_for_timeout

poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_interval: float = 1.0)

Purpose: Monitor running tasks, cancelling any task exceeding the specified timeout period.

Arguments:

  • tasks: Dictionary mapping task IDs to their Ray object references.
  • timeout: Maximum allowable execution time in seconds for any task.
  • poll_interval (default=1.0): Interval for checking task status.

Operation:

  • regularly polls tasks using ray.wait().
  • Checks elapsed time for running tasks, invoking cancellable or forced termination if timeout is exceeded.

Returns:

Dictionary mapping task IDs to their respective execution results or exceptions.


function get_elapsed_time

get_elapsed_time(task_ref: ray.ObjectRef)

Purpose: Retrieve the elapsed execution time for a given Ray task.

Arguments:

  • task_ref: Ray object reference of the task.

Returns:

  • Elapsed time in seconds if task has started; otherwise, None.

Details:

  • Utilizes Ray state API to determine task start time.

function execute_task_graph

execute_task_graph(exp_args_list, avg_step_timeout=60)

Purpose: Orchestrate execution of a graph of tasks, respecting task dependencies and managing overall execution timeout.

Arguments:

  • exp_args_list: List of experiment argument objects, each potentially specifying dependencies (depends_on).
  • avg_step_timeout (default=60): Average step timeout for each task.

Workflow:

  • Constructs task dependencies into a map.
  • Initiates Ray remote tasks respecting their dependencies.
  • Calculates maximum global timeout across all tasks and monitors execution.

Returns:

  • Results dictionary from executed tasks, keyed by their experiment IDs.