Skip to content

Commit

Permalink
Merge pull request #381 from PrefectHQ/events
Browse files Browse the repository at this point in the history
Add new event types
  • Loading branch information
jlowin authored Nov 12, 2024
2 parents 504d3ad + afe44ad commit b3ad535
Show file tree
Hide file tree
Showing 18 changed files with 1,106 additions and 385 deletions.
31 changes: 19 additions & 12 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def _run_model(
from controlflow.events.events import (
AgentMessage,
AgentMessageDelta,
ToolCallEvent,
ToolResultEvent,
ToolResult,
)

tools = as_tools(self.get_tools() + tools)
Expand All @@ -312,12 +311,17 @@ def _run_model(
else:
response += delta

yield AgentMessageDelta(agent=self, delta=delta, snapshot=response)
yield from AgentMessageDelta(
agent=self, message_delta=delta, message_snapshot=response
).all_related_events(tools=tools)

else:
response: AIMessage = model.invoke(messages)

yield AgentMessage(agent=self, message=response)
yield from AgentMessage(agent=self, message=response).all_related_events(
tools=tools
)

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}
Expand All @@ -335,9 +339,8 @@ def _run_model(
logger.debug(f"Response: {response}")

for tool_call in response.tool_calls + response.invalid_tool_calls:
yield ToolCallEvent(agent=self, tool_call=tool_call)
result = handle_tool_call(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)
yield ToolResult(agent=self, tool_result=result)

@prefect_task(task_run_name="Call LLM")
async def _run_model_async(
Expand All @@ -350,8 +353,7 @@ async def _run_model_async(
from controlflow.events.events import (
AgentMessage,
AgentMessageDelta,
ToolCallEvent,
ToolResultEvent,
ToolResult,
)

tools = as_tools(self.get_tools() + tools)
Expand All @@ -371,12 +373,18 @@ async def _run_model_async(
else:
response += delta

yield AgentMessageDelta(agent=self, delta=delta, snapshot=response)
for event in AgentMessageDelta(
agent=self, message_delta=delta, message_snapshot=response
).all_related_events(tools=tools):
yield event

else:
response: AIMessage = await model.ainvoke(messages)

yield AgentMessage(agent=self, message=response)
for event in AgentMessage(agent=self, message=response).all_related_events(
tools=tools
):
yield event

create_markdown_artifact(
markdown=f"""
Expand All @@ -395,6 +403,5 @@ async def _run_model_async(
logger.debug(f"Response: {response}")

for tool_call in response.tool_calls + response.invalid_tool_calls:
yield ToolCallEvent(agent=self, tool_call=tool_call)
result = await handle_tool_call_async(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)
yield ToolResult(agent=self, tool_result=result)
2 changes: 1 addition & 1 deletion src/controlflow/events/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def to_messages(self, context: "CompileContext") -> list["BaseMessage"]:
return []

def __repr__(self) -> str:
return f"{self.event} ({self.timestamp})"
return f"<{self.event} {self.timestamp}>"


class UnpersistedEvent(Event):
Expand Down
151 changes: 128 additions & 23 deletions src/controlflow/events/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Literal, Optional, Union

import pydantic_core
from pydantic import ConfigDict, field_validator, model_validator

from controlflow.agents.agent import Agent
Expand All @@ -11,7 +12,8 @@
HumanMessage,
ToolMessage,
)
from controlflow.tools.tools import InvalidToolCall, ToolCall, ToolResult
from controlflow.tools.tools import InvalidToolCall, Tool, ToolCall
from controlflow.tools.tools import ToolResult as ToolResultPayload
from controlflow.utilities.logging import get_logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,7 +57,7 @@ class AgentMessage(Event):
message: dict

@field_validator("message", mode="before")
def _message(cls, v):
def _as_message_dict(cls, v):
if isinstance(v, BaseMessage):
v = v.model_dump()
v["type"] = "ai"
Expand All @@ -70,6 +72,34 @@ def _finalize(self):
def ai_message(self) -> AIMessage:
return AIMessage(**self.message)

def to_tool_calls(self, tools: list[Tool]) -> list["AgentToolCall"]:
calls = []
for tool_call in (
self.message["tool_calls"] + self.message["invalid_tool_calls"]
):
tool = next((t for t in tools if t.name == tool_call.get("name")), None)
if tool:
calls.append(
AgentToolCall(
agent=self.agent,
tool_call=tool_call,
tool=tool,
args=tool_call["args"],
agent_message_id=self.message.get("id"),
)
)
return calls

def to_content(self) -> "AgentContent":
return AgentContent(
agent=self.agent,
content=self.message["content"],
agent_message_id=self.message.get("id"),
)

def all_related_events(self, tools: list[Tool]) -> list[Event]:
return [self, self.to_content()] + self.to_tool_calls(tools)

def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
if self.agent.name == context.agent.name:
return [self.ai_message]
Expand All @@ -87,62 +117,137 @@ class AgentMessageDelta(UnpersistedEvent):
event: Literal["agent-message-delta"] = "agent-message-delta"

agent: Agent
delta: dict
snapshot: dict
message_delta: dict
message_snapshot: dict

@field_validator("delta", "snapshot", mode="before")
def _message(cls, v):
@field_validator("message_delta", "message_snapshot", mode="before")
def _as_message_dict(cls, v):
if isinstance(v, BaseMessage):
v = v.model_dump()
v["type"] = "AIMessageChunk"
return v

@model_validator(mode="after")
def _finalize(self):
self.delta["name"] = self.agent.name
self.snapshot["name"] = self.agent.name
self.message_delta["name"] = self.agent.name
self.message_snapshot["name"] = self.agent.name
return self

@property
def delta_message(self) -> AIMessageChunk:
return AIMessageChunk(**self.delta)
def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]:
deltas = []
for call_delta in self.message_delta.get("tool_call_chunks", []):
# First match chunks by index because streaming chunks come in sequence (0,1,2...)
# and this index lets us correlate deltas to their snapshots during streaming
chunk_snapshot = next(
(
c
for c in self.message_snapshot.get("tool_call_chunks", [])
if c.get("index", -1) == call_delta.get("index", -2)
),
None,
)

if chunk_snapshot and chunk_snapshot.get("id"):
# Once we have the matching chunk, use its ID to find the full tool call
# The full tool calls contain properly parsed arguments (as Python dicts)
# while chunks just contain raw JSON strings
call_snapshot = next(
(
c
for c in self.message_snapshot["tool_calls"]
if c.get("id") == chunk_snapshot["id"]
),
None,
)

@property
def snapshot_message(self) -> AIMessage:
return AIMessage(**self.snapshot | {"type": "ai"})
if call_snapshot:
tool = next(
(t for t in tools if t.name == call_snapshot.get("name")), None
)
# Use call_snapshot.args which is already parsed into a Python dict
# This avoids issues with pydantic's more limited JSON parser
deltas.append(
AgentToolCallDelta(
agent=self.agent,
tool_call_delta=call_delta,
tool_call_snapshot=call_snapshot,
tool=tool,
args=call_snapshot.get("args", {}),
agent_message_id=self.message_snapshot.get("id"),
)
)
return deltas

def to_content_delta(self) -> "AgentContentDelta":
return AgentContentDelta(
agent=self.agent,
content_delta=self.message_delta["content"],
content_snapshot=self.message_snapshot["content"],
agent_message_id=self.message_snapshot.get("id"),
)

def all_related_events(self, tools: list[Tool]) -> list[Event]:
return [self, self.to_content_delta()] + self.to_tool_call_deltas(tools)

class EndTurn(Event):
event: Literal["end-turn"] = "end-turn"

class AgentContent(UnpersistedEvent):
event: Literal["agent-content"] = "agent-content"
agent: Agent
next_agent_name: Optional[str] = None
agent_message_id: Optional[str] = None
content: Union[str, list[Union[str, dict]]]


class ToolCallEvent(Event):
class AgentContentDelta(UnpersistedEvent):
event: Literal["agent-content-delta"] = "agent-content-delta"
agent: Agent
agent_message_id: Optional[str] = None
content_delta: Union[str, list[Union[str, dict]]]
content_snapshot: Union[str, list[Union[str, dict]]]


class AgentToolCall(Event):
event: Literal["tool-call"] = "tool-call"
agent: Agent
agent_message_id: Optional[str] = None
tool_call: Union[ToolCall, InvalidToolCall]
tool: Optional[Tool] = None
args: dict = {}


class ToolResultEvent(Event):
class AgentToolCallDelta(UnpersistedEvent):
event: Literal["agent-tool-call-delta"] = "agent-tool-call-delta"
agent: Agent
agent_message_id: Optional[str] = None
tool_call_delta: dict
tool_call_snapshot: dict
tool: Optional[Tool] = None
args: dict = {}


class EndTurn(Event):
event: Literal["end-turn"] = "end-turn"
agent: Agent
next_agent_name: Optional[str] = None


class ToolResult(Event):
event: Literal["tool-result"] = "tool-result"
agent: Agent
tool_call: Union[ToolCall, InvalidToolCall]
tool_result: ToolResult
tool_result: ToolResultPayload

def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
if self.agent.name == context.agent.name:
return [
ToolMessage(
content=self.tool_result.str_result,
tool_call_id=self.tool_call["id"],
tool_call_id=self.tool_result.tool_call["id"],
name=self.agent.name,
)
]
else:
return OrchestratorMessage(
prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool '
f'call: {self.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} '
f'call: {self.tool_result.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} '
f'produced this result:',
content=self.tool_result.str_result,
name=self.agent.name,
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/events/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_event_validator() -> TypeAdapter:
AgentMessage,
EndTurn,
OrchestratorMessage,
ToolResultEvent,
ToolResult,
UserMessage,
)

Expand All @@ -30,7 +30,7 @@ def get_event_validator() -> TypeAdapter:
UserMessage,
AgentMessage,
EndTurn,
ToolResultEvent,
ToolResult,
Event,
]
return TypeAdapter(list[types])
Expand Down
12 changes: 6 additions & 6 deletions src/controlflow/events/message_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from controlflow.events.base import Event, UnpersistedEvent
from controlflow.events.events import (
AgentMessage,
ToolCallEvent,
ToolResultEvent,
AgentToolCall,
ToolResult,
)
from controlflow.llm.messages import (
AIMessage,
Expand All @@ -28,8 +28,8 @@
class CombinedAgentMessage(UnpersistedEvent):
event: Literal["combined-agent-message"] = "combined-agent-message"
agent_message: AgentMessage
tool_call: list[ToolCallEvent] = []
tool_results: list[ToolResultEvent] = []
tool_call: list[AgentToolCall] = []
tool_results: list[ToolResult] = []

def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
messages = []
Expand Down Expand Up @@ -213,9 +213,9 @@ def organize_events(self, context: CompileContext) -> list[Event]:
event.ai_message.tool_calls + event.ai_message.invalid_tool_calls
):
tool_calls[tc["id"]] = combined_event
elif isinstance(event, ToolResultEvent):
elif isinstance(event, ToolResult):
combined_event: CombinedAgentMessage = tool_calls.get(
event.tool_call["id"]
event.tool_result.tool_call["id"]
)
if combined_event:
combined_event.tool_results.append(event)
Expand Down
Loading

0 comments on commit b3ad535

Please sign in to comment.