Skip to content

Commit 170d268

Browse files
committed
Expose and document handlers
1 parent dfc74c4 commit 170d268

File tree

12 files changed

+300
-47
lines changed

12 files changed

+300
-47
lines changed

docs/patterns/running-tasks.mdx

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,46 @@ The orchestrator is instantiated with the following arguments:
369369

370370
You can then use the orchestrator's `run()` method to step through the loop manually. If you call `run()` with no arguments, it will continue until all of the provided tasks are complete. You can provide `max_llm_calls` and `max_agent_turns` to further limit the behavior.
371371

372+
373+
## Using handlers
374+
375+
Handlers in ControlFlow provide a way to observe and react to events that occur during task execution. They allow you to customize logging, monitoring, or take specific actions based on the orchestration process.
376+
377+
Handlers implement the `Handler` interface, which defines methods for various events that can occur during task execution, including agent messages (and message deltas), user messages, tool calls, tool results, orchestrator sessions starting or stopping, and more.
378+
379+
ControlFlow includes a built-in `PrintHandler` that pretty-prints agent responses and tool calls to the terminal. It's used by default if `controlflow.settings.pretty_print_agent_events=True` and no other handlers are provided.
380+
381+
### How handlers work
382+
383+
Whenever an event is generated by ControlFlow, the orchestrator will pass it to all of its registered handlers. Each handler will dispatch to one of its methods based on the type of event. For example, an `AgentMessage` event will be handled by the handler's `on_agent_message` method. The `on_event` method is always called for every event. This table describes all event types and the methods they are dispatched to:
384+
385+
| Event Type | Method |
386+
|------------|--------|
387+
| `Event` (all events) | `on_event` |
388+
| `UserMessage` | `on_user_message` |
389+
| `OrchestratorMessage` | `on_orchestrator_message` |
390+
| `AgentMessage` | `on_agent_message` |
391+
| `AgentMessageDelta` | `on_agent_message_delta` |
392+
| `ToolCall` | `on_tool_call` |
393+
| `ToolResult` | `on_tool_result` |
394+
| `OrchestratorStart` | `on_orchestrator_start` |
395+
| `OrchestratorEnd` | `on_orchestrator_end` |
396+
| `OrchestratorError` | `on_orchestrator_error` |
397+
| `EndTurn` | `on_end_turn` |
398+
399+
400+
### Writing a custom handler
401+
402+
To create a custom handler, subclass the `Handler` class and implement the methods for the events you're interested in. Here's a simple example that logs agent messages:
403+
404+
```python
405+
import controlflow as cf
406+
from controlflow.orchestration.handler import Handler
407+
from controlflow.events.events import AgentMessage
408+
409+
class LoggingHandler(Handler):
410+
def on_agent_message(self, event: AgentMessage):
411+
print(f"Agent {event.agent.name} said: {event.ai_message.content}")
412+
413+
cf.run("Write a short poem about AI", handlers=[LoggingHandler()])
414+
```

src/controlflow/agents/agent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .memory import Memory
3535

3636
if TYPE_CHECKING:
37+
from controlflow.orchestration.handler import Handler
3738
from controlflow.orchestration.turn_strategies import TurnStrategy
3839
from controlflow.tasks import Task
3940
from controlflow.tools.tools import Tool
@@ -196,12 +197,14 @@ def run(
196197
objective: str,
197198
*,
198199
turn_strategy: "TurnStrategy" = None,
200+
handlers: list["Handler"] = None,
199201
**task_kwargs,
200202
):
201203
return controlflow.run(
202204
objective=objective,
203205
agents=[self],
204206
turn_strategy=turn_strategy,
207+
handlers=handlers,
205208
**task_kwargs,
206209
)
207210

@@ -210,12 +213,14 @@ async def run_async(
210213
objective: str,
211214
*,
212215
turn_strategy: "TurnStrategy" = None,
216+
handlers: list["Handler"] = None,
213217
**task_kwargs,
214218
):
215219
return await controlflow.run_async(
216220
objective=objective,
217221
agents=[self],
218222
turn_strategy=turn_strategy,
223+
handlers=handlers,
219224
**task_kwargs,
220225
)
221226

src/controlflow/events/orchestrator_events.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Literal
22

3+
from controlflow.agents.agent import Agent
34
from controlflow.events.base import UnpersistedEvent
45
from controlflow.orchestration.orchestrator import Orchestrator
56

@@ -21,3 +22,17 @@ class OrchestratorError(UnpersistedEvent):
2122
persist: bool = False
2223
orchestrator: Orchestrator
2324
error: Exception
25+
26+
27+
class AgentTurnStart(UnpersistedEvent):
28+
event: Literal["agent-turn-start"] = "agent-turn-start"
29+
persist: bool = False
30+
orchestrator: Orchestrator
31+
agent: Agent
32+
33+
34+
class AgentTurnEnd(UnpersistedEvent):
35+
event: Literal["agent-turn-end"] = "agent-turn-end"
36+
persist: bool = False
37+
orchestrator: Orchestrator
38+
agent: Agent
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .orchestrator import Orchestrator
2+
from .handler import Handler
Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,77 @@
1-
from typing import Callable
1+
from typing import TYPE_CHECKING, Callable
22

33
from controlflow.events.base import Event
44

5+
if TYPE_CHECKING:
6+
from controlflow.events.events import (
7+
AgentMessage,
8+
AgentMessageDelta,
9+
EndTurn,
10+
OrchestratorMessage,
11+
ToolCallEvent,
12+
ToolResultEvent,
13+
UserMessage,
14+
)
15+
from controlflow.events.orchestrator_events import (
16+
OrchestratorEnd,
17+
OrchestratorError,
18+
OrchestratorStart,
19+
)
20+
521

622
class Handler:
723
def handle(self, event: Event):
24+
"""
25+
Handle is called whenever an event is emitted.
26+
27+
By default, it dispatches to a method named after the event type e.g.
28+
`self.on_{event_type}(event=event)`.
29+
30+
The `on_event` method is always called for every event.
31+
"""
32+
self.on_event(event=event)
833
event_type = event.event.replace("-", "_")
934
method = getattr(self, f"on_{event_type}", None)
1035
if method:
1136
method(event=event)
1237

38+
def on_event(self, event: Event):
39+
pass
40+
41+
def on_orchestrator_start(self, event: "OrchestratorStart"):
42+
pass
43+
44+
def on_orchestrator_end(self, event: "OrchestratorEnd"):
45+
pass
46+
47+
def on_orchestrator_error(self, event: "OrchestratorError"):
48+
pass
49+
50+
def on_agent_message(self, event: "AgentMessage"):
51+
pass
52+
53+
def on_agent_message_delta(self, event: "AgentMessageDelta"):
54+
pass
55+
56+
def on_tool_call(self, event: "ToolCallEvent"):
57+
pass
58+
59+
def on_tool_result(self, event: "ToolResultEvent"):
60+
pass
61+
62+
def on_orchestrator_message(self, event: "OrchestratorMessage"):
63+
pass
64+
65+
def on_user_message(self, event: "UserMessage"):
66+
pass
67+
68+
def on_end_turn(self, event: "EndTurn"):
69+
pass
70+
1371

1472
class CallbackHandler(Handler):
1573
def __init__(self, callback: Callable[[Event], None]):
1674
self.callback = callback
1775

18-
def handle(self, event: Event):
76+
def on_event(self, event: Event):
1977
self.callback(event)

src/controlflow/orchestration/orchestrator.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -160,24 +160,18 @@ def run(
160160
if max_llm_calls is not None and call_count >= max_llm_calls:
161161
break
162162

163+
self.handle_event(
164+
controlflow.events.orchestrator_events.AgentTurnStart(
165+
orchestrator=self, agent=self.agent
166+
)
167+
)
163168
turn_count += 1
164-
self.turn_strategy.begin_turn()
165-
166-
# Mark assigned tasks as running
167-
for task in self.get_tasks("assigned"):
168-
if not task.is_running():
169-
task.mark_running()
170-
self.flow.add_events(
171-
[
172-
OrchestratorMessage(
173-
content=f"Starting task {task.name} (ID {task.id}) "
174-
f"with objective: {task.objective}"
175-
)
176-
]
177-
)
178-
179-
# Run the agent's turn
180169
call_count += self.run_agent_turn(max_llm_calls - call_count)
170+
self.handle_event(
171+
controlflow.events.orchestrator_events.AgentTurnEnd(
172+
orchestrator=self, agent=self.agent
173+
)
174+
)
181175

182176
# Select the next agent for the following turn
183177
if available_agents := self.get_available_agents():
@@ -244,25 +238,20 @@ async def run_async(
244238
if max_llm_calls is not None and call_count >= max_llm_calls:
245239
break
246240

241+
self.handle_event(
242+
controlflow.events.orchestrator_events.AgentTurnStart(
243+
orchestrator=self, agent=self.agent
244+
)
245+
)
247246
turn_count += 1
248-
self.turn_strategy.begin_turn()
249-
250-
# Mark assigned tasks as running
251-
for task in self.get_tasks("assigned"):
252-
if not task.is_running():
253-
task.mark_running()
254-
self.flow.add_events(
255-
[
256-
OrchestratorMessage(
257-
content=f"Starting task {task.name} (ID {task.id}) with objective: {task.objective}"
258-
)
259-
]
260-
)
261-
262-
# Run the agent's turn
263247
call_count += await self.run_agent_turn_async(
264248
max_llm_calls - call_count
265249
)
250+
self.handle_event(
251+
controlflow.events.orchestrator_events.AgentTurnEnd(
252+
orchestrator=self, agent=self.agent
253+
)
254+
)
266255

267256
# Select the next agent for the following turn
268257
if available_agents := self.get_available_agents():
@@ -300,6 +289,19 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int:
300289
call_count = 0
301290
assigned_tasks = self.get_tasks("assigned")
302291

292+
self.turn_strategy.begin_turn()
293+
294+
# Mark assigned tasks as running
295+
for task in assigned_tasks:
296+
if not task.is_running():
297+
task.mark_running()
298+
self.handle_event(
299+
OrchestratorMessage(
300+
content=f"Starting task {task.name} (ID {task.id}) "
301+
f"with objective: {task.objective}"
302+
)
303+
)
304+
303305
while not self.turn_strategy.should_end_turn():
304306
for task in assigned_tasks:
305307
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
@@ -340,6 +342,19 @@ async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int:
340342
call_count = 0
341343
assigned_tasks = self.get_tasks("assigned")
342344

345+
self.turn_strategy.begin_turn()
346+
347+
# Mark assigned tasks as running
348+
for task in assigned_tasks:
349+
if not task.is_running():
350+
task.mark_running()
351+
self.handle_event(
352+
OrchestratorMessage(
353+
content=f"Starting task {task.name} (ID {task.id}) "
354+
f"with objective: {task.objective}"
355+
)
356+
)
357+
343358
while not self.turn_strategy.should_end_turn():
344359
for task in assigned_tasks:
345360
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:

src/controlflow/orchestration/print_handler.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,6 @@ def __init__(self):
3535
self.paused_id: str = None
3636
super().__init__()
3737

38-
def on_orchestrator_start(self, event: OrchestratorStart):
39-
self.live: Live = Live(auto_refresh=False, console=cf_console)
40-
self.events.clear()
41-
try:
42-
self.live.start()
43-
except rich.errors.LiveError:
44-
pass
45-
46-
def on_orchestrator_end(self, event: OrchestratorEnd):
47-
self.live.stop()
48-
49-
def on_orchestrator_error(self, event: OrchestratorError):
50-
self.live.stop()
51-
5238
def update_live(self, latest: BaseMessage = None):
5339
events = sorted(self.events.items(), key=lambda e: (e[1].timestamp, e[0]))
5440
content = []
@@ -72,6 +58,20 @@ def update_live(self, latest: BaseMessage = None):
7258
elif latest:
7359
cf_console.print(format_event(latest))
7460

61+
def on_orchestrator_start(self, event: OrchestratorStart):
62+
self.live: Live = Live(auto_refresh=False, console=cf_console)
63+
self.events.clear()
64+
try:
65+
self.live.start()
66+
except rich.errors.LiveError:
67+
pass
68+
69+
def on_orchestrator_end(self, event: OrchestratorEnd):
70+
self.live.stop()
71+
72+
def on_orchestrator_error(self, event: OrchestratorError):
73+
self.live.stop()
74+
7575
def on_agent_message_delta(self, event: AgentMessageDelta):
7676
self.events[event.snapshot_message.id] = event
7777
self.update_live()

0 commit comments

Comments
 (0)