Skip to content

Commit

Permalink
Runtime, output format, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper committed Nov 25, 2024
1 parent 6e03579 commit 74c7e5b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
47 changes: 41 additions & 6 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,30 @@ def set_table(self, key, table):
self._tables[key] = table


class _ParallelExecutionRunnableResult:
def __init__(self, data, runtime):
self.data = data
self.runtime = runtime


class ParallelExecutionRunnable:
"""
Runnable to be run by a ParallelExecution step. Subclasses must assign execution_mechanism with one of:
* "multiprocessing" – To run in a separate process. This is appropriate for CPU or GPU intensive tasks as they
would otherwise block the main process by holding Python's Global Interpreter Lock (GIL).
* "threading" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise
block the main event loop thread.
* "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event
loop to continue running while waiting for a response.
* "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It
means that the runnable will not actually be run in parallel to anything else.
Subclasses must also override the run() method with user code that handles the event and returns a result.
Subclasses may optionally override the init() method if the user's implementation of run() requires prior
initialization.
"""

execution_mechanism = None

def __init__(self, name):
Expand All @@ -1452,6 +1475,18 @@ def init(self):
def run(self, event):
return event

def _run(self, event):
start = time.monotonic()
data = self.run(event)
end = time.monotonic()
return _ParallelExecutionRunnableResult(data, end - start)

async def _async_run(self, event):
start = time.monotonic()
data = await self.run(event)
end = time.monotonic()
return _ParallelExecutionRunnableResult(data, end - start)


class ParallelExecution(Flow):
"""
Expand Down Expand Up @@ -1516,16 +1551,16 @@ async def _do(self, event):
raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'")
runnables_encountered.add(id(runnable))
if runnable.execution_mechanism == "asyncio":
future = asyncio.get_running_loop().create_task(runnable.run(event))
future = asyncio.get_running_loop().create_task(runnable._async_run(event))
elif runnable.execution_mechanism == "naive":
future = asyncio.get_running_loop().create_future()
future.set_result(runnable.run(event))
future.set_result(runnable._run(event))
else:
executor = self._executors[runnable.execution_mechanism]
future = asyncio.get_running_loop().run_in_executor(executor, runnable.run, event)
future = asyncio.get_running_loop().run_in_executor(executor, runnable._run, event)
futures.append(future)
results = await asyncio.gather(*futures)
event.body = {"inputs": event.body, "outputs": {}}
results: list[_ParallelExecutionRunnableResult] = await asyncio.gather(*futures)
event.body = {"input": event.body, "results": {}}
for index, result in enumerate(results):
event.body["outputs"][self.runnables[index].name] = result
event.body["results"][self.runnables[index].name] = {"runtime": result.runtime, "output": result.data}
return await self._do_downstream(event)
16 changes: 9 additions & 7 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4809,13 +4809,15 @@ def select_runnables(self, event):
controller = source.run()
controller.emit(0)
controller.terminate()
result = controller.await_termination()
termination_result = controller.await_termination()
end = time.monotonic()

assert end - start < 6
assert result == [
{
"inputs": 0,
"outputs": {"busy1": 1, "busy2": 1, "sleep1": 1, "sleep2": 1, "asleep1": 1, "asleep2": 1, "naive": 1},
}
]
termination_result = termination_result[0]
assert termination_result.keys() == {"input", "results"}
assert termination_result["input"] == 0
results = termination_result["results"]
assert results.keys() == {"busy1", "busy2", "sleep1", "sleep2", "asleep1", "asleep2", "naive"}
for result in results.values():
assert result["output"] == 1
assert 1 < result["runtime"] < 2

0 comments on commit 74c7e5b

Please sign in to comment.