Skip to content

Commit

Permalink
Use Loky for parallelism (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Sep 23, 2024
1 parent 5fe7dfd commit a6a0a22
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 29 deletions.
31 changes: 29 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions prediction_market_agent_tooling/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def patch_logger() -> None:
Function to patch loggers according to the deployed environment.
Patches Loguru's logger, Python's default logger, warnings library and also monkey-patch print function as many libraries just use it.
"""
if not getattr(logger, "_patched", False):
logger._patched = True # type: ignore[attr-defined] # Hacky way to store a flag on the logger object, to not patch it multiple times.
else:
return

config = LogConfig()

if config.LOG_FORMAT == LogFormat.GCP:
Expand Down Expand Up @@ -116,6 +121,4 @@ def simple_warning_format(message, category, filename, lineno, line=None): # ty
) # Escape new lines, because otherwise logs will be broken.


if not getattr(logger, "_patched", False):
patch_logger()
logger._patched = True # type: ignore[attr-defined] # Hacky way to store a flag on the logger object, to not patch it multiple times.
patch_logger()
35 changes: 12 additions & 23 deletions prediction_market_agent_tooling/tools/parallelism.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import concurrent
from concurrent.futures import Executor
from concurrent.futures.process import ProcessPoolExecutor
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Callable, Generator, TypeVar

# Max workers to 5 to avoid rate limiting on some APIs, create a custom executor if you need more workers.
DEFAULT_THREADPOOL_EXECUTOR = ThreadPoolExecutor(max_workers=5)
DEFAULT_PROCESSPOOL_EXECUTOR = ProcessPoolExecutor(max_workers=5)
from loky import get_reusable_executor

from prediction_market_agent_tooling.loggers import patch_logger

A = TypeVar("A")
B = TypeVar("B")
Expand All @@ -15,14 +11,11 @@
def par_map(
items: list[A],
func: Callable[[A], B],
executor: Executor = DEFAULT_THREADPOOL_EXECUTOR,
max_workers: int = 5,
) -> "list[B]":
"""Applies the function to each element using the specified executor. Awaits for all results.
If executor is ProcessPoolExecutor, make sure the function passed is pickable, e.g. no lambda functions
"""
futures: list[concurrent.futures._base.Future[B]] = [
executor.submit(func, item) for item in items
]
"""Applies the function to each element using the specified executor. Awaits for all results."""
executor = get_reusable_executor(max_workers=max_workers, initializer=patch_logger)
futures = [executor.submit(func, item) for item in items]
results = []
for fut in futures:
results.append(fut.result())
Expand All @@ -32,13 +25,9 @@ def par_map(
def par_generator(
items: list[A],
func: Callable[[A], B],
executor: Executor = DEFAULT_THREADPOOL_EXECUTOR,
max_workers: int = 5,
) -> Generator[B, None, None]:
"""Applies the function to each element using the specified executor. Yields results as they come.
If executor is ProcessPoolExecutor, make sure the function passed is pickable, e.g. no lambda functions.
"""
futures: list[concurrent.futures._base.Future[B]] = [
executor.submit(func, item) for item in items
]
for fut in concurrent.futures.as_completed(futures):
yield fut.result()
"""Applies the function to each element using the specified executor. Yields results as they come."""
executor = get_reusable_executor(max_workers=max_workers, initializer=patch_logger)
for res in executor.map(func, items):
yield res
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "prediction-market-agent-tooling"
version = "0.48.16"
version = "0.48.17"
description = "Tools to benchmark, deploy and monitor prediction market agents."
authors = ["Gnosis"]
readme = "README.md"
Expand Down Expand Up @@ -47,6 +47,7 @@ tavily-python = "^0.3.9"
sqlmodel = "^0.0.21"
psycopg2-binary = "^2.9.9"
base58 = ">=1.0.2,<2.0"
loky = "^3.4.1"

[tool.poetry.extras]
openai = ["openai"]
Expand Down
15 changes: 15 additions & 0 deletions tests/tools/test_parallelism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from prediction_market_agent_tooling.tools.parallelism import par_generator, par_map


def test_par_map() -> None:
l = list(range(100))
f = lambda x: x**2
results = par_map(l, f, max_workers=5)
assert [f(x) for x in l] == results


def test_par_generator() -> None:
l = list(range(100))
f = lambda x: x**2
results = par_generator(l, f, max_workers=5)
assert [f(x) for x in l] == sorted(results)

0 comments on commit a6a0a22

Please sign in to comment.