Skip to content

Commit

Permalink
Merge pull request #195 from ServiceNow/fix-daemonic-process-issue
Browse files Browse the repository at this point in the history
Implement parallel processing for studies using ProcessPoolExecutor a…
  • Loading branch information
recursix authored Jan 6, 2025
2 parents 78ad38f + f7a55d7 commit 73baabe
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
39 changes: 39 additions & 0 deletions src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ProcessPoolExecutor
import gzip
import logging
import os
Expand Down Expand Up @@ -498,6 +499,8 @@ def _init_worker(server_queue: Queue):
A queue of object implementing BaseServer to initialize (or anything with a init
method).
"""
print("initializing server instance with on process", os.getpid())
print(f"using queue {server_queue}")
server_instance = server_queue.get() # type: "WebArenaInstanceVars"
logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}")
server_instance.init()
Expand All @@ -510,6 +513,42 @@ def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n

@dataclass
class ParallelStudies(SequentialStudies):
parallel_servers: list[BaseServer] | int = None

def _run(
self,
n_jobs=1,
parallel_backend="ray",
strict_reproducibility=False,
n_relaunch=3,
):
parallel_servers = self.parallel_servers
if isinstance(parallel_servers, int):
parallel_servers = [BaseServer() for _ in range(parallel_servers)]

server_queue = Manager().Queue()
for server in parallel_servers:
server_queue.put(server)

with ProcessPoolExecutor(
max_workers=len(parallel_servers), initializer=_init_worker, initargs=(server_queue,)
) as executor:
# Create list of arguments for each study
study_args = [
(study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
for study in self.studies
]

# Submit all tasks and wait for completion
futures = [executor.submit(_run_study, *args) for args in study_args]

# Wait for all futures to complete and raise any exceptions
for future in futures:
future.result()


@dataclass
class ParallelStudies_alt(SequentialStudies):

parallel_servers: list[BaseServer] | int = None

Expand Down
11 changes: 10 additions & 1 deletion tests/experiments/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from agentlab.llm.chat_api import CheatMiniWoBLLMArgs
from agentlab.experiments.study import ParallelStudies, make_study, Study
from agentlab.experiments.multi_server import WebArenaInstanceVars
import logging


logging.getLogger().setLevel(logging.INFO)


def _make_agent_args_list():
Expand All @@ -28,13 +32,18 @@ def manual_test_launch_parallel_study_webarena():
server_instance_2 = server_instance_1.clone()
server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com"
parallel_servers = [server_instance_1, server_instance_2]
# parallel_servers = [server_instance_2]

for server in parallel_servers:
print(server)

study = make_study(
agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers
agent_args_list,
benchmark="webarena_tiny",
parallel_servers=parallel_servers,
ignore_dependencies=True,
)
study.override_max_steps(2)
assert isinstance(study, ParallelStudies)

study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)
Expand Down

0 comments on commit 73baabe

Please sign in to comment.