module medagentgym.utils
This Python module implements a Ray-based parallel execution framework for running experiments with built-in timeout management and dependency handling. The module ensures tasks that exceed allocated execution times are reliably cancelled, ensuring efficient utilization of computational resources.
function run_exp
run_exp(exp_arg, *dependencies, avg_step_timeout=60)
run_exp(exp_arg, *dependencies, avg_step_timeout=60)
Purpose: Execute a given experimental argument (exp_arg) within the Ray distributed computing framework.
Arguments:
exp_args
: The experiment argument object, which must implement arun()
method.*dependencies
: Any dependencies that must complete before the experiment runs.avg_step_timeout
(default=60): Average time per step, used to compute timeouts (currently commented out).
Returns:
Result of executing exp_arg.run()
.
Usage:
Typically invoked as a Ray remote task.
function parse_and_truncate_error
parse_and_truncate_error(error_msg: str) -> str
Purpose: Clean and truncate error messages to ensure concise and informative outputs.
Arguments:
error_msg
: Original error message string.
Returns:
- Cleaned and truncated error message string.
function _episode_timeout
_episode_timeout(exp_arg, avg_step_timeout=60)
Purpose: Calculate the maximum allowable runtime for an experiment based on provided arguments.
Arguments:
exp_arg
: Experiment argument object, possibly containing environment parameters.avg_step_timeout
(default=60): Average allowable time per experimental step.
Logic:
- Uses
exp_args.env_args.max_steps
if available; otherwise, defaults to a global timeout of 10 hours.
Returns:
- Timeout value in seconds.
function poll_for_timeout
poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_interval: float = 1.0)
Purpose: Monitor running tasks, cancelling any task exceeding the specified timeout period.
Arguments:
tasks
: Dictionary mapping task IDs to their Ray object references.timeout
: Maximum allowable execution time in seconds for any task.poll_interval
(default=1.0): Interval for checking task status.
Operation:
- regularly polls tasks using
ray.wait()
. - Checks elapsed time for running tasks, invoking cancellable or forced termination if timeout is exceeded.
Returns:
Dictionary mapping task IDs to their respective execution results or exceptions.
function get_elapsed_time
get_elapsed_time(task_ref: ray.ObjectRef)
Purpose: Retrieve the elapsed execution time for a given Ray task.
Arguments:
task_ref
: Ray object reference of the task.
Returns:
- Elapsed time in seconds if task has started; otherwise, None.
Details:
- Utilizes Ray state API to determine task start time.
function execute_task_graph
execute_task_graph(exp_args_list, avg_step_timeout=60)
Purpose: Orchestrate execution of a graph of tasks, respecting task dependencies and managing overall execution timeout.
Arguments:
exp_args_list
: List of experiment argument objects, each potentially specifying dependencies (depends_on
).avg_step_timeout
(default=60): Average step timeout for each task.
Workflow:
- Constructs task dependencies into a map.
- Initiates Ray remote tasks respecting their dependencies.
- Calculates maximum global timeout across all tasks and monitors execution.
Returns:
- Results dictionary from executed tasks, keyed by their experiment IDs.