Skip to content

Commit

Permalink
Load task / flow defaults in utility module
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Jun 16, 2024
1 parent c19309d commit 613e37c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 14 deletions.
19 changes: 12 additions & 7 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from contextlib import asynccontextmanager
from typing import Callable, Union

import prefect
from pydantic import Field, PrivateAttr, model_validator

import controlflow
Expand All @@ -22,6 +21,7 @@
from controlflow.tui.app import TUIApp as TUI
from controlflow.utilities.context import ctx
from controlflow.utilities.prefect import create_markdown_artifact
from controlflow.utilities.prefect import task as prefect_task
from controlflow.utilities.tasks import all_complete, any_incomplete
from controlflow.utilities.types import ControlFlowModel

Expand All @@ -32,13 +32,18 @@ def create_messages_markdown_artifact(messages, thread_id):
markdown_messages = "\n\n".join([f"{msg.role}: {msg.content}" for msg in messages])
create_markdown_artifact(
key="messages",
markdown=inspect.cleandoc(f"""
markdown=inspect.cleandoc(
"""
# Messages
*Thread ID: {thread_id}*
{markdown_messages}
"""),
""".format(
thread_id=thread_id,
markdown_messages=markdown_messages,
)
),
)


Expand Down Expand Up @@ -214,7 +219,7 @@ def _setup_run(self):
handlers=handlers,
)

@prefect.task(task_run_name="Run LLM")
@prefect_task(task_run_name="Run LLM")
async def run_once_async(self) -> list[MessageType]:
async with self.tui():
payload = self._setup_run()
Expand Down Expand Up @@ -252,7 +257,7 @@ async def run_once_async(self) -> list[MessageType]:

return response_handler.response_messages

@prefect.task(task_run_name="Run LLM")
@prefect_task(task_run_name="Run LLM")
def run_once(self) -> list[MessageType]:
payload = self._setup_run()
if payload is None:
Expand Down Expand Up @@ -289,7 +294,7 @@ def run_once(self) -> list[MessageType]:

return response_handler.response_messages

@prefect.task(task_run_name="Run LLM Controller")
@prefect_task(task_run_name="Run LLM Controller")
async def run_async(self) -> list[MessageType]:
"""
Run the controller until all tasks are complete.
Expand All @@ -313,7 +318,7 @@ async def run_async(self) -> list[MessageType]:
self._should_stop = False
return messages

@prefect.task(task_run_name="Run LLM Controller")
@prefect_task(task_run_name="Run LLM Controller")
def run(self) -> list[MessageType]:
"""
Run the controller until all tasks are complete.
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_LiteralGenericAlias,
)

import prefect
from prefect.context import TaskRunContext
from pydantic import (
Field,
Expand All @@ -35,6 +34,7 @@
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.prefect import PrefectTrackingTask
from controlflow.utilities.prefect import task as prefect_task
from controlflow.utilities.tasks import (
collect_tasks,
visit_task_collection,
Expand Down Expand Up @@ -313,7 +313,7 @@ async def run_once_async(self, agent: "Agent" = None, flow: "Flow" = None):
controller = controlflow.Controller(tasks=[self], agents=agent, flow=flow)
await controller.run_once_async()

@prefect.task(task_run_name=get_task_run_name)
@prefect_task(task_run_name=get_task_run_name)
def _run(
self,
raise_on_error: bool = True,
Expand Down
24 changes: 21 additions & 3 deletions src/controlflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import inspect
from typing import Any, Callable, Optional, Union

import prefect

import controlflow
from controlflow.core.agent import Agent
from controlflow.core.flow import Flow
from controlflow.core.task import Task
from controlflow.utilities.logging import get_logger
from controlflow.utilities.prefect import flow as prefect_flow
from controlflow.utilities.prefect import task as prefect_task

# from controlflow.utilities.marvin import patch_marvin
from controlflow.utilities.tasks import resolve_tasks
Expand Down Expand Up @@ -63,12 +63,16 @@ def flow(
tools=tools,
agents=agents,
lazy=lazy,
retries=retries,
retry_delay_seconds=retry_delay_seconds,
timeout_seconds=timeout_seconds,
prefect_kwargs=prefect_kwargs,
)

sig = inspect.signature(fn)

# the flow decorator creates a proper prefect flow
@prefect.flow(
@prefect_flow(
timeout_seconds=timeout_seconds,
retries=retries,
retry_delay_seconds=retry_delay_seconds,
Expand Down Expand Up @@ -136,6 +140,10 @@ def task(
tools: Optional[list[Callable[..., Any]]] = None,
user_access: Optional[bool] = None,
lazy: Optional[bool] = None,
retries: Optional[int] = None,
retry_delay_seconds: Optional[Union[float, int]] = None,
timeout_seconds: Optional[Union[float, int]] = None,
prefect_kwargs: Optional[dict[str, Any]] = None,
):
"""
A decorator that turns a Python function into a Task. The Task objective is
Expand Down Expand Up @@ -174,6 +182,10 @@ def task(
tools=tools,
user_access=user_access,
lazy=lazy,
retries=retries,
retry_delay_seconds=retry_delay_seconds,
timeout_seconds=timeout_seconds,
prefect_kwargs=prefect_kwargs,
)

sig = inspect.signature(fn)
Expand All @@ -186,6 +198,12 @@ def task(

result_type = fn.__annotations__.get("return")

@prefect_task(
timeout_seconds=timeout_seconds,
retries=retries,
retry_delay_seconds=retry_delay_seconds,
**prefect_kwargs or {},
)
@functools.wraps(fn)
def wrapper(*args, lazy_: bool = None, **kwargs):
# first process callargs
Expand Down
4 changes: 4 additions & 0 deletions src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class Settings(ControlFlowSettings):

# ------------ display and logging settings ------------

log_prints: bool = Field(
True, description="Whether to log prints to eh Prefect logger by default."
)

# ------------ flow settings ------------

eager_mode: bool = Field(
Expand Down
26 changes: 24 additions & 2 deletions src/controlflow/utilities/prefect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from uuid import UUID

import prefect
import prefect.tasks
from prefect import get_client as get_prefect_client
from prefect.artifacts import ArtifactRequest
Expand All @@ -34,12 +35,33 @@
)
from pydantic import TypeAdapter

import controlflow
from controlflow.utilities.types import ControlFlowModel

if TYPE_CHECKING:
from controlflow.llm.tools import Tool


def task(*args, **kwargs):
"""
A decorator that creates a Prefect task with ControlFlow defaults
"""

kwargs.setdefault("log_prints", controlflow.settings.log_prints)

return prefect.task(*args, **kwargs)


def flow(*args, **kwargs):
"""
A decorator that creates a Prefect flow with ControlFlow defaults
"""

kwargs.setdefault("log_prints", controlflow.settings.log_prints)

return prefect.flow(*args, **kwargs)


def create_markdown_artifact(
key: str,
markdown: str,
Expand Down Expand Up @@ -353,7 +375,7 @@ def prefect_task_context(**kwargs):
)

@contextmanager
@prefect.task(**kwargs)
@task(**kwargs)
def task_context():
yield

Expand Down Expand Up @@ -387,7 +409,7 @@ def prefect_flow_context(**kwargs):
)

@contextmanager
@prefect.flow(**kwargs)
@flow(**kwargs)
def flow_context():
yield

Expand Down

0 comments on commit 613e37c

Please sign in to comment.