From 613e37c34d5c8fac4866498fec26cb82a6fdf3d1 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sun, 16 Jun 2024 14:46:35 -0400 Subject: [PATCH] Load task / flow defaults in utility module --- src/controlflow/core/controller/controller.py | 19 +++++++++----- src/controlflow/core/task.py | 4 +-- src/controlflow/decorators.py | 24 ++++++++++++++--- src/controlflow/settings.py | 4 +++ src/controlflow/utilities/prefect.py | 26 +++++++++++++++++-- 5 files changed, 63 insertions(+), 14 deletions(-) diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index 71326c4d..e9694c1e 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -5,7 +5,6 @@ from contextlib import asynccontextmanager from typing import Callable, Union -import prefect from pydantic import Field, PrivateAttr, model_validator import controlflow @@ -22,6 +21,7 @@ from controlflow.tui.app import TUIApp as TUI from controlflow.utilities.context import ctx from controlflow.utilities.prefect import create_markdown_artifact +from controlflow.utilities.prefect import task as prefect_task from controlflow.utilities.tasks import all_complete, any_incomplete from controlflow.utilities.types import ControlFlowModel @@ -32,13 +32,18 @@ def create_messages_markdown_artifact(messages, thread_id): markdown_messages = "\n\n".join([f"{msg.role}: {msg.content}" for msg in messages]) create_markdown_artifact( key="messages", - markdown=inspect.cleandoc(f""" + markdown=inspect.cleandoc( + """ # Messages *Thread ID: {thread_id}* {markdown_messages} - """), + """.format( + thread_id=thread_id, + markdown_messages=markdown_messages, + ) + ), ) @@ -214,7 +219,7 @@ def _setup_run(self): handlers=handlers, ) - @prefect.task(task_run_name="Run LLM") + @prefect_task(task_run_name="Run LLM") async def run_once_async(self) -> list[MessageType]: async with self.tui(): payload = self._setup_run() @@ -252,7 +257,7 @@ async def run_once_async(self) -> list[MessageType]: return response_handler.response_messages - @prefect.task(task_run_name="Run LLM") + @prefect_task(task_run_name="Run LLM") def run_once(self) -> list[MessageType]: payload = self._setup_run() if payload is None: @@ -289,7 +294,7 @@ def run_once(self) -> list[MessageType]: return response_handler.response_messages - @prefect.task(task_run_name="Run LLM Controller") + @prefect_task(task_run_name="Run LLM Controller") async def run_async(self) -> list[MessageType]: """ Run the controller until all tasks are complete. @@ -313,7 +318,7 @@ async def run_async(self) -> list[MessageType]: self._should_stop = False return messages - @prefect.task(task_run_name="Run LLM Controller") + @prefect_task(task_run_name="Run LLM Controller") def run(self) -> list[MessageType]: """ Run the controller until all tasks are complete. diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 208a3c45..2182a346 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -15,7 +15,6 @@ _LiteralGenericAlias, ) -import prefect from prefect.context import TaskRunContext from pydantic import ( Field, @@ -35,6 +34,7 @@ from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger from controlflow.utilities.prefect import PrefectTrackingTask +from controlflow.utilities.prefect import task as prefect_task from controlflow.utilities.tasks import ( collect_tasks, visit_task_collection, @@ -313,7 +313,7 @@ async def run_once_async(self, agent: "Agent" = None, flow: "Flow" = None): controller = controlflow.Controller(tasks=[self], agents=agent, flow=flow) await controller.run_once_async() - @prefect.task(task_run_name=get_task_run_name) + @prefect_task(task_run_name=get_task_run_name) def _run( self, raise_on_error: bool = True, diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index 7b96ec66..f59605dd 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -2,13 +2,13 @@ import inspect from typing import Any, Callable, Optional, Union -import prefect - import controlflow from controlflow.core.agent import Agent from controlflow.core.flow import Flow from controlflow.core.task import Task from controlflow.utilities.logging import get_logger +from controlflow.utilities.prefect import flow as prefect_flow +from controlflow.utilities.prefect import task as prefect_task # from controlflow.utilities.marvin import patch_marvin from controlflow.utilities.tasks import resolve_tasks @@ -63,12 +63,16 @@ def flow( tools=tools, agents=agents, lazy=lazy, + retries=retries, + retry_delay_seconds=retry_delay_seconds, + timeout_seconds=timeout_seconds, + prefect_kwargs=prefect_kwargs, ) sig = inspect.signature(fn) # the flow decorator creates a proper prefect flow - @prefect.flow( + @prefect_flow( timeout_seconds=timeout_seconds, retries=retries, retry_delay_seconds=retry_delay_seconds, @@ -136,6 +140,10 @@ def task( tools: Optional[list[Callable[..., Any]]] = None, user_access: Optional[bool] = None, lazy: Optional[bool] = None, + retries: Optional[int] = None, + retry_delay_seconds: Optional[Union[float, int]] = None, + timeout_seconds: Optional[Union[float, int]] = None, + prefect_kwargs: Optional[dict[str, Any]] = None, ): """ A decorator that turns a Python function into a Task. The Task objective is @@ -174,6 +182,10 @@ def task( tools=tools, user_access=user_access, lazy=lazy, + retries=retries, + retry_delay_seconds=retry_delay_seconds, + timeout_seconds=timeout_seconds, + prefect_kwargs=prefect_kwargs, ) sig = inspect.signature(fn) @@ -186,6 +198,12 @@ def task( result_type = fn.__annotations__.get("return") + @prefect_task( + timeout_seconds=timeout_seconds, + retries=retries, + retry_delay_seconds=retry_delay_seconds, + **prefect_kwargs or {}, + ) @functools.wraps(fn) def wrapper(*args, lazy_: bool = None, **kwargs): # first process callargs diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index a8435c83..81dde035 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -44,6 +44,10 @@ class Settings(ControlFlowSettings): # ------------ display and logging settings ------------ + log_prints: bool = Field( + True, description="Whether to log prints to eh Prefect logger by default." + ) + # ------------ flow settings ------------ eager_mode: bool = Field( diff --git a/src/controlflow/utilities/prefect.py b/src/controlflow/utilities/prefect.py index 7bda8e8d..26a4fe03 100644 --- a/src/controlflow/utilities/prefect.py +++ b/src/controlflow/utilities/prefect.py @@ -9,6 +9,7 @@ ) from uuid import UUID +import prefect import prefect.tasks from prefect import get_client as get_prefect_client from prefect.artifacts import ArtifactRequest @@ -34,12 +35,33 @@ ) from pydantic import TypeAdapter +import controlflow from controlflow.utilities.types import ControlFlowModel if TYPE_CHECKING: from controlflow.llm.tools import Tool +def task(*args, **kwargs): + """ + A decorator that creates a Prefect task with ControlFlow defaults + """ + + kwargs.setdefault("log_prints", controlflow.settings.log_prints) + + return prefect.task(*args, **kwargs) + + +def flow(*args, **kwargs): + """ + A decorator that creates a Prefect flow with ControlFlow defaults + """ + + kwargs.setdefault("log_prints", controlflow.settings.log_prints) + + return prefect.flow(*args, **kwargs) + + def create_markdown_artifact( key: str, markdown: str, @@ -353,7 +375,7 @@ def prefect_task_context(**kwargs): ) @contextmanager - @prefect.task(**kwargs) + @task(**kwargs) def task_context(): yield @@ -387,7 +409,7 @@ def prefect_flow_context(**kwargs): ) @contextmanager - @prefect.flow(**kwargs) + @flow(**kwargs) def flow_context(): yield