From f7a55d7e311638edf8106876bc99becacae04f9e Mon Sep 17 00:00:00 2001 From: recursix Date: Fri, 3 Jan 2025 17:41:28 +0000 Subject: [PATCH] Implement parallel processing for studies using ProcessPoolExecutor and enhance logging for server initialization --- src/agentlab/experiments/study.py | 39 +++++++++++++++++++++++++++++++ tests/experiments/test_study.py | 11 ++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index f49ccd0e..c04f1a74 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -1,3 +1,4 @@ +from concurrent.futures import ProcessPoolExecutor import gzip import logging import os @@ -498,6 +499,8 @@ def _init_worker(server_queue: Queue): A queue of object implementing BaseServer to initialize (or anything with a init method). """ + print("initializing server instance with on process", os.getpid()) + print(f"using queue {server_queue}") server_instance = server_queue.get() # type: "WebArenaInstanceVars" logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}") server_instance.init() @@ -510,6 +513,42 @@ def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n @dataclass class ParallelStudies(SequentialStudies): + parallel_servers: list[BaseServer] | int = None + + def _run( + self, + n_jobs=1, + parallel_backend="ray", + strict_reproducibility=False, + n_relaunch=3, + ): + parallel_servers = self.parallel_servers + if isinstance(parallel_servers, int): + parallel_servers = [BaseServer() for _ in range(parallel_servers)] + + server_queue = Manager().Queue() + for server in parallel_servers: + server_queue.put(server) + + with ProcessPoolExecutor( + max_workers=len(parallel_servers), initializer=_init_worker, initargs=(server_queue,) + ) as executor: + # Create list of arguments for each study + study_args = [ + (study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch) + for study in self.studies + ] + + # Submit all tasks and wait for completion + futures = [executor.submit(_run_study, *args) for args in study_args] + + # Wait for all futures to complete and raise any exceptions + for future in futures: + future.result() + + +@dataclass +class ParallelStudies_alt(SequentialStudies): parallel_servers: list[BaseServer] | int = None diff --git a/tests/experiments/test_study.py b/tests/experiments/test_study.py index 0bc24161..7685a135 100644 --- a/tests/experiments/test_study.py +++ b/tests/experiments/test_study.py @@ -4,6 +4,10 @@ from agentlab.llm.chat_api import CheatMiniWoBLLMArgs from agentlab.experiments.study import ParallelStudies, make_study, Study from agentlab.experiments.multi_server import WebArenaInstanceVars +import logging + + +logging.getLogger().setLevel(logging.INFO) def _make_agent_args_list(): @@ -28,13 +32,18 @@ def manual_test_launch_parallel_study_webarena(): server_instance_2 = server_instance_1.clone() server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com" parallel_servers = [server_instance_1, server_instance_2] + # parallel_servers = [server_instance_2] for server in parallel_servers: print(server) study = make_study( - agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers + agent_args_list, + benchmark="webarena_tiny", + parallel_servers=parallel_servers, + ignore_dependencies=True, ) + study.override_max_steps(2) assert isinstance(study, ParallelStudies) study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)