From b4d859c31a4a73e3bd5beaff892876283653e9b3 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:45:35 -0400 Subject: [PATCH] Support async tasks --- src/controlflow/decorators.py | 24 ++++++++++++++--------- tests/test_decorator.py | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index eeed17ae..5f175948 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -201,18 +201,24 @@ def _get_task(*args, **kwargs) -> Task: **task_kwargs, ) - @functools.wraps(fn) - @prefect_task( + if asyncio.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + task = _get_task(*args, **kwargs) + return await task.run_async() + else: + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + task = _get_task(*args, **kwargs) + return task.run() + + wrapper = prefect_task( timeout_seconds=timeout_seconds, retries=retries, retry_delay_seconds=retry_delay_seconds, - ) - def wrapper( - *args, - **kwargs, - ): - task = _get_task(*args, **kwargs) - return task.run() + )(wrapper) # store the `as_task` method for loading the task object wrapper.as_task = _get_task diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 0b07731b..468b652f 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -102,3 +102,39 @@ def partial_flow(): result = partial_flow() assert result == 10 + + +class TestTaskDecorator: + def test_task_decorator_sync_as_task(self): + @controlflow.task + def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + task = write_poem.as_task("AI") + assert task.name == "write_poem" + assert task.objective == "write a two-line poem about `topic`" + assert task.result_type is str + + def test_task_decorator_async_as_task(self): + @controlflow.task + async def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + task = write_poem.as_task("AI") + assert task.name == "write_poem" + assert task.objective == "write a two-line poem about `topic`" + assert task.result_type is str + + def test_task_decorator_sync(self): + @controlflow.task + def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + assert write_poem("AI") + + async def test_task_decorator_async(self): + @controlflow.task + async def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + assert await write_poem("AI")