From 74c7e5b3e029b10f905b19687d93bb6b07c611fe Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Mon, 25 Nov 2024 13:39:36 +0800 Subject: [PATCH] Runtime, output format, docs --- storey/flow.py | 47 ++++++++++++++++++++++++++++++++++++++++------ tests/test_flow.py | 16 +++++++++------- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index c0457b82..19b5a4b4 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -1438,7 +1438,30 @@ def set_table(self, key, table): self._tables[key] = table +class _ParallelExecutionRunnableResult: + def __init__(self, data, runtime): + self.data = data + self.runtime = runtime + + class ParallelExecutionRunnable: + """ + Runnable to be run by a ParallelExecution step. Subclasses must assign execution_mechanism with one of: + * "multiprocessing" – To run in a separate process. This is appropriate for CPU or GPU intensive tasks as they + would otherwise block the main process by holding Python's Global Interpreter Lock (GIL). + * "threading" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise + block the main event loop thread. + * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event + loop to continue running while waiting for a response. + * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It + means that the runnable will not actually be run in parallel to anything else. + + Subclasses must also override the run() method with user code that handles the event and returns a result. + + Subclasses may optionally override the init() method if the user's implementation of run() requires prior + initialization. + """ + execution_mechanism = None def __init__(self, name): @@ -1452,6 +1475,18 @@ def init(self): def run(self, event): return event + def _run(self, event): + start = time.monotonic() + data = self.run(event) + end = time.monotonic() + return _ParallelExecutionRunnableResult(data, end - start) + + async def _async_run(self, event): + start = time.monotonic() + data = await self.run(event) + end = time.monotonic() + return _ParallelExecutionRunnableResult(data, end - start) + class ParallelExecution(Flow): """ @@ -1516,16 +1551,16 @@ async def _do(self, event): raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'") runnables_encountered.add(id(runnable)) if runnable.execution_mechanism == "asyncio": - future = asyncio.get_running_loop().create_task(runnable.run(event)) + future = asyncio.get_running_loop().create_task(runnable._async_run(event)) elif runnable.execution_mechanism == "naive": future = asyncio.get_running_loop().create_future() - future.set_result(runnable.run(event)) + future.set_result(runnable._run(event)) else: executor = self._executors[runnable.execution_mechanism] - future = asyncio.get_running_loop().run_in_executor(executor, runnable.run, event) + future = asyncio.get_running_loop().run_in_executor(executor, runnable._run, event) futures.append(future) - results = await asyncio.gather(*futures) - event.body = {"inputs": event.body, "outputs": {}} + results: list[_ParallelExecutionRunnableResult] = await asyncio.gather(*futures) + event.body = {"input": event.body, "results": {}} for index, result in enumerate(results): - event.body["outputs"][self.runnables[index].name] = result + event.body["results"][self.runnables[index].name] = {"runtime": result.runtime, "output": result.data} return await self._do_downstream(event) diff --git a/tests/test_flow.py b/tests/test_flow.py index b7a7780a..7fd5f4a5 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -4809,13 +4809,15 @@ def select_runnables(self, event): controller = source.run() controller.emit(0) controller.terminate() - result = controller.await_termination() + termination_result = controller.await_termination() end = time.monotonic() assert end - start < 6 - assert result == [ - { - "inputs": 0, - "outputs": {"busy1": 1, "busy2": 1, "sleep1": 1, "sleep2": 1, "asleep1": 1, "asleep2": 1, "naive": 1}, - } - ] + termination_result = termination_result[0] + assert termination_result.keys() == {"input", "results"} + assert termination_result["input"] == 0 + results = termination_result["results"] + assert results.keys() == {"busy1", "busy2", "sleep1", "sleep2", "asleep1", "asleep2", "naive"} + for result in results.values(): + assert result["output"] == 1 + assert 1 < result["runtime"] < 2