diff --git a/poetry.lock b/poetry.lock index abb30f73..c43d0fc6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -765,6 +765,17 @@ deprecation = ">=2.0,<3.0" [package.extras] pydantic = ["pydantic (>=1.0.0,<3.0)"] +[[package]] +name = "cloudpickle" +version = "3.0.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, + {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, +] + [[package]] name = "colorama" version = "0.4.6" @@ -2584,6 +2595,20 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] +[[package]] +name = "loky" +version = "3.4.1" +description = "A robust implementation of concurrent.futures.ProcessPoolExecutor" +optional = false +python-versions = ">=3.7" +files = [ + {file = "loky-3.4.1-py3-none-any.whl", hash = "sha256:7132da80d1a057b5917ff32c7867b65ed164aae84c259a1dbc44375791280c87"}, + {file = "loky-3.4.1.tar.gz", hash = "sha256:66db350de68c301299c882ace3b8f06ba5c4cb2c45f8fcffd498160ce8280753"}, +] + +[package.dependencies] +cloudpickle = "*" + [[package]] name = "lru-dict" version = "1.2.0" @@ -3826,6 +3851,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs optional = false python-versions = ">=3.8" files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] @@ -3836,6 +3862,7 @@ description = "A collection of ASN.1-based protocols modules" optional = false python-versions = ">=3.8" files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, ] @@ -5765,4 +5792,4 @@ openai = ["openai"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "71f54ad45b96fb9bf3163d9e4da88a75332de81bae9bf319e73d48b056b1126c" +content-hash = "658abc9433182dac32862d8925b9fd262609b55adf55ecc7bbbed1af39709883" diff --git a/prediction_market_agent_tooling/loggers.py b/prediction_market_agent_tooling/loggers.py index a44da36f..c8cc2df0 100644 --- a/prediction_market_agent_tooling/loggers.py +++ b/prediction_market_agent_tooling/loggers.py @@ -49,6 +49,11 @@ def patch_logger() -> None: Function to patch loggers according to the deployed environment. Patches Loguru's logger, Python's default logger, warnings library and also monkey-patch print function as many libraries just use it. """ + if not getattr(logger, "_patched", False): + logger._patched = True # type: ignore[attr-defined] # Hacky way to store a flag on the logger object, to not patch it multiple times. + else: + return + config = LogConfig() if config.LOG_FORMAT == LogFormat.GCP: @@ -116,6 +121,4 @@ def simple_warning_format(message, category, filename, lineno, line=None): # ty ) # Escape new lines, because otherwise logs will be broken. -if not getattr(logger, "_patched", False): - patch_logger() - logger._patched = True # type: ignore[attr-defined] # Hacky way to store a flag on the logger object, to not patch it multiple times. +patch_logger() diff --git a/prediction_market_agent_tooling/tools/parallelism.py b/prediction_market_agent_tooling/tools/parallelism.py index bbeca891..26e80f4e 100644 --- a/prediction_market_agent_tooling/tools/parallelism.py +++ b/prediction_market_agent_tooling/tools/parallelism.py @@ -1,12 +1,8 @@ -import concurrent -from concurrent.futures import Executor -from concurrent.futures.process import ProcessPoolExecutor -from concurrent.futures.thread import ThreadPoolExecutor from typing import Callable, Generator, TypeVar -# Max workers to 5 to avoid rate limiting on some APIs, create a custom executor if you need more workers. -DEFAULT_THREADPOOL_EXECUTOR = ThreadPoolExecutor(max_workers=5) -DEFAULT_PROCESSPOOL_EXECUTOR = ProcessPoolExecutor(max_workers=5) +from loky import get_reusable_executor + +from prediction_market_agent_tooling.loggers import patch_logger A = TypeVar("A") B = TypeVar("B") @@ -15,14 +11,11 @@ def par_map( items: list[A], func: Callable[[A], B], - executor: Executor = DEFAULT_THREADPOOL_EXECUTOR, + max_workers: int = 5, ) -> "list[B]": - """Applies the function to each element using the specified executor. Awaits for all results. - If executor is ProcessPoolExecutor, make sure the function passed is pickable, e.g. no lambda functions - """ - futures: list[concurrent.futures._base.Future[B]] = [ - executor.submit(func, item) for item in items - ] + """Applies the function to each element using the specified executor. Awaits for all results.""" + executor = get_reusable_executor(max_workers=max_workers, initializer=patch_logger) + futures = [executor.submit(func, item) for item in items] results = [] for fut in futures: results.append(fut.result()) @@ -32,13 +25,9 @@ def par_map( def par_generator( items: list[A], func: Callable[[A], B], - executor: Executor = DEFAULT_THREADPOOL_EXECUTOR, + max_workers: int = 5, ) -> Generator[B, None, None]: - """Applies the function to each element using the specified executor. Yields results as they come. - If executor is ProcessPoolExecutor, make sure the function passed is pickable, e.g. no lambda functions. - """ - futures: list[concurrent.futures._base.Future[B]] = [ - executor.submit(func, item) for item in items - ] - for fut in concurrent.futures.as_completed(futures): - yield fut.result() + """Applies the function to each element using the specified executor. Yields results as they come.""" + executor = get_reusable_executor(max_workers=max_workers, initializer=patch_logger) + for res in executor.map(func, items): + yield res diff --git a/pyproject.toml b/pyproject.toml index 139b1e4f..8fc3f2e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "prediction-market-agent-tooling" -version = "0.48.16" +version = "0.48.17" description = "Tools to benchmark, deploy and monitor prediction market agents." authors = ["Gnosis"] readme = "README.md" @@ -47,6 +47,7 @@ tavily-python = "^0.3.9" sqlmodel = "^0.0.21" psycopg2-binary = "^2.9.9" base58 = ">=1.0.2,<2.0" +loky = "^3.4.1" [tool.poetry.extras] openai = ["openai"] diff --git a/tests/tools/test_parallelism.py b/tests/tools/test_parallelism.py new file mode 100644 index 00000000..6f830edf --- /dev/null +++ b/tests/tools/test_parallelism.py @@ -0,0 +1,15 @@ +from prediction_market_agent_tooling.tools.parallelism import par_generator, par_map + + +def test_par_map() -> None: + l = list(range(100)) + f = lambda x: x**2 + results = par_map(l, f, max_workers=5) + assert [f(x) for x in l] == results + + +def test_par_generator() -> None: + l = list(range(100)) + f = lambda x: x**2 + results = par_generator(l, f, max_workers=5) + assert [f(x) for x in l] == sorted(results)