Skip to content

Switch to dask #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 16, 2024
92 changes: 92 additions & 0 deletions src/agentlab/experiments/graph_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from dask import compute, delayed
from dask.distributed import Client
from browsergym.experiments.loop import ExpArgs
import logging


def run_experiments(n_workers, exp_args_list: list[ExpArgs], exp_dir):
"""Run a list of experiments in parallel while respecting dependencies."""

logging.info(f"Saving experiments to {exp_dir}")
for exp_args in exp_args_list:
exp_args.agent_args.prepare()
exp_args.prepare(exp_root=exp_dir)

try:
execute_task_graph(Client(n_workers=n_workers), exp_args_list)
finally:
logging.info("All jobs are finished. Calling agent_args.close() on all agents...")
for exp_args in exp_args_list:
exp_args.agent_args.close()
logging.info("Experiment finished.")


def _run(exp_arg: ExpArgs, *dependencies):
"""Capture dependencies to ensure they are run before the current task."""
return exp_arg.run()


def execute_task_graph(dask_client, exp_args_list: list[ExpArgs]):
"""Execute a task graph in parallel while respecting dependencies."""
exp_args_map = {exp_args.exp_id: exp_args for exp_args in exp_args_list}

with dask_client:
tasks = {}

def get_task(exp_arg: ExpArgs):
if exp_arg.exp_id not in tasks:
dependencies = [get_task(exp_args_map[dep_key]) for dep_key in exp_arg.depends_on]
tasks[exp_arg.exp_id] = delayed(_run)(exp_arg, *dependencies)
return tasks[exp_arg.exp_id]

for exp_arg in exp_args_list:
get_task(exp_arg)

task_ids, task_list = zip(*tasks.items())
results = compute(*task_list)

return {task_id: result for task_id, result in zip(task_ids, results)}


def add_dependencies(exp_args_list: list[ExpArgs], task_dependencies: dict[list] = None):
"""Add dependencies to a list of ExpArgs.

Args:
exp_args_list: list[ExpArgs]
A list of experiments to run.
task_dependencies: dict
A dictionary mapping task names to a list of task names that they
depend on. If None or empty, no dependencies are added.

Returns:
list[ExpArgs]
The modified exp_args_list with dependencies added.
"""

if task_dependencies is None or all([len(dep) == 0 for dep in task_dependencies.values()]):
# nothing to be done
return exp_args_list

exp_args_map = {exp_args.env_args.task_name: exp_args for exp_args in exp_args_list}
if len(exp_args_map) != len(exp_args_list):
raise ValueError(
(
"Task names are not unique in exp_args_map, "
"you can't run multiple seeds with task dependencies."
)
)

for task_name in exp_args_map.keys():
if task_name not in task_dependencies:
raise ValueError(f"Task {task_name} is missing from task_dependencies")

# turn dependencies from task names to exp_ids
for task_name, exp_args in exp_args_map.items():

exp_args.depends_on = tuple(
exp_args_map[dep_name].exp_id
for dep_name in task_dependencies[task_name]
if dep_name in exp_args_map # ignore dependencies that are not to be run
)

return exp_args_list
8 changes: 3 additions & 5 deletions src/agentlab/experiments/launch_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pathlib import Path

from browsergym.experiments.loop import ExpArgs, yield_all_exp_results
from joblib import Parallel, delayed
from agentlab.experiments.graph_execution import execute_task_graph
from dask.distributed import Client


def import_object(path: str):
Expand All @@ -24,10 +25,7 @@ def run_experiments(n_jobs, exp_args_list: list[ExpArgs], exp_dir):
exp_args.prepare(exp_root=exp_dir)

try:
prefer = "processes"
Parallel(n_jobs=n_jobs, prefer=prefer)(
delayed(exp_args.run)() for exp_args in exp_args_list
)
execute_task_graph(Client(n_workers=n_jobs), exp_args_list)
finally:
# will close servers even if there is an exception or ctrl+c
# servers won't be closed if the script is killed with kill -9 or segfaults.
Expand Down
82 changes: 82 additions & 0 deletions tests/experiments/test_graph_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from dask.distributed import Client
import pytest
from agentlab.experiments.graph_execution import execute_task_graph, add_dependencies
from time import time, sleep
from browsergym.experiments.loop import ExpArgs, EnvArgs


# Mock implementation of the ExpArgs class with timestamp checks
class MockedExpArgs:
def __init__(self, exp_id, depends_on=None):
self.exp_id = exp_id
self.depends_on = depends_on if depends_on else []
self.start_time = None
self.end_time = None

def run(self):
self.start_time = time()
sleep(0.5) # Simulate task execution time
self.end_time = time()
return self


def test_execute_task_graph():
# Define a list of ExpArgs with dependencies
exp_args_list = [
MockedExpArgs(exp_id="task1", depends_on=[]),
MockedExpArgs(exp_id="task2", depends_on=["task1"]),
MockedExpArgs(exp_id="task3", depends_on=["task1"]),
MockedExpArgs(exp_id="task4", depends_on=["task2", "task3"]),
]

# Execute the task graph
results = execute_task_graph(Client(n_workers=3), exp_args_list)

exp_args_list = [results[task_id] for task_id in ["task1", "task2", "task3", "task4"]]

# Verify that all tasks were executed in the proper order
assert exp_args_list[0].start_time < exp_args_list[1].start_time
assert exp_args_list[0].start_time < exp_args_list[2].start_time
assert exp_args_list[1].end_time < exp_args_list[3].start_time
assert exp_args_list[2].end_time < exp_args_list[3].start_time

# Verify that parallel tasks (task2 and task3) started within a short time of each other
parallel_start_diff = abs(exp_args_list[1].start_time - exp_args_list[2].start_time)
assert parallel_start_diff < 0.1 # Allow for a small delay

# Ensure that the entire task graph took the expected amount of time
total_time = exp_args_list[-1].end_time - exp_args_list[0].start_time
assert total_time >= 1.5 # Since the critical path involves at least 1.5 seconds of work


def test_add_dependencies():
# Prepare a simple list of ExpArgs

def make_exp_args(task_name, exp_id):
return ExpArgs(agent_args=None, env_args=EnvArgs(task_name=task_name), exp_id=exp_id)

exp_args_list = [
make_exp_args("task1", "1"),
make_exp_args("task2", "2"),
make_exp_args("task3", "3"),
]

# Define simple task_dependencies
task_dependencies = {"task1": ["task2"], "task2": [], "task3": ["task1"]}

# Call the function
modified_list = add_dependencies(exp_args_list, task_dependencies)

# Verify dependencies
assert modified_list[0].depends_on == ("2",) # task1 depends on task2
assert modified_list[1].depends_on == () # task2 has no dependencies
assert modified_list[2].depends_on == ("1",) # task3 depends on task1

# assert raise if task_dependencies is wrong
task_dependencies = {"task1": ["task2"], "task2": [], "task4": ["task3"]}
with pytest.raises(ValueError):
add_dependencies(exp_args_list, task_dependencies)


if __name__ == "__main__":
test_execute_task_graph()