Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore ai name to messages #82

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,22 @@ async def run_once_async(self):
response_handler = ResponseHandler()
payload["handlers"].append(response_handler)

messages = []
for msg in payload["messages"]:
if isinstance(msg, AIMessage) and msg.name:
msg = msg.copy()
msg.content = (
f"Message from agent: {msg.name}\n\n{msg.content}"
)
messages.append(msg)

response_gen = await completion_async(
messages=payload["messages"],
messages=messages,
model=agent.model,
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
# assistant_name=agent.name,
# message_preprocessor=payload["message_preprocessor"],
ai_name=agent.name,
stream=True,
)
async for _ in response_gen:
Expand All @@ -245,14 +253,20 @@ def run_once(self):
response_handler = ResponseHandler()
payload["handlers"].append(response_handler)

messages = []
for msg in payload["messages"]:
if isinstance(msg, AIMessage) and msg.name:
msg = msg.copy()
msg.content = f"Message from agent: {msg.name}\n\n{msg.content}"
messages.append(msg)

response_gen = completion(
messages=payload["messages"],
messages=messages,
model=agent.model,
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
# assistant_name=agent.name,
# message_preprocessor=payload["message_preprocessor"],
ai_name=agent.name,
stream=True,
)
for _ in response_gen:
Expand Down
20 changes: 15 additions & 5 deletions src/controlflow/llm/completions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import math
from typing import AsyncGenerator, Callable, Generator, Optional, Union, cast
from typing import AsyncGenerator, Callable, Generator, Optional, Union

import langchain_core.language_models as lc_models

Expand All @@ -24,6 +24,7 @@ def _completion_generator(
model: lc_models.BaseChatModel,
tools: Optional[list[Callable]],
max_iterations: int,
ai_name: Optional[str],
stream: bool,
**kwargs,
) -> Generator[CompletionEvent, None, None]:
Expand All @@ -47,7 +48,9 @@ def _completion_generator(
input=messages + response_messages,
**kwargs,
)
response_message = AIMessage.from_message(response_message)
response_message = AIMessage.from_message(
response_message, name=ai_name
)

else:
deltas: list[AIMessageChunk] = []
Expand All @@ -57,7 +60,7 @@ def _completion_generator(
input=messages + response_messages,
**kwargs,
):
delta = AIMessageChunk.from_chunk(delta)
delta = AIMessageChunk.from_chunk(delta, name=ai_name)
deltas.append(delta)

if snapshot is None:
Expand Down Expand Up @@ -127,6 +130,7 @@ async def _completion_async_generator(
model: lc_models.BaseChatModel,
tools: Optional[list[Callable]],
max_iterations: int,
ai_name: Optional[str],
stream: bool,
**kwargs,
) -> AsyncGenerator[CompletionEvent, None]:
Expand All @@ -151,7 +155,9 @@ async def _completion_async_generator(
tools=tools or None,
**kwargs,
)
response_message = AIMessage.from_message(response_message)
response_message = AIMessage.from_message(
response_message, name=ai_name
)

else:
deltas: list[AIMessageChunk] = []
Expand All @@ -162,7 +168,7 @@ async def _completion_async_generator(
tools=tools or None,
**kwargs,
):
delta = cast(AIMessageChunk, delta)
delta = AIMessageChunk.from_chunk(delta, name=ai_name)
deltas.append(delta)

if snapshot is None:
Expand Down Expand Up @@ -251,6 +257,7 @@ def completion(
tools: list[Callable] = None,
max_iterations: int = None,
handlers: list[CompletionHandler] = None,
ai_name: Optional[str] = None,
stream: bool = False,
**kwargs,
) -> Union[list[MessageType], Generator[MessageType, None, None]]:
Expand All @@ -266,6 +273,7 @@ def completion(
model=model,
tools=tools,
max_iterations=max_iterations,
ai_name=ai_name,
stream=stream,
**kwargs,
)
Expand All @@ -286,6 +294,7 @@ async def completion_async(
tools: list[Callable] = None,
max_iterations: int = None,
handlers: list[CompletionHandler] = None,
ai_name: Optional[str] = None,
stream: bool = False,
**kwargs,
) -> Union[list[MessageType], Generator[MessageType, None, None]]:
Expand All @@ -301,6 +310,7 @@ async def completion_async(
model=model,
tools=tools,
max_iterations=max_iterations,
ai_name=ai_name,
stream=stream,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/controlflow/llm/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def to_message(self, **kwargs) -> AIMessage:
def __add__(self, other: Any) -> "AIMessageChunk": # type: ignore
result = super().__add__(other)
result.timestamp = self.timestamp
result.name = self.name
return result


Expand Down
8 changes: 5 additions & 3 deletions src/controlflow/llm/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ def wrapper(*args, **kwargs):

class Tool(langchain_core.tools.StructuredTool):
"""
A subclass of StructuredTool that is compatible with Pydantic v1 models
(which Langchain uses) and v2 models (which ControlFlow users).
A subclass of StructuredTool that is compatible with functions whose
signatures include either Pydantic v1 models (which Langchain uses) or v2
models (which ControlFlow users).

Note that THIS is a Pydantic v1 model because it subclasses the Langchain class.
Note that THIS class is a Pydantic v1 model because it subclasses the Langchain
class.
"""

tags: dict[str, Any] = pydantic.v1.Field(default_factory=dict)
Expand Down
Loading