diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index eeed17ae..5f175948 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -201,18 +201,24 @@ def _get_task(*args, **kwargs) -> Task: **task_kwargs, ) - @functools.wraps(fn) - @prefect_task( + if asyncio.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + task = _get_task(*args, **kwargs) + return await task.run_async() + else: + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + task = _get_task(*args, **kwargs) + return task.run() + + wrapper = prefect_task( timeout_seconds=timeout_seconds, retries=retries, retry_delay_seconds=retry_delay_seconds, - ) - def wrapper( - *args, - **kwargs, - ): - task = _get_task(*args, **kwargs) - return task.run() + )(wrapper) # store the `as_task` method for loading the task object wrapper.as_task = _get_task diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 0b07731b..468b652f 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -102,3 +102,39 @@ def partial_flow(): result = partial_flow() assert result == 10 + + +class TestTaskDecorator: + def test_task_decorator_sync_as_task(self): + @controlflow.task + def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + task = write_poem.as_task("AI") + assert task.name == "write_poem" + assert task.objective == "write a two-line poem about `topic`" + assert task.result_type is str + + def test_task_decorator_async_as_task(self): + @controlflow.task + async def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + task = write_poem.as_task("AI") + assert task.name == "write_poem" + assert task.objective == "write a two-line poem about `topic`" + assert task.result_type is str + + def test_task_decorator_sync(self): + @controlflow.task + def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + assert write_poem("AI") + + async def test_task_decorator_async(self): + @controlflow.task + async def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + assert await write_poem("AI")