Skip to content

Commit aecf16e

Browse files
authored
Merge pull request #82 from PrefectHQ/ai-name
Restore ai name to messages
2 parents 404c22c + b768dff commit aecf16e

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

src/controlflow/core/controller/controller.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,22 @@ async def run_once_async(self):
217217
response_handler = ResponseHandler()
218218
payload["handlers"].append(response_handler)
219219

220+
messages = []
221+
for msg in payload["messages"]:
222+
if isinstance(msg, AIMessage) and msg.name:
223+
msg = msg.copy()
224+
msg.content = (
225+
f"Message from agent: {msg.name}\n\n{msg.content}"
226+
)
227+
messages.append(msg)
228+
220229
response_gen = await completion_async(
221-
messages=payload["messages"],
230+
messages=messages,
222231
model=agent.model,
223232
tools=payload["tools"],
224233
handlers=payload["handlers"],
225234
max_iterations=1,
226-
# assistant_name=agent.name,
227-
# message_preprocessor=payload["message_preprocessor"],
235+
ai_name=agent.name,
228236
stream=True,
229237
)
230238
async for _ in response_gen:
@@ -245,14 +253,20 @@ def run_once(self):
245253
response_handler = ResponseHandler()
246254
payload["handlers"].append(response_handler)
247255

256+
messages = []
257+
for msg in payload["messages"]:
258+
if isinstance(msg, AIMessage) and msg.name:
259+
msg = msg.copy()
260+
msg.content = f"Message from agent: {msg.name}\n\n{msg.content}"
261+
messages.append(msg)
262+
248263
response_gen = completion(
249-
messages=payload["messages"],
264+
messages=messages,
250265
model=agent.model,
251266
tools=payload["tools"],
252267
handlers=payload["handlers"],
253268
max_iterations=1,
254-
# assistant_name=agent.name,
255-
# message_preprocessor=payload["message_preprocessor"],
269+
ai_name=agent.name,
256270
stream=True,
257271
)
258272
for _ in response_gen:

src/controlflow/llm/completions.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22
import math
3-
from typing import AsyncGenerator, Callable, Generator, Optional, Union, cast
3+
from typing import AsyncGenerator, Callable, Generator, Optional, Union
44

55
import langchain_core.language_models as lc_models
66

@@ -24,6 +24,7 @@ def _completion_generator(
2424
model: lc_models.BaseChatModel,
2525
tools: Optional[list[Callable]],
2626
max_iterations: int,
27+
ai_name: Optional[str],
2728
stream: bool,
2829
**kwargs,
2930
) -> Generator[CompletionEvent, None, None]:
@@ -47,7 +48,9 @@ def _completion_generator(
4748
input=messages + response_messages,
4849
**kwargs,
4950
)
50-
response_message = AIMessage.from_message(response_message)
51+
response_message = AIMessage.from_message(
52+
response_message, name=ai_name
53+
)
5154

5255
else:
5356
deltas: list[AIMessageChunk] = []
@@ -57,7 +60,7 @@ def _completion_generator(
5760
input=messages + response_messages,
5861
**kwargs,
5962
):
60-
delta = AIMessageChunk.from_chunk(delta)
63+
delta = AIMessageChunk.from_chunk(delta, name=ai_name)
6164
deltas.append(delta)
6265

6366
if snapshot is None:
@@ -127,6 +130,7 @@ async def _completion_async_generator(
127130
model: lc_models.BaseChatModel,
128131
tools: Optional[list[Callable]],
129132
max_iterations: int,
133+
ai_name: Optional[str],
130134
stream: bool,
131135
**kwargs,
132136
) -> AsyncGenerator[CompletionEvent, None]:
@@ -151,7 +155,9 @@ async def _completion_async_generator(
151155
tools=tools or None,
152156
**kwargs,
153157
)
154-
response_message = AIMessage.from_message(response_message)
158+
response_message = AIMessage.from_message(
159+
response_message, name=ai_name
160+
)
155161

156162
else:
157163
deltas: list[AIMessageChunk] = []
@@ -162,7 +168,7 @@ async def _completion_async_generator(
162168
tools=tools or None,
163169
**kwargs,
164170
):
165-
delta = cast(AIMessageChunk, delta)
171+
delta = AIMessageChunk.from_chunk(delta, name=ai_name)
166172
deltas.append(delta)
167173

168174
if snapshot is None:
@@ -251,6 +257,7 @@ def completion(
251257
tools: list[Callable] = None,
252258
max_iterations: int = None,
253259
handlers: list[CompletionHandler] = None,
260+
ai_name: Optional[str] = None,
254261
stream: bool = False,
255262
**kwargs,
256263
) -> Union[list[MessageType], Generator[MessageType, None, None]]:
@@ -266,6 +273,7 @@ def completion(
266273
model=model,
267274
tools=tools,
268275
max_iterations=max_iterations,
276+
ai_name=ai_name,
269277
stream=stream,
270278
**kwargs,
271279
)
@@ -286,6 +294,7 @@ async def completion_async(
286294
tools: list[Callable] = None,
287295
max_iterations: int = None,
288296
handlers: list[CompletionHandler] = None,
297+
ai_name: Optional[str] = None,
289298
stream: bool = False,
290299
**kwargs,
291300
) -> Union[list[MessageType], Generator[MessageType, None, None]]:
@@ -301,6 +310,7 @@ async def completion_async(
301310
model=model,
302311
tools=tools,
303312
max_iterations=max_iterations,
313+
ai_name=ai_name,
304314
stream=stream,
305315
**kwargs,
306316
)

src/controlflow/llm/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def to_message(self, **kwargs) -> AIMessage:
6363
def __add__(self, other: Any) -> "AIMessageChunk": # type: ignore
6464
result = super().__add__(other)
6565
result.timestamp = self.timestamp
66+
result.name = self.name
6667
return result
6768

6869

src/controlflow/llm/tools.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ def wrapper(*args, **kwargs):
4040

4141
class Tool(langchain_core.tools.StructuredTool):
4242
"""
43-
A subclass of StructuredTool that is compatible with Pydantic v1 models
44-
(which Langchain uses) and v2 models (which ControlFlow users).
43+
A subclass of StructuredTool that is compatible with functions whose
44+
signatures include either Pydantic v1 models (which Langchain uses) or v2
45+
models (which ControlFlow users).
4546
46-
Note that THIS is a Pydantic v1 model because it subclasses the Langchain class.
47+
Note that THIS class is a Pydantic v1 model because it subclasses the Langchain
48+
class.
4749
"""
4850

4951
tags: dict[str, Any] = pydantic.v1.Field(default_factory=dict)

0 commit comments

Comments
 (0)