Skip to content

Commit

Permalink
Merge pull request #364 from PrefectHQ/async-handlers
Browse files Browse the repository at this point in the history
Add async handlers
  • Loading branch information
jlowin authored Oct 25, 2024
2 parents 2dac05a + fe147ee commit 6954c99
Show file tree
Hide file tree
Showing 11 changed files with 5,308 additions and 35 deletions.
32 changes: 28 additions & 4 deletions docs/patterns/running-tasks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -405,17 +405,17 @@ You can then use the orchestrator's `run()` method to step through the loop manu

Handlers in ControlFlow provide a way to observe and react to events that occur during task execution. They allow you to customize logging, monitoring, or take specific actions based on the orchestration process.

Handlers implement the `Handler` interface, which defines methods for various events that can occur during task execution, including agent messages (and message deltas), user messages, tool calls, tool results, orchestrator sessions starting or stopping, and more.
ControlFlow supports both synchronous and asynchronous handlers. Synchronous handlers implement the `Handler` interface, while asynchronous handlers implement the `AsyncHandler` interface. Both interfaces define methods for various events that can occur during task execution, including agent messages (and message deltas), user messages, tool calls, tool results, orchestrator sessions starting or stopping, and more.

ControlFlow includes a built-in `PrintHandler` that pretty-prints agent responses and tool calls to the terminal. It's used by default if `controlflow.settings.pretty_print_agent_events=True` and no other handlers are provided.

### How handlers work

Whenever an event is generated by ControlFlow, the orchestrator will pass it to all of its registered handlers. Each handler will dispatch to one of its methods based on the type of event. For example, an `AgentMessage` event will be handled by the handler's `on_agent_message` method. The `on_event` method is always called for every event. This table describes all event types and the methods they are dispatched to:
Whenever an event is generated by ControlFlow, the orchestrator will pass it to all of its registered handlers. Each handler will dispatch to one of its methods based on the type of event. For example, an `AgentMessage` event will be handled by the handler's `on_agent_message` method (or `on_agent_message_async` for async handlers). The `on_event` method is always called for every event. This table describes all event types and the methods they are dispatched to:

| Event Type | Method |
|------------|--------|
| `Event` (all events) | `on_event` |
| `Event` (all events) | `on_event` |
| `UserMessage` | `on_user_message` |
| `OrchestratorMessage` | `on_orchestrator_message` |
| `AgentMessage` | `on_agent_message` |
Expand All @@ -430,7 +430,9 @@ Whenever an event is generated by ControlFlow, the orchestrator will pass it to

### Writing a custom handler

To create a custom handler, subclass the `Handler` class and implement the methods for the events you're interested in. Here's a simple example that logs agent messages:
To create a custom handler, subclass either the `Handler` class for synchronous handlers or the `AsyncHandler` class for asynchronous handlers. Implement the methods for the events you're interested in. Here are examples of both types:

#### Synchronous Handler

```python
import controlflow as cf
Expand All @@ -443,3 +445,25 @@ class LoggingHandler(Handler):

cf.run("Write a short poem about AI", handlers=[LoggingHandler()])
```

#### Asynchronous Handler

<VersionBadge version="0.11.1" />

```python
import asyncio
import controlflow as cf
from controlflow.orchestration.handler import AsyncHandler
from controlflow.events.events import AgentMessage

class AsyncLoggingHandler(AsyncHandler):
async def on_agent_message(self, event: AgentMessage):
await asyncio.sleep(0.1) # Simulate some async operation
print(f"Agent {event.agent.name} said: {event.ai_message.content}")

await cf.run_async("Write a short poem about AI", handlers=[AsyncLoggingHandler()])
```

When using asynchronous handlers, make sure to use the `run_async` function or other asynchronous methods in ControlFlow to properly handle the asynchronous events.

You can use both synchronous and asynchronous handlers together in the same async run. The orchestrator will automatically handle both types appropriately.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ dependencies = [
"langchain_openai>=0.2",
"langchain-anthropic>=0.2",
"markdownify>=0.12.1",
"openai<1.47", # 1.47.0 introduced a bug with attempting to reuse an async client that doesnt have an obvious solution
"openai<1.47",
# 1.47.0 introduced a bug with attempting to reuse an async client that doesnt have an obvious solution
"pydantic-settings>=2.2.1",
"textual>=0.61.1",
"tiktoken>=0.7.0",
"typer>=0.10",
"ipython>=8.18.1",
]
readme = "README.md"
requires-python = ">= 3.9"
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from controlflow.utilities.prefect import create_markdown_artifact, prefect_task

if TYPE_CHECKING:
from controlflow.orchestration.handler import Handler
from controlflow.orchestration.handler import AsyncHandler, Handler
from controlflow.orchestration.turn_strategies import TurnStrategy
from controlflow.tasks import Task
from controlflow.tools.tools import Tool
Expand Down Expand Up @@ -229,7 +229,7 @@ async def run_async(
objective: str,
*,
turn_strategy: "TurnStrategy" = None,
handlers: list["Handler"] = None,
handlers: list[Union["Handler", "AsyncHandler"]] = None,
**task_kwargs,
):
return await controlflow.run_async(
Expand Down
8 changes: 4 additions & 4 deletions src/controlflow/orchestration/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from pydantic import BaseModel, field_validator

from controlflow.tasks.task import Task
from controlflow.utilities.general import ControlFlowModel
from controlflow.utilities.logging import get_logger

if TYPE_CHECKING:
from controlflow.orchestration.orchestrator import Orchestrator
from controlflow.tasks.task import Task

logger = get_logger(__name__)

Expand Down Expand Up @@ -101,7 +101,7 @@ def should_end(self, context: RunContext) -> bool:


class AllComplete(RunEndCondition):
def __init__(self, tasks: Optional[list[Task]] = None):
def __init__(self, tasks: Optional[list["Task"]] = None):
self.tasks = tasks

def should_end(self, context: RunContext) -> bool:
Expand All @@ -113,7 +113,7 @@ def should_end(self, context: RunContext) -> bool:


class AnyComplete(RunEndCondition):
def __init__(self, tasks: Optional[list[Task]] = None, min_complete: int = 1):
def __init__(self, tasks: Optional[list["Task"]] = None, min_complete: int = 1):
self.tasks = tasks
if min_complete < 1:
raise ValueError("min_complete must be at least 1")
Expand All @@ -128,7 +128,7 @@ def should_end(self, context: RunContext) -> bool:


class AnyFailed(RunEndCondition):
def __init__(self, tasks: Optional[list[Task]] = None, min_failed: int = 1):
def __init__(self, tasks: Optional[list["Task"]] = None, min_failed: int = 1):
self.tasks = tasks
if min_failed < 1:
raise ValueError("min_failed must be at least 1")
Expand Down
53 changes: 52 additions & 1 deletion src/controlflow/orchestration/handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Callable
import asyncio
from typing import TYPE_CHECKING, Callable, Coroutine, Union

from controlflow.events.base import Event

Expand Down Expand Up @@ -75,3 +76,53 @@ def __init__(self, callback: Callable[[Event], None]):

def on_event(self, event: Event):
self.callback(event)


class AsyncHandler:
async def handle(self, event: Event):
"""
Handle is called whenever an event is emitted.
By default, it dispatches to a method named after the event type e.g.
`self.on_{event_type}(event=event)`.
The `on_event` method is always called for every event.
"""
await self.on_event(event=event)
event_type = event.event.replace("-", "_")
method = getattr(self, f"on_{event_type}", None)
if method:
await method(event=event)

async def on_event(self, event: Event):
pass

async def on_orchestrator_start(self, event: "OrchestratorStart"):
pass

async def on_orchestrator_end(self, event: "OrchestratorEnd"):
pass

async def on_orchestrator_error(self, event: "OrchestratorError"):
pass

async def on_agent_message(self, event: "AgentMessage"):
pass

async def on_agent_message_delta(self, event: "AgentMessageDelta"):
pass

async def on_tool_call(self, event: "ToolCallEvent"):
pass

async def on_tool_result(self, event: "ToolResultEvent"):
pass

async def on_orchestrator_message(self, event: "OrchestratorMessage"):
pass

async def on_user_message(self, event: "UserMessage"):
pass

async def on_end_turn(self, event: "EndTurn"):
pass
39 changes: 28 additions & 11 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
RunContext,
RunEndCondition,
)
from controlflow.orchestration.handler import Handler
from controlflow.orchestration.handler import AsyncHandler, Handler
from controlflow.orchestration.turn_strategies import Popcorn, TurnStrategy
from controlflow.tasks.task import Task
from controlflow.tools.tools import Tool, as_tools
Expand Down Expand Up @@ -51,7 +51,7 @@ class Orchestrator(ControlFlowModel):
description="The strategy to use for managing agent turns",
validate_default=True,
)
handlers: list[Handler] = Field(None, validate_default=True)
handlers: list[Union[Handler, AsyncHandler]] = Field(None, validate_default=True)

@field_validator("turn_strategy", mode="before")
def _validate_turn_strategy(cls, v):
Expand Down Expand Up @@ -86,7 +86,25 @@ def handle_event(self, event: Event):
if not isinstance(event, AgentMessageDelta):
logger.debug(f"Handling event: {repr(event)}")
for handler in self.handlers:
handler.handle(event)
if isinstance(handler, Handler):
handler.handle(event)
if event.persist:
self.flow.add_events([event])

async def handle_event_async(self, event: Event):
"""
Handle an event asynchronously by passing it to all handlers and persisting if necessary.
Args:
event (Event): The event to handle.
"""
if not isinstance(event, AgentMessageDelta):
logger.debug(f"Handling event asynchronously: {repr(event)}")
for handler in self.handlers:
if isinstance(handler, AsyncHandler):
await handler.handle(event)
elif isinstance(handler, Handler):
handler.handle(event)
if event.persist:
self.flow.add_events([event])

Expand Down Expand Up @@ -264,17 +282,16 @@ async def run_async(
)

# Signal the start of orchestration
self.handle_event(
await self.handle_event_async(
controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self)
)

try:
while True:
# Check termination condition
if run_context.should_end():
break

self.handle_event(
await self.handle_event_async(
controlflow.events.orchestrator_events.AgentTurnStart(
orchestrator=self, agent=self.agent
)
Expand All @@ -283,7 +300,7 @@ async def run_async(
run_context=run_context,
model_kwargs=model_kwargs,
)
self.handle_event(
await self.handle_event_async(
controlflow.events.orchestrator_events.AgentTurnEnd(
orchestrator=self, agent=self.agent
)
Expand All @@ -297,15 +314,15 @@ async def run_async(

except Exception as exc:
# Handle any exceptions that occur during orchestration
self.handle_event(
await self.handle_event_async(
controlflow.events.orchestrator_events.OrchestratorError(
orchestrator=self, error=exc
)
)
raise
finally:
# Signal the end of orchestration
self.handle_event(
await self.handle_event_async(
controlflow.events.orchestrator_events.OrchestratorEnd(
orchestrator=self
)
Expand Down Expand Up @@ -389,7 +406,7 @@ async def run_agent_turn_async(
for task in assigned_tasks:
if not task.is_running():
task.mark_running()
self.handle_event(
await self.handle_event_async(
OrchestratorMessage(
content=f"Starting task {task.name} (ID {task.id}) "
f"with objective: {task.objective}"
Expand Down Expand Up @@ -418,7 +435,7 @@ async def run_agent_turn_async(
tools=tools,
model_kwargs=model_kwargs,
):
self.handle_event(event)
await self.handle_event_async(event)

run_context.llm_calls += 1
for task in assigned_tasks:
Expand Down
6 changes: 3 additions & 3 deletions src/controlflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from controlflow.agents.agent import Agent
from controlflow.flows import Flow, get_flow
from controlflow.orchestration.conditions import RunContext, RunEndCondition
from controlflow.orchestration.handler import Handler
from controlflow.orchestration.handler import AsyncHandler, Handler
from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy
from controlflow.tasks.task import Task
from controlflow.utilities.prefect import prefect_task
Expand Down Expand Up @@ -77,7 +77,7 @@ async def run_tasks_async(
raise_on_failure: bool = True,
max_llm_calls: int = None,
max_agent_turns: int = None,
handlers: list[Handler] = None,
handlers: list[Union[Handler, AsyncHandler]] = None,
model_kwargs: Optional[dict] = None,
run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None,
):
Expand Down Expand Up @@ -147,7 +147,7 @@ async def run_async(
max_llm_calls: int = None,
max_agent_turns: int = None,
raise_on_failure: bool = True,
handlers: list[Handler] = None,
handlers: list[Union[Handler, AsyncHandler]] = None,
model_kwargs: Optional[dict] = None,
run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None,
**task_kwargs,
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

if TYPE_CHECKING:
from controlflow.flows import Flow
from controlflow.orchestration.handler import Handler
from controlflow.orchestration.handler import AsyncHandler, Handler
from controlflow.orchestration.turn_strategies import TurnStrategy

T = TypeVar("T")
Expand Down Expand Up @@ -430,7 +430,7 @@ async def run_async(
turn_strategy: "TurnStrategy" = None,
max_llm_calls: int = None,
max_agent_turns: int = None,
handlers: list["Handler"] = None,
handlers: list[Union["Handler", "AsyncHandler"]] = None,
raise_on_failure: bool = True,
) -> T:
"""
Expand Down
21 changes: 20 additions & 1 deletion tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from controlflow.events.events import AgentMessage
from controlflow.flows import Flow
from controlflow.instructions import instructions
from controlflow.orchestration.handler import Handler
from controlflow.orchestration.handler import AsyncHandler, Handler
from controlflow.tasks.task import (
COMPLETE_STATUSES,
INCOMPLETE_STATUSES,
Expand Down Expand Up @@ -535,6 +535,17 @@ def on_event(self, event: Event):
def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

class AsyncExampleHandler(AsyncHandler):
def __init__(self):
self.events = []
self.agent_messages = []

async def on_event(self, event: Event):
self.events.append(event)

async def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

def test_task_run_with_handlers(self, default_fake_llm):
handler = self.ExampleHandler()
task = Task(objective="Calculate 2 + 2", result_type=int)
Expand All @@ -551,6 +562,14 @@ async def test_task_run_async_with_handlers(self, default_fake_llm):
assert len(handler.events) > 0
assert len(handler.agent_messages) == 1

async def test_task_run_async_with_async_handlers(self, default_fake_llm):
handler = self.AsyncExampleHandler()
task = Task(objective="Calculate 2 + 2", result_type=int)
await task.run_async(handlers=[handler], max_llm_calls=1)

assert len(handler.events) > 0
assert len(handler.agent_messages) == 1


class TestCompletionTools:
def test_default_completion_tools(self):
Expand Down
Loading

0 comments on commit 6954c99

Please sign in to comment.