Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new event types #381

Merged
merged 12 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def _run_model(
from controlflow.events.events import (
AgentMessage,
AgentMessageDelta,
ToolCallEvent,
ToolResultEvent,
ToolResult,
)

tools = as_tools(self.get_tools() + tools)
Expand All @@ -312,12 +311,17 @@ def _run_model(
else:
response += delta

yield AgentMessageDelta(agent=self, delta=delta, snapshot=response)
yield from AgentMessageDelta(
agent=self, message_delta=delta, message_snapshot=response
).all_related_events(tools=tools)

else:
response: AIMessage = model.invoke(messages)

yield AgentMessage(agent=self, message=response)
yield from AgentMessage(agent=self, message=response).all_related_events(
tools=tools
)

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}
Expand All @@ -335,9 +339,8 @@ def _run_model(
logger.debug(f"Response: {response}")

for tool_call in response.tool_calls + response.invalid_tool_calls:
yield ToolCallEvent(agent=self, tool_call=tool_call)
result = handle_tool_call(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)
yield ToolResult(agent=self, tool_result=result)

@prefect_task(task_run_name="Call LLM")
async def _run_model_async(
Expand All @@ -350,8 +353,7 @@ async def _run_model_async(
from controlflow.events.events import (
AgentMessage,
AgentMessageDelta,
ToolCallEvent,
ToolResultEvent,
ToolResult,
)

tools = as_tools(self.get_tools() + tools)
Expand All @@ -371,12 +373,18 @@ async def _run_model_async(
else:
response += delta

yield AgentMessageDelta(agent=self, delta=delta, snapshot=response)
for event in AgentMessageDelta(
agent=self, message_delta=delta, message_snapshot=response
).all_related_events(tools=tools):
yield event

else:
response: AIMessage = await model.ainvoke(messages)

yield AgentMessage(agent=self, message=response)
for event in AgentMessage(agent=self, message=response).all_related_events(
tools=tools
):
yield event

create_markdown_artifact(
markdown=f"""
Expand All @@ -395,6 +403,5 @@ async def _run_model_async(
logger.debug(f"Response: {response}")

for tool_call in response.tool_calls + response.invalid_tool_calls:
yield ToolCallEvent(agent=self, tool_call=tool_call)
result = await handle_tool_call_async(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)
yield ToolResult(agent=self, tool_result=result)
2 changes: 1 addition & 1 deletion src/controlflow/events/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def to_messages(self, context: "CompileContext") -> list["BaseMessage"]:
return []

def __repr__(self) -> str:
return f"{self.event} ({self.timestamp})"
return f"<{self.event} {self.timestamp}>"


class UnpersistedEvent(Event):
Expand Down
151 changes: 128 additions & 23 deletions src/controlflow/events/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Literal, Optional, Union

import pydantic_core
from pydantic import ConfigDict, field_validator, model_validator

from controlflow.agents.agent import Agent
Expand All @@ -11,7 +12,8 @@
HumanMessage,
ToolMessage,
)
from controlflow.tools.tools import InvalidToolCall, ToolCall, ToolResult
from controlflow.tools.tools import InvalidToolCall, Tool, ToolCall
from controlflow.tools.tools import ToolResult as ToolResultPayload
from controlflow.utilities.logging import get_logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,7 +57,7 @@ class AgentMessage(Event):
message: dict

@field_validator("message", mode="before")
def _message(cls, v):
def _as_message_dict(cls, v):
if isinstance(v, BaseMessage):
v = v.model_dump()
v["type"] = "ai"
Expand All @@ -70,6 +72,34 @@ def _finalize(self):
def ai_message(self) -> AIMessage:
return AIMessage(**self.message)

def to_tool_calls(self, tools: list[Tool]) -> list["AgentToolCall"]:
calls = []
for tool_call in (
self.message["tool_calls"] + self.message["invalid_tool_calls"]
):
tool = next((t for t in tools if t.name == tool_call.get("name")), None)
if tool:
calls.append(
AgentToolCall(
agent=self.agent,
tool_call=tool_call,
tool=tool,
args=tool_call["args"],
agent_message_id=self.message.get("id"),
)
)
return calls

def to_content(self) -> "AgentContent":
return AgentContent(
agent=self.agent,
content=self.message["content"],
agent_message_id=self.message.get("id"),
)

def all_related_events(self, tools: list[Tool]) -> list[Event]:
return [self, self.to_content()] + self.to_tool_calls(tools)

def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
if self.agent.name == context.agent.name:
return [self.ai_message]
Expand All @@ -87,62 +117,137 @@ class AgentMessageDelta(UnpersistedEvent):
event: Literal["agent-message-delta"] = "agent-message-delta"

agent: Agent
delta: dict
snapshot: dict
message_delta: dict
message_snapshot: dict

@field_validator("delta", "snapshot", mode="before")
def _message(cls, v):
@field_validator("message_delta", "message_snapshot", mode="before")
def _as_message_dict(cls, v):
if isinstance(v, BaseMessage):
v = v.model_dump()
v["type"] = "AIMessageChunk"
return v

@model_validator(mode="after")
def _finalize(self):
self.delta["name"] = self.agent.name
self.snapshot["name"] = self.agent.name
self.message_delta["name"] = self.agent.name
self.message_snapshot["name"] = self.agent.name
return self

@property
def delta_message(self) -> AIMessageChunk:
return AIMessageChunk(**self.delta)
def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]:
deltas = []
for call_delta in self.message_delta.get("tool_call_chunks", []):
# First match chunks by index because streaming chunks come in sequence (0,1,2...)
# and this index lets us correlate deltas to their snapshots during streaming
chunk_snapshot = next(
(
c
for c in self.message_snapshot.get("tool_call_chunks", [])
if c.get("index", -1) == call_delta.get("index", -2)
),
None,
)

if chunk_snapshot and chunk_snapshot.get("id"):
# Once we have the matching chunk, use its ID to find the full tool call
# The full tool calls contain properly parsed arguments (as Python dicts)
# while chunks just contain raw JSON strings
call_snapshot = next(
(
c
for c in self.message_snapshot["tool_calls"]
if c.get("id") == chunk_snapshot["id"]
),
None,
)

@property
def snapshot_message(self) -> AIMessage:
return AIMessage(**self.snapshot | {"type": "ai"})
if call_snapshot:
tool = next(
(t for t in tools if t.name == call_snapshot.get("name")), None
)
# Use call_snapshot.args which is already parsed into a Python dict
# This avoids issues with pydantic's more limited JSON parser
deltas.append(
AgentToolCallDelta(
agent=self.agent,
tool_call_delta=call_delta,
tool_call_snapshot=call_snapshot,
tool=tool,
args=call_snapshot.get("args", {}),
agent_message_id=self.message_snapshot.get("id"),
)
)
return deltas

def to_content_delta(self) -> "AgentContentDelta":
return AgentContentDelta(
agent=self.agent,
content_delta=self.message_delta["content"],
content_snapshot=self.message_snapshot["content"],
agent_message_id=self.message_snapshot.get("id"),
)

def all_related_events(self, tools: list[Tool]) -> list[Event]:
return [self, self.to_content_delta()] + self.to_tool_call_deltas(tools)

class EndTurn(Event):
event: Literal["end-turn"] = "end-turn"

class AgentContent(UnpersistedEvent):
event: Literal["agent-content"] = "agent-content"
agent: Agent
next_agent_name: Optional[str] = None
agent_message_id: Optional[str] = None
content: Union[str, list[Union[str, dict]]]


class ToolCallEvent(Event):
class AgentContentDelta(UnpersistedEvent):
event: Literal["agent-content-delta"] = "agent-content-delta"
agent: Agent
agent_message_id: Optional[str] = None
content_delta: Union[str, list[Union[str, dict]]]
content_snapshot: Union[str, list[Union[str, dict]]]


class AgentToolCall(Event):
event: Literal["tool-call"] = "tool-call"
agent: Agent
agent_message_id: Optional[str] = None
tool_call: Union[ToolCall, InvalidToolCall]
tool: Optional[Tool] = None
args: dict = {}


class ToolResultEvent(Event):
class AgentToolCallDelta(UnpersistedEvent):
event: Literal["agent-tool-call-delta"] = "agent-tool-call-delta"
agent: Agent
agent_message_id: Optional[str] = None
tool_call_delta: dict
tool_call_snapshot: dict
tool: Optional[Tool] = None
args: dict = {}


class EndTurn(Event):
event: Literal["end-turn"] = "end-turn"
agent: Agent
next_agent_name: Optional[str] = None


class ToolResult(Event):
event: Literal["tool-result"] = "tool-result"
agent: Agent
tool_call: Union[ToolCall, InvalidToolCall]
tool_result: ToolResult
tool_result: ToolResultPayload

def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
if self.agent.name == context.agent.name:
return [
ToolMessage(
content=self.tool_result.str_result,
tool_call_id=self.tool_call["id"],
tool_call_id=self.tool_result.tool_call["id"],
name=self.agent.name,
)
]
else:
return OrchestratorMessage(
prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool '
f'call: {self.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} '
f'call: {self.tool_result.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} '
f'produced this result:',
content=self.tool_result.str_result,
name=self.agent.name,
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/events/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_event_validator() -> TypeAdapter:
AgentMessage,
EndTurn,
OrchestratorMessage,
ToolResultEvent,
ToolResult,
UserMessage,
)

Expand All @@ -30,7 +30,7 @@ def get_event_validator() -> TypeAdapter:
UserMessage,
AgentMessage,
EndTurn,
ToolResultEvent,
ToolResult,
Event,
]
return TypeAdapter(list[types])
Expand Down
12 changes: 6 additions & 6 deletions src/controlflow/events/message_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from controlflow.events.base import Event, UnpersistedEvent
from controlflow.events.events import (
AgentMessage,
ToolCallEvent,
ToolResultEvent,
AgentToolCall,
ToolResult,
)
from controlflow.llm.messages import (
AIMessage,
Expand All @@ -28,8 +28,8 @@
class CombinedAgentMessage(UnpersistedEvent):
event: Literal["combined-agent-message"] = "combined-agent-message"
agent_message: AgentMessage
tool_call: list[ToolCallEvent] = []
tool_results: list[ToolResultEvent] = []
tool_call: list[AgentToolCall] = []
tool_results: list[ToolResult] = []

def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
messages = []
Expand Down Expand Up @@ -213,9 +213,9 @@ def organize_events(self, context: CompileContext) -> list[Event]:
event.ai_message.tool_calls + event.ai_message.invalid_tool_calls
):
tool_calls[tc["id"]] = combined_event
elif isinstance(event, ToolResultEvent):
elif isinstance(event, ToolResult):
combined_event: CombinedAgentMessage = tool_calls.get(
event.tool_call["id"]
event.tool_result.tool_call["id"]
)
if combined_event:
combined_event.tool_results.append(event)
Expand Down
Loading