1
- import inspect
2
1
import math
3
2
from typing import AsyncGenerator , Callable , Generator , Tuple , Union
4
3
5
4
import litellm
6
5
from litellm .utils import trim_messages
7
6
8
7
import controlflow
9
- from controlflow .llm .handlers import AsyncStreamHandler , StreamHandler
8
+ from controlflow .llm .handlers import CompoundHandler , StreamHandler
10
9
from controlflow .llm .tools import (
11
10
as_tools ,
12
11
get_tool_calls ,
19
18
)
20
19
21
20
22
- async def maybe_coro (coro ):
23
- if inspect .isawaitable (coro ):
24
- await coro
25
-
26
-
27
21
def completion (
28
22
messages : list [Union [dict , ControlFlowMessage ]],
29
23
model = None ,
@@ -47,8 +41,7 @@ def completion(
47
41
response_messages = []
48
42
new_messages = []
49
43
50
- if handlers is None :
51
- handlers = []
44
+ handler = CompoundHandler (handlers = handlers or [])
52
45
53
46
if model is None :
54
47
model = controlflow .settings .model
@@ -70,20 +63,17 @@ def completion(
70
63
response_messages = as_cf_messages ([response ])
71
64
72
65
# on message done
73
- for h in handlers :
74
- for msg in response_messages :
75
- if msg .has_tool_calls ():
76
- h .on_tool_call_done (msg )
77
- else :
78
- h .on_message_done (msg )
66
+ for msg in response_messages :
67
+ new_messages . append ( msg )
68
+ if msg .has_tool_calls ():
69
+ handler .on_tool_call_done (msg )
70
+ else :
71
+ handler .on_message_done (msg )
79
72
80
- new_messages . extend ( response_messages )
73
+ # tool calls
81
74
for tool_call in get_tool_calls (response_messages ):
82
75
tool_message = handle_tool_call (tool_call , tools )
83
-
84
- # on tool result
85
- for h in handlers :
86
- h .on_tool_result (tool_message )
76
+ handler .on_tool_result (tool_message )
87
77
new_messages .append (tool_message )
88
78
89
79
counter += 1
@@ -120,9 +110,7 @@ def completion_stream(
120
110
snapshot_message = None
121
111
new_messages = []
122
112
123
- if handlers is None :
124
- handlers = []
125
-
113
+ handler = CompoundHandler (handlers = handlers or [])
126
114
if model is None :
127
115
model = controlflow .settings .model
128
116
@@ -149,35 +137,33 @@ def completion_stream(
149
137
150
138
# on message created
151
139
if len (deltas ) == 1 :
152
- for h in handlers :
153
- if snapshot_message .has_tool_calls ():
154
- h .on_tool_call_created (delta = delta_message )
155
- else :
156
- h .on_message_created (delta = delta_message )
157
-
158
- # on message delta
159
- for h in handlers :
160
140
if snapshot_message .has_tool_calls ():
161
- h . on_tool_call_delta (delta = delta_message , snapshot = snapshot_message )
141
+ handler . on_tool_call_created (delta = delta_message )
162
142
else :
163
- h .on_message_delta (delta = delta_message , snapshot = snapshot_message )
143
+ handler .on_message_created (delta = delta_message )
144
+
145
+ # on message delta
146
+ if snapshot_message .has_tool_calls ():
147
+ handler .on_tool_call_delta (
148
+ delta = delta_message , snapshot = snapshot_message
149
+ )
150
+ else :
151
+ handler .on_message_delta (delta = delta_message , snapshot = snapshot_message )
164
152
165
153
yield snapshot_message
166
154
167
155
new_messages .append (snapshot_message )
168
156
169
157
# on message done
170
- for h in handlers :
171
- if snapshot_message .has_tool_calls ():
172
- h .on_tool_call_done (snapshot_message )
173
- else :
174
- h .on_message_done (snapshot_message )
158
+ if snapshot_message .has_tool_calls ():
159
+ handler .on_tool_call_done (snapshot_message )
160
+ else :
161
+ handler .on_message_done (snapshot_message )
175
162
176
163
# tool calls
177
164
for tool_call in get_tool_calls ([snapshot_message ]):
178
165
tool_message = handle_tool_call (tool_call , tools )
179
- for h in handlers :
180
- h .on_tool_result (tool_message )
166
+ handler .on_tool_result (tool_message )
181
167
new_messages .append (tool_message )
182
168
yield tool_message
183
169
@@ -191,7 +177,7 @@ async def completion_async(
191
177
model = None ,
192
178
tools : list [Callable ] = None ,
193
179
max_iterations = None ,
194
- handlers : list [Union [ AsyncStreamHandler , StreamHandler ] ] = None ,
180
+ handlers : list [StreamHandler ] = None ,
195
181
** kwargs ,
196
182
) -> list [ControlFlowMessage ]:
197
183
"""
@@ -209,9 +195,7 @@ async def completion_async(
209
195
response_messages = []
210
196
new_messages = []
211
197
212
- if handlers is None :
213
- handlers = []
214
-
198
+ handler = CompoundHandler (handlers = handlers or [])
215
199
if model is None :
216
200
model = controlflow .settings .model
217
201
@@ -231,19 +215,18 @@ async def completion_async(
231
215
232
216
response_messages = as_cf_messages ([response ])
233
217
234
- # on message done
235
- for h in handlers :
236
- for msg in response_messages :
237
- if msg .has_tool_calls ():
238
- await maybe_coro ( h .on_tool_call_done (msg ) )
239
- else :
240
- await maybe_coro ( h .on_message_done (msg ) )
218
+ # on done
219
+ for msg in response_messages :
220
+ new_messages . append ( msg )
221
+ if msg .has_tool_calls ():
222
+ handler .on_tool_call_done (msg )
223
+ else :
224
+ handler .on_message_done (msg )
241
225
242
- new_messages . extend ( response_messages )
226
+ # tool calls
243
227
for tool_call in get_tool_calls (response_messages ):
244
228
tool_message = handle_tool_call (tool_call , tools )
245
- for h in handlers :
246
- await maybe_coro (h .on_tool_result (tool_message ))
229
+ handler .on_tool_result (tool_message )
247
230
new_messages .append (tool_message )
248
231
249
232
counter += 1
@@ -258,7 +241,7 @@ async def completion_stream_async(
258
241
model = None ,
259
242
tools : list [Callable ] = None ,
260
243
max_iterations : int = None ,
261
- handlers : list [Union [ AsyncStreamHandler , StreamHandler ] ] = None ,
244
+ handlers : list [StreamHandler ] = None ,
262
245
** kwargs ,
263
246
) -> AsyncGenerator [ControlFlowMessage , None ]:
264
247
"""
@@ -280,9 +263,7 @@ async def completion_stream_async(
280
263
snapshot_message = None
281
264
new_messages = []
282
265
283
- if handlers is None :
284
- handlers = []
285
-
266
+ handler = CompoundHandler (handlers = handlers or [])
286
267
if model is None :
287
268
model = controlflow .settings .model
288
269
@@ -309,43 +290,32 @@ async def completion_stream_async(
309
290
310
291
# on message created
311
292
if len (deltas ) == 1 :
312
- for h in handlers :
313
- if snapshot_message .has_tool_calls ():
314
- await maybe_coro (h .on_tool_call_created (delta = delta_message ))
315
- else :
316
- await maybe_coro (h .on_message_created (delta = delta_message ))
317
-
318
- # on message delta
319
- for h in handlers :
320
293
if snapshot_message .has_tool_calls ():
321
- await maybe_coro (
322
- h .on_tool_call_delta (
323
- delta = delta_message , snapshot = snapshot_message
324
- )
325
- )
294
+ handler .on_tool_call_created (delta = delta_message )
326
295
else :
327
- await maybe_coro (
328
- h .on_message_delta (
329
- delta = delta_message , snapshot = snapshot_message
330
- )
331
- )
332
-
333
- yield snapshot_message
296
+ handler .on_message_created (delta = delta_message )
334
297
335
- new_messages .append (snapshot_message )
336
-
337
- # on message done
338
- for h in handlers :
298
+ # on message delta
339
299
if snapshot_message .has_tool_calls ():
340
- await maybe_coro (h .on_tool_call_done (snapshot_message ))
300
+ handler .on_tool_call_delta (
301
+ delta = delta_message , snapshot = snapshot_message
302
+ )
341
303
else :
342
- await maybe_coro (h .on_message_done (snapshot_message ))
304
+ handler .on_message_delta (delta = delta_message , snapshot = snapshot_message )
305
+
306
+ # on message done
307
+ if snapshot_message .has_tool_calls ():
308
+ handler .on_tool_call_done (snapshot_message )
309
+ else :
310
+ handler .on_message_done (snapshot_message )
311
+
312
+ new_messages .append (snapshot_message )
313
+ yield snapshot_message
343
314
344
315
# tool calls
345
316
for tool_call in get_tool_calls ([snapshot_message ]):
346
317
tool_message = handle_tool_call (tool_call , tools )
347
- for h in handlers :
348
- await maybe_coro (h .on_tool_result (tool_message ))
318
+ handler .on_tool_result (tool_message )
349
319
new_messages .append (tool_message )
350
320
yield tool_message
351
321
0 commit comments