Skip to content

Commit 572a188

Browse files
authored
Merge pull request #68 from jlowin/handlers
Simplify handlers
2 parents 67dd674 + 99601f9 commit 572a188

File tree

5 files changed

+105
-141
lines changed

5 files changed

+105
-141
lines changed

src/controlflow/llm/completions.py

Lines changed: 56 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import inspect
21
import math
32
from typing import AsyncGenerator, Callable, Generator, Tuple, Union
43

54
import litellm
65
from litellm.utils import trim_messages
76

87
import controlflow
9-
from controlflow.llm.handlers import AsyncStreamHandler, StreamHandler
8+
from controlflow.llm.handlers import CompoundHandler, StreamHandler
109
from controlflow.llm.tools import (
1110
as_tools,
1211
get_tool_calls,
@@ -19,11 +18,6 @@
1918
)
2019

2120

22-
async def maybe_coro(coro):
23-
if inspect.isawaitable(coro):
24-
await coro
25-
26-
2721
def completion(
2822
messages: list[Union[dict, ControlFlowMessage]],
2923
model=None,
@@ -47,8 +41,7 @@ def completion(
4741
response_messages = []
4842
new_messages = []
4943

50-
if handlers is None:
51-
handlers = []
44+
handler = CompoundHandler(handlers=handlers or [])
5245

5346
if model is None:
5447
model = controlflow.settings.model
@@ -70,20 +63,17 @@ def completion(
7063
response_messages = as_cf_messages([response])
7164

7265
# on message done
73-
for h in handlers:
74-
for msg in response_messages:
75-
if msg.has_tool_calls():
76-
h.on_tool_call_done(msg)
77-
else:
78-
h.on_message_done(msg)
66+
for msg in response_messages:
67+
new_messages.append(msg)
68+
if msg.has_tool_calls():
69+
handler.on_tool_call_done(msg)
70+
else:
71+
handler.on_message_done(msg)
7972

80-
new_messages.extend(response_messages)
73+
# tool calls
8174
for tool_call in get_tool_calls(response_messages):
8275
tool_message = handle_tool_call(tool_call, tools)
83-
84-
# on tool result
85-
for h in handlers:
86-
h.on_tool_result(tool_message)
76+
handler.on_tool_result(tool_message)
8777
new_messages.append(tool_message)
8878

8979
counter += 1
@@ -120,9 +110,7 @@ def completion_stream(
120110
snapshot_message = None
121111
new_messages = []
122112

123-
if handlers is None:
124-
handlers = []
125-
113+
handler = CompoundHandler(handlers=handlers or [])
126114
if model is None:
127115
model = controlflow.settings.model
128116

@@ -149,35 +137,33 @@ def completion_stream(
149137

150138
# on message created
151139
if len(deltas) == 1:
152-
for h in handlers:
153-
if snapshot_message.has_tool_calls():
154-
h.on_tool_call_created(delta=delta_message)
155-
else:
156-
h.on_message_created(delta=delta_message)
157-
158-
# on message delta
159-
for h in handlers:
160140
if snapshot_message.has_tool_calls():
161-
h.on_tool_call_delta(delta=delta_message, snapshot=snapshot_message)
141+
handler.on_tool_call_created(delta=delta_message)
162142
else:
163-
h.on_message_delta(delta=delta_message, snapshot=snapshot_message)
143+
handler.on_message_created(delta=delta_message)
144+
145+
# on message delta
146+
if snapshot_message.has_tool_calls():
147+
handler.on_tool_call_delta(
148+
delta=delta_message, snapshot=snapshot_message
149+
)
150+
else:
151+
handler.on_message_delta(delta=delta_message, snapshot=snapshot_message)
164152

165153
yield snapshot_message
166154

167155
new_messages.append(snapshot_message)
168156

169157
# on message done
170-
for h in handlers:
171-
if snapshot_message.has_tool_calls():
172-
h.on_tool_call_done(snapshot_message)
173-
else:
174-
h.on_message_done(snapshot_message)
158+
if snapshot_message.has_tool_calls():
159+
handler.on_tool_call_done(snapshot_message)
160+
else:
161+
handler.on_message_done(snapshot_message)
175162

176163
# tool calls
177164
for tool_call in get_tool_calls([snapshot_message]):
178165
tool_message = handle_tool_call(tool_call, tools)
179-
for h in handlers:
180-
h.on_tool_result(tool_message)
166+
handler.on_tool_result(tool_message)
181167
new_messages.append(tool_message)
182168
yield tool_message
183169

@@ -191,7 +177,7 @@ async def completion_async(
191177
model=None,
192178
tools: list[Callable] = None,
193179
max_iterations=None,
194-
handlers: list[Union[AsyncStreamHandler, StreamHandler]] = None,
180+
handlers: list[StreamHandler] = None,
195181
**kwargs,
196182
) -> list[ControlFlowMessage]:
197183
"""
@@ -209,9 +195,7 @@ async def completion_async(
209195
response_messages = []
210196
new_messages = []
211197

212-
if handlers is None:
213-
handlers = []
214-
198+
handler = CompoundHandler(handlers=handlers or [])
215199
if model is None:
216200
model = controlflow.settings.model
217201

@@ -231,19 +215,18 @@ async def completion_async(
231215

232216
response_messages = as_cf_messages([response])
233217

234-
# on message done
235-
for h in handlers:
236-
for msg in response_messages:
237-
if msg.has_tool_calls():
238-
await maybe_coro(h.on_tool_call_done(msg))
239-
else:
240-
await maybe_coro(h.on_message_done(msg))
218+
# on done
219+
for msg in response_messages:
220+
new_messages.append(msg)
221+
if msg.has_tool_calls():
222+
handler.on_tool_call_done(msg)
223+
else:
224+
handler.on_message_done(msg)
241225

242-
new_messages.extend(response_messages)
226+
# tool calls
243227
for tool_call in get_tool_calls(response_messages):
244228
tool_message = handle_tool_call(tool_call, tools)
245-
for h in handlers:
246-
await maybe_coro(h.on_tool_result(tool_message))
229+
handler.on_tool_result(tool_message)
247230
new_messages.append(tool_message)
248231

249232
counter += 1
@@ -258,7 +241,7 @@ async def completion_stream_async(
258241
model=None,
259242
tools: list[Callable] = None,
260243
max_iterations: int = None,
261-
handlers: list[Union[AsyncStreamHandler, StreamHandler]] = None,
244+
handlers: list[StreamHandler] = None,
262245
**kwargs,
263246
) -> AsyncGenerator[ControlFlowMessage, None]:
264247
"""
@@ -280,9 +263,7 @@ async def completion_stream_async(
280263
snapshot_message = None
281264
new_messages = []
282265

283-
if handlers is None:
284-
handlers = []
285-
266+
handler = CompoundHandler(handlers=handlers or [])
286267
if model is None:
287268
model = controlflow.settings.model
288269

@@ -309,43 +290,32 @@ async def completion_stream_async(
309290

310291
# on message created
311292
if len(deltas) == 1:
312-
for h in handlers:
313-
if snapshot_message.has_tool_calls():
314-
await maybe_coro(h.on_tool_call_created(delta=delta_message))
315-
else:
316-
await maybe_coro(h.on_message_created(delta=delta_message))
317-
318-
# on message delta
319-
for h in handlers:
320293
if snapshot_message.has_tool_calls():
321-
await maybe_coro(
322-
h.on_tool_call_delta(
323-
delta=delta_message, snapshot=snapshot_message
324-
)
325-
)
294+
handler.on_tool_call_created(delta=delta_message)
326295
else:
327-
await maybe_coro(
328-
h.on_message_delta(
329-
delta=delta_message, snapshot=snapshot_message
330-
)
331-
)
332-
333-
yield snapshot_message
296+
handler.on_message_created(delta=delta_message)
334297

335-
new_messages.append(snapshot_message)
336-
337-
# on message done
338-
for h in handlers:
298+
# on message delta
339299
if snapshot_message.has_tool_calls():
340-
await maybe_coro(h.on_tool_call_done(snapshot_message))
300+
handler.on_tool_call_delta(
301+
delta=delta_message, snapshot=snapshot_message
302+
)
341303
else:
342-
await maybe_coro(h.on_message_done(snapshot_message))
304+
handler.on_message_delta(delta=delta_message, snapshot=snapshot_message)
305+
306+
# on message done
307+
if snapshot_message.has_tool_calls():
308+
handler.on_tool_call_done(snapshot_message)
309+
else:
310+
handler.on_message_done(snapshot_message)
311+
312+
new_messages.append(snapshot_message)
313+
yield snapshot_message
343314

344315
# tool calls
345316
for tool_call in get_tool_calls([snapshot_message]):
346317
tool_message = handle_tool_call(tool_call, tools)
347-
for h in handlers:
348-
await maybe_coro(h.on_tool_result(tool_message))
318+
handler.on_tool_result(tool_message)
349319
new_messages.append(tool_message)
350320
yield tool_message
351321

src/controlflow/llm/handlers.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from controlflow.llm.tools import get_tool_calls
21
from controlflow.utilities.context import ctx
32
from controlflow.utilities.types import AssistantMessage, ToolMessage
43

@@ -26,53 +25,56 @@ def on_tool_result(self, message: ToolMessage):
2625
pass
2726

2827

29-
class AsyncStreamHandler(StreamHandler):
30-
async def on_message_created(self, delta: AssistantMessage):
31-
pass
28+
class CompoundHandler(StreamHandler):
29+
def __init__(self, handlers: list[StreamHandler]):
30+
self.handlers = handlers
3231

33-
async def on_message_delta(
34-
self, delta: AssistantMessage, snapshot: AssistantMessage
35-
):
36-
pass
32+
def on_message_created(self, delta: AssistantMessage):
33+
for handler in self.handlers:
34+
handler.on_message_created(delta)
3735

38-
async def on_message_done(self, message: AssistantMessage):
39-
pass
36+
def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
37+
for handler in self.handlers:
38+
handler.on_message_delta(delta, snapshot)
4039

41-
async def on_tool_call_created(self, delta: AssistantMessage):
42-
pass
40+
def on_message_done(self, message: AssistantMessage):
41+
for handler in self.handlers:
42+
handler.on_message_done(message)
4343

44-
async def on_tool_call_delta(
45-
self, delta: AssistantMessage, snapshot: AssistantMessage
46-
):
47-
pass
44+
def on_tool_call_created(self, delta: AssistantMessage):
45+
for handler in self.handlers:
46+
handler.on_tool_call_created(delta)
4847

49-
async def on_tool_call_done(self, message: AssistantMessage):
50-
pass
48+
def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
49+
for handler in self.handlers:
50+
handler.on_tool_call_delta(delta, snapshot)
5151

52-
async def on_tool_result(self, message: ToolMessage):
53-
pass
52+
def on_tool_call_done(self, message: AssistantMessage):
53+
for handler in self.handlers:
54+
handler.on_tool_call_done(message)
55+
56+
def on_tool_result(self, message: ToolMessage):
57+
for handler in self.handlers:
58+
handler.on_tool_result(message)
5459

5560

56-
class TUIHandler(AsyncStreamHandler):
57-
async def on_message_delta(
61+
class TUIHandler(StreamHandler):
62+
def on_message_delta(
5863
self, delta: AssistantMessage, snapshot: AssistantMessage
5964
) -> None:
6065
if tui := ctx.get("tui"):
6166
tui.update_message(message=snapshot)
6267

63-
async def on_tool_call_delta(
64-
self, delta: AssistantMessage, snapshot: AssistantMessage
65-
):
68+
def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
6669
if tui := ctx.get("tui"):
67-
for tool_call in get_tool_calls(snapshot):
68-
tui.update_message(message=snapshot)
70+
tui.update_message(message=snapshot)
6971

70-
async def on_tool_result(self, message: ToolMessage):
72+
def on_tool_result(self, message: ToolMessage):
7173
if tui := ctx.get("tui"):
7274
tui.update_tool_result(message=message)
7375

7476

75-
class PrintHandler(AsyncStreamHandler):
77+
class PrintHandler(StreamHandler):
7678
def on_message_created(self, delta: AssistantMessage):
7779
print(f"Created: {delta}\n")
7880

src/controlflow/tui/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def __init__(self, flow: "controlflow.Flow", **kwargs):
3939
async def run_context(
4040
self,
4141
run: bool = True,
42-
inline: bool = False,
42+
inline: bool = True,
4343
inline_stay_visible: bool = True,
4444
headless: bool = None,
45-
hold: bool = True,
45+
hold: bool = False,
4646
):
4747
if headless is None:
4848
headless = controlflow.settings.run_tui_headless

0 commit comments

Comments
 (0)