From eb7534cd6e9ab5f9ce6d808d473d20895cc6c07f Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 7 Nov 2024 18:32:43 -0500 Subject: [PATCH 01/12] Add new event types --- src/controlflow/agents/agent.py | 36 ++-- src/controlflow/events/events.py | 89 +++++++++- src/controlflow/events/history.py | 4 +- src/controlflow/events/message_compiler.py | 10 +- src/controlflow/handlers/__init__.py | 0 src/controlflow/handlers/callback_handler.py | 24 +++ .../print_handler.py | 14 +- src/controlflow/handlers/queue_handler.py | 56 ++++++ src/controlflow/orchestration/handler.py | 20 +-- src/controlflow/orchestration/orchestrator.py | 8 +- src/controlflow/settings.py | 6 +- src/controlflow/stream.py | 167 ++++++++++++++++++ 12 files changed, 388 insertions(+), 46 deletions(-) create mode 100644 src/controlflow/handlers/__init__.py create mode 100644 src/controlflow/handlers/callback_handler.py rename src/controlflow/{orchestration => handlers}/print_handler.py (95%) create mode 100644 src/controlflow/handlers/queue_handler.py create mode 100644 src/controlflow/stream.py diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index e14ee063..5bbd0da5 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -289,10 +289,13 @@ def _run_model( model_kwargs: Optional[dict] = None, ) -> Generator[Event, None, None]: from controlflow.events.events import ( + AgentContent, + AgentContentDelta, AgentMessage, AgentMessageDelta, - ToolCallEvent, - ToolResultEvent, + AgentToolCall, + AgentToolCallDelta, + ToolResult, ) tools = as_tools(self.get_tools() + tools) @@ -312,12 +315,17 @@ def _run_model( else: response += delta - yield AgentMessageDelta(agent=self, delta=delta, snapshot=response) + yield from AgentMessageDelta( + agent=self, delta=delta, snapshot=response + ).all_related_events(tools=tools) else: response: AIMessage = model.invoke(messages) - yield AgentMessage(agent=self, message=response) + yield from AgentMessage(agent=self, message=response).all_related_events( + tools=tools + ) + create_markdown_artifact( markdown=f""" {response.content or '(No content)'} @@ -335,9 +343,8 @@ def _run_model( logger.debug(f"Response: {response}") for tool_call in response.tool_calls + response.invalid_tool_calls: - yield ToolCallEvent(agent=self, tool_call=tool_call) result = handle_tool_call(tool_call, tools=tools) - yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + yield ToolResult(agent=self, tool_call=tool_call, tool_result=result) @prefect_task(task_run_name="Call LLM") async def _run_model_async( @@ -350,8 +357,8 @@ async def _run_model_async( from controlflow.events.events import ( AgentMessage, AgentMessageDelta, - ToolCallEvent, - ToolResultEvent, + AgentToolCall, + ToolResult, ) tools = as_tools(self.get_tools() + tools) @@ -371,12 +378,18 @@ async def _run_model_async( else: response += delta - yield AgentMessageDelta(agent=self, delta=delta, snapshot=response) + for event in AgentMessageDelta( + agent=self, delta=delta, snapshot=response + ).all_related_events(tools=tools): + yield event else: response: AIMessage = await model.ainvoke(messages) - yield AgentMessage(agent=self, message=response) + for event in AgentMessage(agent=self, message=response).all_related_events( + tools=tools + ): + yield event create_markdown_artifact( markdown=f""" @@ -395,6 +408,5 @@ async def _run_model_async( logger.debug(f"Response: {response}") for tool_call in response.tool_calls + response.invalid_tool_calls: - yield ToolCallEvent(agent=self, tool_call=tool_call) result = await handle_tool_call_async(tool_call, tools=tools) - yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + yield ToolResult(agent=self, tool_call=tool_call, tool_result=result) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index 6e5c6d17..5a3845d8 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -11,7 +11,7 @@ HumanMessage, ToolMessage, ) -from controlflow.tools.tools import InvalidToolCall, ToolCall, ToolResult +from controlflow.tools.tools import InvalidToolCall, Tool, ToolCall, ToolResult from controlflow.utilities.logging import get_logger if TYPE_CHECKING: @@ -70,6 +70,29 @@ def _finalize(self): def ai_message(self) -> AIMessage: return AIMessage(**self.message) + def to_tool_calls(self, tools: list[Tool]) -> list["AgentToolCall"]: + calls = [] + for tool_call in ( + self.message["tool_calls"] + self.message["invalid_tool_calls"] + ): + tool = next((t for t in tools if t.name == tool_call.get("name")), None) + if tool: + calls.append( + AgentToolCall( + agent=self.agent, + tool_call=tool_call, + tool=tool, + args=tool_call["args"], + ) + ) + return calls + + def to_content(self) -> "AgentContent": + return AgentContent(agent=self.agent, content=self.message["content"]) + + def all_related_events(self, tools: list[Tool]) -> list[Event]: + return [self, self.to_content()] + self.to_tool_calls(tools) + def to_messages(self, context: "CompileContext") -> list[BaseMessage]: if self.agent.name == context.agent.name: return [self.ai_message] @@ -111,6 +134,64 @@ def delta_message(self) -> AIMessageChunk: def snapshot_message(self) -> AIMessage: return AIMessage(**self.snapshot | {"type": "ai"}) + def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: + deltas = [] + for call_delta in self.delta["tool_call_chunks"]: + # try to retrieve the matching snapshot based on index + call_snapshot = next( + ( + c + for i, c in enumerate(self.snapshot["tool_calls"]) + if i == call_delta.get("index") + ), + None, + ) + + tool = next((t for t in tools if t.name == call_snapshot.get("name")), None) + if call_snapshot: + deltas.append( + AgentToolCallDelta( + agent=self.agent, + delta=call_delta, + snapshot=call_snapshot, + tool=tool, + args=call_snapshot["args"], + ) + ) + return deltas + + def to_content_delta(self) -> "AgentContentDelta": + return AgentContentDelta( + agent=self.agent, + delta=self.delta["content"], + snapshot=self.snapshot["content"], + ) + + def all_related_events(self, tools: list[Tool]) -> list[Event]: + return [self, self.to_content_delta()] + self.to_tool_call_deltas(tools) + + +class AgentContent(UnpersistedEvent): + event: Literal["agent-content"] = "agent-content" + agent: Agent + content: Union[str, list[Union[str, dict]]] + + +class AgentContentDelta(UnpersistedEvent): + event: Literal["agent-content-delta"] = "agent-content-delta" + agent: Agent + delta: str + snapshot: str + + +class AgentToolCallDelta(UnpersistedEvent): + event: Literal["agent-tool-call-delta"] = "agent-tool-call-delta" + agent: Agent + delta: dict + snapshot: dict + tool: Tool + args: dict + class EndTurn(Event): event: Literal["end-turn"] = "end-turn" @@ -118,13 +199,15 @@ class EndTurn(Event): next_agent_name: Optional[str] = None -class ToolCallEvent(Event): +class AgentToolCall(Event): event: Literal["tool-call"] = "tool-call" agent: Agent tool_call: Union[ToolCall, InvalidToolCall] + tool: Tool + args: dict -class ToolResultEvent(Event): +class ToolResult(Event): event: Literal["tool-result"] = "tool-result" agent: Agent tool_call: Union[ToolCall, InvalidToolCall] diff --git a/src/controlflow/events/history.py b/src/controlflow/events/history.py index e62cc660..154d5185 100644 --- a/src/controlflow/events/history.py +++ b/src/controlflow/events/history.py @@ -21,7 +21,7 @@ def get_event_validator() -> TypeAdapter: AgentMessage, EndTurn, OrchestratorMessage, - ToolResultEvent, + ToolResult, UserMessage, ) @@ -30,7 +30,7 @@ def get_event_validator() -> TypeAdapter: UserMessage, AgentMessage, EndTurn, - ToolResultEvent, + ToolResult, Event, ] return TypeAdapter(list[types]) diff --git a/src/controlflow/events/message_compiler.py b/src/controlflow/events/message_compiler.py index aff21195..63a8f23e 100644 --- a/src/controlflow/events/message_compiler.py +++ b/src/controlflow/events/message_compiler.py @@ -8,8 +8,8 @@ from controlflow.events.base import Event, UnpersistedEvent from controlflow.events.events import ( AgentMessage, - ToolCallEvent, - ToolResultEvent, + AgentToolCall, + ToolResult, ) from controlflow.llm.messages import ( AIMessage, @@ -28,8 +28,8 @@ class CombinedAgentMessage(UnpersistedEvent): event: Literal["combined-agent-message"] = "combined-agent-message" agent_message: AgentMessage - tool_call: list[ToolCallEvent] = [] - tool_results: list[ToolResultEvent] = [] + tool_call: list[AgentToolCall] = [] + tool_results: list[ToolResult] = [] def to_messages(self, context: "CompileContext") -> list[BaseMessage]: messages = [] @@ -213,7 +213,7 @@ def organize_events(self, context: CompileContext) -> list[Event]: event.ai_message.tool_calls + event.ai_message.invalid_tool_calls ): tool_calls[tc["id"]] = combined_event - elif isinstance(event, ToolResultEvent): + elif isinstance(event, ToolResult): combined_event: CombinedAgentMessage = tool_calls.get( event.tool_call["id"] ) diff --git a/src/controlflow/handlers/__init__.py b/src/controlflow/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/controlflow/handlers/callback_handler.py b/src/controlflow/handlers/callback_handler.py new file mode 100644 index 00000000..ec73a19f --- /dev/null +++ b/src/controlflow/handlers/callback_handler.py @@ -0,0 +1,24 @@ +""" +A handler that calls a callback function for each event. +""" + +from typing import TYPE_CHECKING, Any, Callable, Coroutine + +from controlflow.events.base import Event +from controlflow.orchestration.handler import AsyncHandler, Handler + + +class CallbackHandler(Handler): + def __init__(self, callback: Callable[[Event], None]): + self.callback = callback + + def on_event(self, event: Event): + self.callback(event) + + +class AsyncCallbackHandler(AsyncHandler): + def __init__(self, callback: Callable[[Event], Coroutine[Any, Any, None]]): + self.callback = callback + + async def on_event(self, event: Event): + await self.callback(event) diff --git a/src/controlflow/orchestration/print_handler.py b/src/controlflow/handlers/print_handler.py similarity index 95% rename from src/controlflow/orchestration/print_handler.py rename to src/controlflow/handlers/print_handler.py index 2d05918c..0cb141c5 100644 --- a/src/controlflow/orchestration/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -15,8 +15,8 @@ from controlflow.events.events import ( AgentMessage, AgentMessageDelta, - ToolCallEvent, - ToolResultEvent, + AgentToolCall, + ToolResult, ) from controlflow.events.orchestrator_events import ( OrchestratorEnd, @@ -44,7 +44,7 @@ def update_live(self, latest: BaseMessage = None): # gather all tool events first for _, event in events: - if isinstance(event, ToolResultEvent): + if isinstance(event, ToolResult): tool_results[event.tool_call["id"]] = event for _, event in events: @@ -83,7 +83,7 @@ def on_agent_message(self, event: AgentMessage): self.events[event.ai_message.id] = event self.update_live() - def on_tool_call(self, event: ToolCallEvent): + def on_tool_call(self, event: AgentToolCall): # if collecting input on the terminal, pause the live display # to avoid overwriting the input prompt if event.tool_call["name"] == "cli_input": @@ -91,7 +91,7 @@ def on_tool_call(self, event: ToolCallEvent): self.live.stop() self.events.clear() - def on_tool_result(self, event: ToolResultEvent): + def on_tool_result(self, event: ToolResult): # skip completion tools if configured to do so if not self.include_completion_tools and event.tool_result.tool_metadata.get( "is_completion_tool" @@ -135,7 +135,7 @@ def status(icon, text) -> Table: def format_event( event: Union[AgentMessageDelta, AgentMessage], - tool_results: dict[str, ToolResultEvent] = None, + tool_results: dict[str, ToolResult] = None, ) -> Panel: title = f"Agent: {event.agent.name}" @@ -200,7 +200,7 @@ def format_tool_call(tool_call: ToolCall) -> Panel: return status(Spinner("dots"), f'Tool call: "{tool_call["name"]}"') -def format_tool_result(event: ToolResultEvent) -> Panel: +def format_tool_result(event: ToolResult) -> Panel: if event.tool_result.is_error: icon = ":x:" else: diff --git a/src/controlflow/handlers/queue_handler.py b/src/controlflow/handlers/queue_handler.py new file mode 100644 index 00000000..0441823d --- /dev/null +++ b/src/controlflow/handlers/queue_handler.py @@ -0,0 +1,56 @@ +""" +A handler that queues events in a queue. +""" + +import asyncio +import queue +from typing import TYPE_CHECKING, Any, Callable, Coroutine + +from controlflow.events.base import Event +from controlflow.events.events import ( + AgentMessage, + AgentMessageDelta, + AgentToolCall, + ToolResult, +) +from controlflow.orchestration.handler import AsyncHandler, Handler + + +class QueueHandler(Handler): + def __init__( + self, queue: queue.Queue = None, event_filter: Callable[[Event], bool] = None + ): + self.queue = queue or queue.Queue() + self.event_filter = event_filter + + def on_event(self, event: Event): + if self.event_filter and not self.event_filter(event): + return + self.queue.put(event) + + +class AsyncQueueHandler(AsyncHandler): + def __init__( + self, queue: asyncio.Queue = None, event_filter: Callable[[Event], bool] = None + ): + self.queue = queue or asyncio.Queue() + self.event_filter = event_filter + + async def on_event(self, event: Event): + if self.event_filter and not self.event_filter(event): + return + await self.queue.put(event) + + +def message_filter(event: Event) -> bool: + return isinstance(event, (AgentMessage, AgentMessageDelta)) + + +def tool_filter(event: Event) -> bool: + return isinstance(event, (AgentToolCall, ToolResult)) + + +def result_filter(event: Event) -> bool: + return isinstance(event, (AgentToolCall, ToolResult)) and event.tool_call[ + "name" + ].startswith("mark_task_") diff --git a/src/controlflow/orchestration/handler.py b/src/controlflow/orchestration/handler.py index 9843a744..bd77a772 100644 --- a/src/controlflow/orchestration/handler.py +++ b/src/controlflow/orchestration/handler.py @@ -7,10 +7,10 @@ from controlflow.events.events import ( AgentMessage, AgentMessageDelta, + AgentToolCall, EndTurn, OrchestratorMessage, - ToolCallEvent, - ToolResultEvent, + ToolResult, UserMessage, ) from controlflow.events.orchestrator_events import ( @@ -54,10 +54,10 @@ def on_agent_message(self, event: "AgentMessage"): def on_agent_message_delta(self, event: "AgentMessageDelta"): pass - def on_tool_call(self, event: "ToolCallEvent"): + def on_tool_call(self, event: "AgentToolCall"): pass - def on_tool_result(self, event: "ToolResultEvent"): + def on_tool_result(self, event: "ToolResult"): pass def on_orchestrator_message(self, event: "OrchestratorMessage"): @@ -70,14 +70,6 @@ def on_end_turn(self, event: "EndTurn"): pass -class CallbackHandler(Handler): - def __init__(self, callback: Callable[[Event], None]): - self.callback = callback - - def on_event(self, event: Event): - self.callback(event) - - class AsyncHandler: async def handle(self, event: Event): """ @@ -112,10 +104,10 @@ async def on_agent_message(self, event: "AgentMessage"): async def on_agent_message_delta(self, event: "AgentMessageDelta"): pass - async def on_tool_call(self, event: "ToolCallEvent"): + async def on_tool_call(self, event: "AgentToolCall"): pass - async def on_tool_result(self, event: "ToolResultEvent"): + async def on_tool_result(self, event: "ToolResult"): pass async def on_orchestrator_message(self, event: "OrchestratorMessage"): diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 94292639..97993b33 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -72,10 +72,14 @@ def _validate_handlers(cls, v): Returns: list[Handler]: The validated list of handlers. """ - from controlflow.orchestration.print_handler import PrintHandler + from controlflow.handlers.print_handler import PrintHandler if v is None and controlflow.settings.enable_default_print_handler: - v = [PrintHandler()] + v = [ + PrintHandler( + include_completion_tools=controlflow.settings.default_print_handler_include_completion_tools + ) + ] return v or [] def handle_event(self, event: Event): diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index c7f839c3..cb2fbe72 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -73,7 +73,11 @@ def _validate_pretty_print_agent_events(cls, data: dict) -> dict: enable_default_print_handler: bool = Field( default=True, description="If True, a PrintHandler will be enabled and automatically " - "pretty-print agent events. Note that this may interfere with logging.", + "pretty-print agent events and completion tools.", + ) + default_print_handler_include_completion_tools: bool = Field( + default=True, + description="If True, the default PrintHandler will include completion tools.", ) # ------------ orchestration settings ------------ diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py new file mode 100644 index 00000000..d54fae87 --- /dev/null +++ b/src/controlflow/stream.py @@ -0,0 +1,167 @@ +# Example usage +# +# # Stream all events +# for event in cf.stream.events("Write a story"): +# print(event) +# +# # Stream just messages +# for event in cf.stream.events("Write a story", events='messages'): +# print(event.content) +# +# # Stream just the result +# for delta, snapshot in cf.stream.result("Write a story"): +# print(f"New: {delta}") +# +# # Stream results from multiple tasks +# for delta, snapshot in cf.stream.result_from_tasks([task1, task2]): +# print(f"New result: {delta}") +# +from typing import Any, AsyncIterator, Callable, Iterator, Literal, Optional, Union + +from controlflow.events.base import Event +from controlflow.events.events import ( + AgentMessage, + AgentMessageDelta, + AgentToolCall, + ToolResult, +) +from controlflow.orchestration.handler import AsyncHandler, Handler +from controlflow.orchestration.orchestrator import Orchestrator +from controlflow.tasks.task import Task + +StreamEvents = Union[list[str], Literal["all", "messages", "tools", "completion_tools"]] + + +def events( + objective: str, + *, + events: StreamEvents = "all", + filter_fn: Optional[Callable[[Event], bool]] = None, + **kwargs, +) -> Iterator[Event]: + """ + Stream events from a task execution. + + Args: + objective: The task objective + events: Which events to stream. Can be list of event types or: + 'all' - all events + 'messages' - agent messages + 'tools' - all tool calls/results + 'completion_tools' - only completion tools + filter_fn: Optional additional filter function + **kwargs: Additional arguments passed to Task + + Returns: + Iterator of Event objects + """ + + def get_event_filter(): + if isinstance(events, list): + return lambda e: e.event in events + elif events == "messages": + return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) + elif events == "tools": + return lambda e: isinstance(e, (AgentToolCall, ToolResult)) + elif events == "completion_tools": + return lambda e: ( + isinstance(e, (AgentToolCall, ToolResult)) + and e.tool_call["name"].startswith("mark_task_") + ) + else: # 'all' + return lambda e: True + + event_filter = get_event_filter() + + def event_handler(event: Event): + if event_filter(event) and (not filter_fn or filter_fn(event)): + yield event + + task = Task(objective=objective) + task.run(handlers=[Handler(event_handler)], **kwargs) + + +def result( + objective: str, + **kwargs, +) -> Iterator[tuple[Any, Any]]: + """ + Stream result from a task execution. + + Args: + objective: The task objective + **kwargs: Additional arguments passed to Task + + Returns: + Iterator of (delta, accumulated) result tuples + """ + current_result = None + + def result_handler(event: Event): + nonlocal current_result + if isinstance(event, ToolResult): + if event.tool_call["name"].startswith("mark_task_"): + result = event.tool_result.result # Get actual result value + if result != current_result: # Only yield if changed + current_result = result + yield (result, result) # For now delta == full result + + task = Task(objective=objective) + task.run(handlers=[Handler(result_handler)], **kwargs) + + +def events_from_tasks( + tasks: list[Task], + events: StreamEvents = "all", + filter_fn: Optional[Callable[[Event], bool]] = None, + **kwargs, +) -> Iterator[Event]: + """Stream events from multiple task executions.""" + + def get_event_filter(): + if isinstance(events, list): + return lambda e: e.event in events + elif events == "messages": + return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) + elif events == "tools": + return lambda e: isinstance(e, (AgentToolCall, ToolResult)) + elif events == "completion_tools": + return lambda e: ( + isinstance(e, (AgentToolCall, ToolResult)) + and e.tool_call["name"].startswith("mark_task_") + ) + else: # 'all' + return lambda e: True + + event_filter = get_event_filter() + + def event_handler(event: Event): + if event_filter(event) and (not filter_fn or filter_fn(event)): + yield event + + orchestrator = Orchestrator( + tasks=tasks, handlers=[Handler(event_handler)], **kwargs + ) + orchestrator.run() + + +def result_from_tasks( + tasks: list[Task], + **kwargs, +) -> Iterator[tuple[Any, Any]]: + """Stream results from multiple task executions.""" + current_results = {task.id: None for task in tasks} + + def result_handler(event: Event): + if isinstance(event, ToolResult): + if event.tool_call["name"].startswith("mark_task_"): + task_id = event.task.id + result = event.tool_result.result + if result != current_results[task_id]: + current_results[task_id] = result + yield (result, result) + + orchestrator = Orchestrator( + tasks=tasks, handlers=[Handler(result_handler)], **kwargs + ) + orchestrator.run() From b65d3ca32fc8700870db10c2dc9b5057dbf710c6 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 9 Nov 2024 16:38:19 -0500 Subject: [PATCH 02/12] Update event names --- src/controlflow/agents/agent.py | 8 +- src/controlflow/events/events.py | 69 ++-- src/controlflow/events/orchestrator_events.py | 19 +- src/controlflow/handlers/print_handler.py | 44 +-- src/controlflow/orchestration/orchestrator.py | 303 +++++++++++++----- src/controlflow/stream.py | 44 ++- src/controlflow/tools/tools.py | 16 +- 7 files changed, 338 insertions(+), 165 deletions(-) diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 5bbd0da5..97f39f49 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -316,7 +316,7 @@ def _run_model( response += delta yield from AgentMessageDelta( - agent=self, delta=delta, snapshot=response + agent=self, message_delta=delta, message_snapshot=response ).all_related_events(tools=tools) else: @@ -344,7 +344,7 @@ def _run_model( for tool_call in response.tool_calls + response.invalid_tool_calls: result = handle_tool_call(tool_call, tools=tools) - yield ToolResult(agent=self, tool_call=tool_call, tool_result=result) + yield ToolResult(agent=self, tool_result=result) @prefect_task(task_run_name="Call LLM") async def _run_model_async( @@ -379,7 +379,7 @@ async def _run_model_async( response += delta for event in AgentMessageDelta( - agent=self, delta=delta, snapshot=response + agent=self, message_delta=delta, message_snapshot=response ).all_related_events(tools=tools): yield event @@ -409,4 +409,4 @@ async def _run_model_async( for tool_call in response.tool_calls + response.invalid_tool_calls: result = await handle_tool_call_async(tool_call, tools=tools) - yield ToolResult(agent=self, tool_call=tool_call, tool_result=result) + yield ToolResult(agent=self, tool_result=result) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index 5a3845d8..55d00c35 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -11,7 +11,9 @@ HumanMessage, ToolMessage, ) -from controlflow.tools.tools import InvalidToolCall, Tool, ToolCall, ToolResult +from controlflow.tools.tools import InvalidToolCall, Tool +from controlflow.tools.tools import ToolCall as ToolCallPayload +from controlflow.tools.tools import ToolResult as ToolResultPayload from controlflow.utilities.logging import get_logger if TYPE_CHECKING: @@ -55,7 +57,7 @@ class AgentMessage(Event): message: dict @field_validator("message", mode="before") - def _message(cls, v): + def _as_message_dict(cls, v): if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "ai" @@ -110,11 +112,11 @@ class AgentMessageDelta(UnpersistedEvent): event: Literal["agent-message-delta"] = "agent-message-delta" agent: Agent - delta: dict - snapshot: dict + message_delta: dict + message_snapshot: dict - @field_validator("delta", "snapshot", mode="before") - def _message(cls, v): + @field_validator("message_delta", "message_snapshot", mode="before") + def _as_message_dict(cls, v): if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "AIMessageChunk" @@ -122,26 +124,18 @@ def _message(cls, v): @model_validator(mode="after") def _finalize(self): - self.delta["name"] = self.agent.name - self.snapshot["name"] = self.agent.name + self.message_delta["name"] = self.agent.name + self.message_snapshot["name"] = self.agent.name return self - @property - def delta_message(self) -> AIMessageChunk: - return AIMessageChunk(**self.delta) - - @property - def snapshot_message(self) -> AIMessage: - return AIMessage(**self.snapshot | {"type": "ai"}) - def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: deltas = [] - for call_delta in self.delta["tool_call_chunks"]: + for call_delta in self.message_delta["tool_call_chunks"]: # try to retrieve the matching snapshot based on index call_snapshot = next( ( c - for i, c in enumerate(self.snapshot["tool_calls"]) + for i, c in enumerate(self.message_snapshot["tool_calls"]) if i == call_delta.get("index") ), None, @@ -152,8 +146,8 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: deltas.append( AgentToolCallDelta( agent=self.agent, - delta=call_delta, - snapshot=call_snapshot, + tool_call_delta=call_delta, + tool_call_snapshot=call_snapshot, tool=tool, args=call_snapshot["args"], ) @@ -163,8 +157,8 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: def to_content_delta(self) -> "AgentContentDelta": return AgentContentDelta( agent=self.agent, - delta=self.delta["content"], - snapshot=self.snapshot["content"], + content_delta=self.message_delta["content"], + content_snapshot=self.message_snapshot["content"], ) def all_related_events(self, tools: list[Tool]) -> list[Event]: @@ -180,17 +174,25 @@ class AgentContent(UnpersistedEvent): class AgentContentDelta(UnpersistedEvent): event: Literal["agent-content-delta"] = "agent-content-delta" agent: Agent - delta: str - snapshot: str + content_delta: str + content_snapshot: str + + +class AgentToolCall(Event): + event: Literal["tool-call"] = "tool-call" + agent: Agent + tool_call: Union[ToolCallPayload, InvalidToolCall] + tool: Optional[Tool] = None + args: dict = {} class AgentToolCallDelta(UnpersistedEvent): event: Literal["agent-tool-call-delta"] = "agent-tool-call-delta" agent: Agent - delta: dict - snapshot: dict - tool: Tool - args: dict + tool_call_delta: dict + tool_call_snapshot: dict + tool: Optional[Tool] = None + args: dict = {} class EndTurn(Event): @@ -199,19 +201,10 @@ class EndTurn(Event): next_agent_name: Optional[str] = None -class AgentToolCall(Event): - event: Literal["tool-call"] = "tool-call" - agent: Agent - tool_call: Union[ToolCall, InvalidToolCall] - tool: Tool - args: dict - - class ToolResult(Event): event: Literal["tool-result"] = "tool-result" agent: Agent - tool_call: Union[ToolCall, InvalidToolCall] - tool_result: ToolResult + tool_result: ToolResultPayload def to_messages(self, context: "CompileContext") -> list[BaseMessage]: if self.agent.name == context.agent.name: diff --git a/src/controlflow/events/orchestrator_events.py b/src/controlflow/events/orchestrator_events.py index 932fe8de..88370f74 100644 --- a/src/controlflow/events/orchestrator_events.py +++ b/src/controlflow/events/orchestrator_events.py @@ -1,41 +1,46 @@ from dataclasses import Field -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from pydantic.functional_serializers import PlainSerializer from controlflow.agents.agent import Agent from controlflow.events.base import UnpersistedEvent -from controlflow.orchestration.orchestrator import Orchestrator + +if TYPE_CHECKING: + from controlflow.orchestration.conditions import RunContext + from controlflow.orchestration.orchestrator import Orchestrator class OrchestratorStart(UnpersistedEvent): event: Literal["orchestrator-start"] = "orchestrator-start" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" + run_context: "RunContext" class OrchestratorEnd(UnpersistedEvent): event: Literal["orchestrator-end"] = "orchestrator-end" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" + run_context: "RunContext" class OrchestratorError(UnpersistedEvent): event: Literal["orchestrator-error"] = "orchestrator-error" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" error: Annotated[Exception, PlainSerializer(lambda x: str(x), return_type=str)] class AgentTurnStart(UnpersistedEvent): event: Literal["agent-turn-start"] = "agent-turn-start" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" agent: Agent class AgentTurnEnd(UnpersistedEvent): event: Literal["agent-turn-end"] = "agent-turn-end" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" agent: Agent diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index 0cb141c5..92a418e3 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -45,7 +45,7 @@ def update_live(self, latest: BaseMessage = None): # gather all tool events first for _, event in events: if isinstance(event, ToolResult): - tool_results[event.tool_call["id"]] = event + tool_results[event.tool_result.tool_call["id"]] = event for _, event in events: if isinstance(event, (AgentMessageDelta, AgentMessage)): @@ -76,7 +76,7 @@ def on_orchestrator_error(self, event: OrchestratorError): self.live.stop() def on_agent_message_delta(self, event: AgentMessageDelta): - self.events[event.snapshot_message.id] = event + self.events[event.message_snapshot["id"]] = event self.update_live() def on_agent_message(self, event: AgentMessage): @@ -93,15 +93,17 @@ def on_tool_call(self, event: AgentToolCall): def on_tool_result(self, event: ToolResult): # skip completion tools if configured to do so - if not self.include_completion_tools and event.tool_result.tool_metadata.get( - "is_completion_tool" + if ( + not self.include_completion_tools + and event.tool_result.tool + and event.tool_result.tool.metadata.get("is_completion_tool") ): return - self.events[f"tool-result:{event.tool_call['id']}"] = event + self.events[f"tool-result:{event.tool_result.tool_call['id']}"] = event # # if we were paused, resume the live display - if self.paused_id and self.paused_id == event.tool_call["id"]: + if self.paused_id and self.paused_id == event.tool_result.tool_call["id"]: self.paused_id = None # print newline to avoid odd formatting issues print() @@ -141,22 +143,22 @@ def format_event( content = [] if isinstance(event, AgentMessageDelta): - message = event.snapshot_message + message = event.message_snapshot elif isinstance(event, AgentMessage): - message = event.ai_message + message = event.message else: return - if message.content: - if isinstance(message.content, str): - content.append(Markdown(str(message.content))) - elif isinstance(message.content, dict): - if "content" in message.content: - content.append(Markdown(str(message.content["content"]))) - elif "text" in message.content: - content.append(Markdown(str(message.content["text"]))) - elif isinstance(message.content, list): - for item in message.content: + if message["content"]: + if isinstance(message["content"], str): + content.append(Markdown(str(message["content"]))) + elif isinstance(message["content"], dict): + if "content" in message["content"]: + content.append(Markdown(str(message["content"]["content"]))) + elif "text" in message["content"]: + content.append(Markdown(str(message["content"]["text"]))) + elif isinstance(message["content"], list): + for item in message["content"]: if isinstance(item, str): content.append(Markdown(str(item))) elif "content" in item: @@ -165,7 +167,7 @@ def format_event( content.append(Markdown(str(item["text"]))) tool_content = [] - for tool_call in message.tool_calls + message.invalid_tool_calls: + for tool_call in message["tool_calls"] + message["invalid_tool_calls"]: tool_result = (tool_results or {}).get(tool_call["id"]) if tool_result: c = format_tool_result(tool_result) @@ -207,7 +209,7 @@ def format_tool_result(event: ToolResult) -> Panel: icon = ":white_check_mark:" if controlflow.settings.tools_verbose: - msg = f'Tool call: "{event.tool_call["name"]}"\n\nTool args: {event.tool_call["args"]}\n\nTool result: {event.tool_result.str_result}' + msg = f'Tool call: "{event.tool_result.tool_call["name"]}"\n\nTool args: {event.tool_result.tool_call["args"]}\n\nTool result: {event.tool_result.str_result}' else: - msg = f'Tool call: "{event.tool_call["name"]}"' + msg = f'Tool call: "{event.tool_result.tool_call["name"]}"' return status(icon, msg) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 97993b33..e6f68bd1 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, TypeVar, Union +from typing import AsyncIterator, Callable, Iterator, Optional, TypeVar, Union from pydantic import BaseModel, Field, field_validator @@ -8,6 +8,13 @@ from controlflow.events.base import Event from controlflow.events.events import AgentMessageDelta, OrchestratorMessage from controlflow.events.message_compiler import MessageCompiler +from controlflow.events.orchestrator_events import ( + AgentTurnEnd, + AgentTurnStart, + OrchestratorEnd, + OrchestratorError, + OrchestratorStart, +) from controlflow.flows import Flow from controlflow.instructions import get_instructions from controlflow.llm.messages import BaseMessage @@ -167,6 +174,56 @@ def get_memories(self) -> list[Memory]: return memories + def _run_agent_turn( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> Iterator[Event]: + """Run a single agent turn, yielding events as they occur.""" + assigned_tasks = self.get_tasks("assigned") + + self.turn_strategy.begin_turn() + + # Mark assigned tasks as running + for task in assigned_tasks: + if not task.is_running(): + task.mark_running() + yield OrchestratorMessage( + content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " + f"with objective: {task.objective}" + ) + + while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls + for task in assigned_tasks: + if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: + task.mark_failed(reason="Max LLM calls reached for this task.") + + # Check if there are any ready tasks left + if not any(t.is_ready() for t in assigned_tasks): + logger.debug("No `ready` tasks to run") + break + + if run_context.should_end(): + break + + messages = self.compile_messages() + tools = self.get_tools() + + # Run model and yield events + for event in self.agent._run_model( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event + + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 + + run_context.agent_turns += 1 + @prefect_task(task_run_name="Orchestrator.run()") def run( self, @@ -177,9 +234,11 @@ def run( Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, ) -> RunContext: - import controlflow.events.orchestrator_events - - # Create the base termination condition + """ + Run the orchestrator, handling events internally. + Returns the final run context. + """ + # Create run context at the outermost level if run_until is None: run_until = AllComplete() elif not isinstance(run_until, RunEndCondition): @@ -197,6 +256,19 @@ def run( run_context = RunContext(orchestrator=self, run_end_condition=run_until) + for event in self._run( + run_context=run_context, + model_kwargs=model_kwargs, + ): + self.handle_event(event) + return run_context + + def _run( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> Iterator[Event]: + """Run the orchestrator, yielding events as they occur.""" # Initialize the agent if not already set if not self.agent: self.agent = self.turn_strategy.get_next_agent( @@ -204,29 +276,23 @@ def run( ) # Signal the start of orchestration - self.handle_event( - controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) - ) + yield OrchestratorStart(orchestrator=self, run_context=run_context) try: while True: if run_context.should_end(): break - self.handle_event( - controlflow.events.orchestrator_events.AgentTurnStart( - orchestrator=self, agent=self.agent - ) - ) - self.run_agent_turn( + yield AgentTurnStart(orchestrator=self, agent=self.agent) + + # Run turn and yield its events + for event in self._run_agent_turn( run_context=run_context, model_kwargs=model_kwargs, - ) - self.handle_event( - controlflow.events.orchestrator_events.AgentTurnEnd( - orchestrator=self, agent=self.agent - ) - ) + ): + yield event + + yield AgentTurnEnd(orchestrator=self, agent=self.agent) # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -235,21 +301,12 @@ def run( ) except Exception as exc: - # Handle any exceptions that occur during orchestration - self.handle_event( - controlflow.events.orchestrator_events.OrchestratorError( - orchestrator=self, error=exc - ) - ) + # Yield error event if something goes wrong + yield OrchestratorError(orchestrator=self, error=exc) raise finally: # Signal the end of orchestration - self.handle_event( - controlflow.events.orchestrator_events.OrchestratorEnd( - orchestrator=self - ) - ) - return run_context + yield OrchestratorEnd(orchestrator=self, run_context=run_context) @prefect_task async def run_async( @@ -261,9 +318,11 @@ async def run_async( Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, ) -> RunContext: - import controlflow.events.orchestrator_events - - # Create the base termination condition + """ + Run the orchestrator asynchronously, handling events internally. + Returns the final run context. + """ + # Create run context at the outermost level if run_until is None: run_until = AllComplete() elif not isinstance(run_until, RunEndCondition): @@ -281,58 +340,11 @@ async def run_async( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - # Initialize the agent if not already set - if not self.agent: - self.agent = self.turn_strategy.get_next_agent( - None, self.get_available_agents() - ) - - # Signal the start of orchestration - await self.handle_event_async( - controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) - ) - - try: - while True: - if run_context.should_end(): - break - - await self.handle_event_async( - controlflow.events.orchestrator_events.AgentTurnStart( - orchestrator=self, agent=self.agent - ) - ) - await self.run_agent_turn_async( - run_context=run_context, - model_kwargs=model_kwargs, - ) - await self.handle_event_async( - controlflow.events.orchestrator_events.AgentTurnEnd( - orchestrator=self, agent=self.agent - ) - ) - - # Select the next agent for the following turn - if available_agents := self.get_available_agents(): - self.agent = self.turn_strategy.get_next_agent( - self.agent, available_agents - ) - - except Exception as exc: - # Handle any exceptions that occur during orchestration - await self.handle_event_async( - controlflow.events.orchestrator_events.OrchestratorError( - orchestrator=self, error=exc - ) - ) - raise - finally: - # Signal the end of orchestration - await self.handle_event_async( - controlflow.events.orchestrator_events.OrchestratorEnd( - orchestrator=self - ) - ) + async for event in self._run_async( + run_context=run_context, + model_kwargs=model_kwargs, + ): + await self.handle_event_async(event) return run_context @prefect_task(task_run_name="Agent turn: {self.agent.name}") @@ -580,5 +592,124 @@ def get_task_hierarchy(self) -> dict: return hierarchy + async def _run_agent_turn_async( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> AsyncIterator[Event]: + """Async version of _run_agent_turn.""" + assigned_tasks = self.get_tasks("assigned") + + self.turn_strategy.begin_turn() + + # Mark assigned tasks as running + for task in assigned_tasks: + if not task.is_running(): + task.mark_running() + yield OrchestratorMessage( + content=f"Starting task {task.name} (ID {task.id}) " + f"with objective: {task.objective}" + ) + + while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls + for task in assigned_tasks: + if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: + task.mark_failed(reason="Max LLM calls reached for this task.") + + # Check if there are any ready tasks left + if not any(t.is_ready() for t in assigned_tasks): + logger.debug("No `ready` tasks to run") + break + + if run_context.should_end(): + break + + messages = self.compile_messages() + tools = self.get_tools() + + async for event in self.agent._run_model_async( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event + + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 + + run_context.agent_turns += 1 + + async def _run_async( + self, + max_llm_calls: Optional[int] = None, + max_agent_turns: Optional[int] = None, + model_kwargs: Optional[dict] = None, + run_until: Optional[ + Union[RunEndCondition, Callable[[RunContext], bool]] + ] = None, + ) -> AsyncIterator[Event]: + """Async version of _run.""" + # Create the base termination condition + if run_until is None: + run_until = AllComplete() + elif not isinstance(run_until, RunEndCondition): + run_until = FnCondition(run_until) + + # Add max_llm_calls condition + if max_llm_calls is None: + max_llm_calls = controlflow.settings.orchestrator_max_llm_calls + run_until = run_until | MaxLLMCalls(max_llm_calls) + + # Add max_agent_turns condition + if max_agent_turns is None: + max_agent_turns = controlflow.settings.orchestrator_max_agent_turns + run_until = run_until | MaxAgentTurns(max_agent_turns) + + run_context = RunContext(orchestrator=self, run_end_condition=run_until) + + # Initialize the agent if not already set + if not self.agent: + self.agent = self.turn_strategy.get_next_agent( + None, self.get_available_agents() + ) + + # Signal the start of orchestration + yield OrchestratorStart(orchestrator=self, run_context=run_context) + + try: + while True: + if run_context.should_end(): + break + + yield AgentTurnStart(orchestrator=self, agent=self.agent) + + async for event in self._run_agent_turn_async( + run_context=run_context, + model_kwargs=model_kwargs, + ): + yield event + + yield AgentTurnEnd(orchestrator=self, agent=self.agent) + + # Select the next agent for the following turn + if available_agents := self.get_available_agents(): + self.agent = self.turn_strategy.get_next_agent( + self.agent, available_agents + ) + + except Exception as exc: + yield OrchestratorError(orchestrator=self, error=exc) + raise + finally: + yield OrchestratorEnd(orchestrator=self) + +# Rebuild all models with forward references after Orchestrator is defined +OrchestratorStart.model_rebuild() +OrchestratorEnd.model_rebuild() +OrchestratorError.model_rebuild() +AgentTurnStart.model_rebuild() +AgentTurnEnd.model_rebuild() RunContext.model_rebuild() diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py index d54fae87..e8552ebf 100644 --- a/src/controlflow/stream.py +++ b/src/controlflow/stream.py @@ -20,16 +20,58 @@ from controlflow.events.base import Event from controlflow.events.events import ( + AgentContent, + AgentContentDelta, AgentMessage, AgentMessageDelta, AgentToolCall, + AgentToolCallDelta, ToolResult, ) from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.orchestrator import Orchestrator from controlflow.tasks.task import Task -StreamEvents = Union[list[str], Literal["all", "messages", "tools", "completion_tools"]] +StreamEvents = Union[ + list[str], + Literal["all", "messages", "content", "tools", "completion_tools", "agent_tools"], +] + + +def event_filter(events: StreamEvents) -> Callable[[Event], bool]: + def _event_filter(event: Event) -> bool: + if events == "all": + return True + elif events == "messages": + return isinstance(event, (AgentMessage, AgentMessageDelta)) + elif events == "content": + return isinstance(event, (AgentContent, AgentContentDelta)) + elif events == "tools": + return isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)) + elif events == "completion_tools": + if isinstance(event, (AgentToolCall, AgentToolCallDelta)): + return event.tool and event.tool.metadata.get("is_completion_tool") + elif isinstance(event, ToolResult): + return event.tool_result and event.tool_result.tool.metadata.get( + "is_completion_tool" + ) + return False + elif events == "agent_tools": + if isinstance(event, (AgentToolCall, AgentToolCallDelta)): + return event.tool and event.tool in event.agent.get_tools() + elif isinstance(event, ToolResult): + return ( + event.tool_result + and event.tool_result.tool in event.agent.get_tools() + ) + return False + else: + raise ValueError(f"Invalid event type: {events}") + + return _event_filter + + +# -------------------- BELOW HERE IS THE OLD STUFF -------------------- def events( diff --git a/src/controlflow/tools/tools.py b/src/controlflow/tools/tools.py index 224cfa31..cce08fc4 100644 --- a/src/controlflow/tools/tools.py +++ b/src/controlflow/tools/tools.py @@ -298,16 +298,16 @@ def output_to_string(output: Any) -> str: class ToolResult(ControlFlowModel): - tool_call_id: str + tool_call: Union[ToolCall, InvalidToolCall] + tool: Optional[Tool] = None result: Any = Field(exclude=True, repr=False) str_result: str = Field(repr=False) is_error: bool = False - tool_metadata: dict = {} def handle_tool_call( tool_call: Union[ToolCall, InvalidToolCall], tools: list[Tool] -) -> Any: +) -> ToolResult: """ Given a ToolCall and set of available tools, runs the tool call and returns a ToolResult object @@ -340,15 +340,15 @@ def handle_tool_call( raise exc return ToolResult( - tool_call_id=tool_call["id"], + tool_call=tool_call, + tool=tool, result=fn_output, str_result=output_to_string(fn_output), is_error=is_error, - tool_metadata=tool.metadata if tool else {}, ) -async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any: +async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> ToolResult: """ Given a ToolCall and set of available tools, runs the tool call and returns a ToolResult object @@ -381,9 +381,9 @@ async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any: raise exc return ToolResult( - tool_call_id=tool_call["id"], + tool_call=tool_call, + tool=tool, result=fn_output, str_result=output_to_string(fn_output), is_error=is_error, - tool_metadata=tool.metadata if tool else {}, ) From 516dd4b6ea29d6b2fba63449adaae1e415ec3b3a Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 9 Nov 2024 22:49:25 -0500 Subject: [PATCH 03/12] Update print handler --- src/controlflow/events/base.py | 2 +- src/controlflow/events/events.py | 25 +- src/controlflow/events/message_compiler.py | 2 +- src/controlflow/handlers/print_handler.py | 367 ++++++++++-------- src/controlflow/orchestration/orchestrator.py | 4 +- 5 files changed, 220 insertions(+), 180 deletions(-) diff --git a/src/controlflow/events/base.py b/src/controlflow/events/base.py index 1ae915d6..aad788cb 100644 --- a/src/controlflow/events/base.py +++ b/src/controlflow/events/base.py @@ -30,7 +30,7 @@ def to_messages(self, context: "CompileContext") -> list["BaseMessage"]: return [] def __repr__(self) -> str: - return f"{self.event} ({self.timestamp})" + return f"<{self.event} {self.timestamp}>" class UnpersistedEvent(Event): diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index 55d00c35..ea71463a 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -85,12 +85,17 @@ def to_tool_calls(self, tools: list[Tool]) -> list["AgentToolCall"]: tool_call=tool_call, tool=tool, args=tool_call["args"], + agent_message_id=self.message.get("id"), ) ) return calls def to_content(self) -> "AgentContent": - return AgentContent(agent=self.agent, content=self.message["content"]) + return AgentContent( + agent=self.agent, + content=self.message["content"], + agent_message_id=self.message.get("id"), + ) def all_related_events(self, tools: list[Tool]) -> list[Event]: return [self, self.to_content()] + self.to_tool_calls(tools) @@ -141,8 +146,10 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: None, ) - tool = next((t for t in tools if t.name == call_snapshot.get("name")), None) if call_snapshot: + tool = next( + (t for t in tools if t.name == call_snapshot.get("name")), None + ) deltas.append( AgentToolCallDelta( agent=self.agent, @@ -150,6 +157,7 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: tool_call_snapshot=call_snapshot, tool=tool, args=call_snapshot["args"], + agent_message_id=self.message_snapshot.get("id"), ) ) return deltas @@ -159,6 +167,7 @@ def to_content_delta(self) -> "AgentContentDelta": agent=self.agent, content_delta=self.message_delta["content"], content_snapshot=self.message_snapshot["content"], + agent_message_id=self.message_snapshot.get("id"), ) def all_related_events(self, tools: list[Tool]) -> list[Event]: @@ -168,19 +177,22 @@ def all_related_events(self, tools: list[Tool]) -> list[Event]: class AgentContent(UnpersistedEvent): event: Literal["agent-content"] = "agent-content" agent: Agent + agent_message_id: Optional[str] = None content: Union[str, list[Union[str, dict]]] class AgentContentDelta(UnpersistedEvent): event: Literal["agent-content-delta"] = "agent-content-delta" agent: Agent - content_delta: str - content_snapshot: str + agent_message_id: Optional[str] = None + content_delta: Union[str, list[Union[str, dict]]] + content_snapshot: Union[str, list[Union[str, dict]]] class AgentToolCall(Event): event: Literal["tool-call"] = "tool-call" agent: Agent + agent_message_id: Optional[str] = None tool_call: Union[ToolCallPayload, InvalidToolCall] tool: Optional[Tool] = None args: dict = {} @@ -189,6 +201,7 @@ class AgentToolCall(Event): class AgentToolCallDelta(UnpersistedEvent): event: Literal["agent-tool-call-delta"] = "agent-tool-call-delta" agent: Agent + agent_message_id: Optional[str] = None tool_call_delta: dict tool_call_snapshot: dict tool: Optional[Tool] = None @@ -211,14 +224,14 @@ def to_messages(self, context: "CompileContext") -> list[BaseMessage]: return [ ToolMessage( content=self.tool_result.str_result, - tool_call_id=self.tool_call["id"], + tool_call_id=self.tool_result.tool_call["id"], name=self.agent.name, ) ] else: return OrchestratorMessage( prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool ' - f'call: {self.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} ' + f'call: {self.tool_result.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} ' f'produced this result:', content=self.tool_result.str_result, name=self.agent.name, diff --git a/src/controlflow/events/message_compiler.py b/src/controlflow/events/message_compiler.py index 63a8f23e..430f026e 100644 --- a/src/controlflow/events/message_compiler.py +++ b/src/controlflow/events/message_compiler.py @@ -215,7 +215,7 @@ def organize_events(self, context: CompileContext) -> list[Event]: tool_calls[tc["id"]] = combined_event elif isinstance(event, ToolResult): combined_event: CombinedAgentMessage = tool_calls.get( - event.tool_call["id"] + event.tool_result.tool_call["id"] ) if combined_event: combined_event.tool_results.append(event) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index 92a418e3..c8ca252e 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -1,7 +1,8 @@ import datetime -from typing import Union +from typing import Optional import rich +from pydantic import BaseModel from rich import box from rich.console import Group from rich.live import Live @@ -10,89 +11,198 @@ from rich.spinner import Spinner from rich.table import Table -import controlflow -from controlflow.events.base import Event -from controlflow.events.events import ( - AgentMessage, - AgentMessageDelta, - AgentToolCall, - ToolResult, -) +from controlflow.events.events import AgentContentDelta, AgentToolCallDelta, ToolResult from controlflow.events.orchestrator_events import ( OrchestratorEnd, OrchestratorError, OrchestratorStart, ) -from controlflow.llm.messages import BaseMessage from controlflow.orchestration.handler import Handler -from controlflow.tools.tools import ToolCall +from controlflow.tools.tools import Tool from controlflow.utilities.rich import console as cf_console -class PrintHandler(Handler): - def __init__(self, include_completion_tools: bool = True): - self.events: dict[str, Event] = {} - self.paused_id: str = None - self.include_completion_tools = include_completion_tools - super().__init__() +class DisplayState(BaseModel): + """Base class for content to be displayed.""" - def update_live(self, latest: BaseMessage = None): - events = sorted(self.events.items(), key=lambda e: (e[1].timestamp, e[0])) - content = [] + agent_name: str + first_timestamp: datetime.datetime - tool_results = {} # To track tool results by their call ID + def format_timestamp(self) -> str: + """Format the timestamp for display.""" + local_timestamp = self.first_timestamp.astimezone() + return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) - # gather all tool events first - for _, event in events: - if isinstance(event, ToolResult): - tool_results[event.tool_result.tool_call["id"]] = event - for _, event in events: - if isinstance(event, (AgentMessageDelta, AgentMessage)): - if formatted := format_event(event, tool_results=tool_results): - content.append(formatted) +class ContentState(DisplayState): + """State for content being streamed.""" - if not content: - return - elif self.live.is_started: - self.live.update(Group(*content), refresh=True) - elif latest: - cf_console.print(format_event(latest)) + content: str = "" - def on_orchestrator_start(self, event: OrchestratorStart): - self.live: Live = Live( - auto_refresh=False, console=cf_console, vertical_overflow="visible" + @staticmethod + def _convert_content_to_str(content) -> str: + """Convert various content formats to a string.""" + if isinstance(content, str): + return content + + if isinstance(content, dict): + return content.get("content", content.get("text", "")) + + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + part = item.get("content", item.get("text", "")) + if part: + parts.append(part) + return "\n".join(parts) + + return str(content) + + def update_content(self, new_content) -> None: + """Update content, converting complex content types to string.""" + self.content = self._convert_content_to_str(new_content) + + def render_panel(self) -> Panel: + """Render content as a markdown panel.""" + return Panel( + Markdown(self.content), + title=f"[bold]Agent: {self.agent_name}[/]", + subtitle=f"[italic]{self.format_timestamp()}[/]", + title_align="left", + subtitle_align="right", + border_style="blue", + box=box.ROUNDED, + width=100, + padding=(1, 2), ) - self.events.clear() - try: - self.live.start() - except rich.errors.LiveError: - pass - def on_orchestrator_end(self, event: OrchestratorEnd): - self.live.stop() - def on_orchestrator_error(self, event: OrchestratorError): - self.live.stop() +class ToolState(DisplayState): + """State for a tool call and its result.""" - def on_agent_message_delta(self, event: AgentMessageDelta): - self.events[event.message_snapshot["id"]] = event - self.update_live() + name: str + args: dict + result: Optional[str] = None + is_error: bool = False + is_complete: bool = False + tool: Optional[Tool] = None - def on_agent_message(self, event: AgentMessage): - self.events[event.ai_message.id] = event - self.update_live() + def render_panel(self, show_details: bool = True) -> Panel: + """Render tool state as a panel with status indicator.""" + t = Table.grid(padding=1) - def on_tool_call(self, event: AgentToolCall): - # if collecting input on the terminal, pause the live display - # to avoid overwriting the input prompt - if event.tool_call["name"] == "cli_input": - self.paused_id = event.tool_call["id"] - self.live.stop() - self.events.clear() + if self.is_complete: + icon = ":x:" if self.is_error else ":white_check_mark:" + if show_details and self.result: + tool_text = f'Tool "{self.name}": {self.result}' + else: + tool_text = f'Tool "{self.name}" completed' + else: + icon = Spinner("dots") + tool_text = f'Tool "{self.name}" running...' + if show_details and self.args: + tool_text += f"\nArguments: {self.args}" + + t.add_row(icon, tool_text) + + return Panel( + t, + subtitle=f"[italic]{self.format_timestamp()}[/]", + subtitle_align="right", + border_style="red" if self.is_error else "blue", + box=box.ROUNDED, + width=100, + padding=(1, 2), + ) + + +class PrintHandler(Handler): + def __init__(self, include_completion_tools: bool = True): + super().__init__() + self.include_completion_tools = include_completion_tools + self.live: Optional[Live] = None + self.paused_id: Optional[str] = None + self.states: dict[str, DisplayState] = {} + + def update_display(self): + """Render all current state as panels and update display.""" + if not self.live or not self.live.is_started or self.paused_id: + return + + # Sort states by timestamp and render panels + sorted_states = sorted(self.states.values(), key=lambda s: s.first_timestamp) + panels = [ + state.render_panel(show_details=self.include_completion_tools) + if isinstance(state, ToolState) + else state.render_panel() + for state in sorted_states + ] + + if panels: + self.live.update(Group(*panels), refresh=True) + + def on_agent_content_delta(self, event: AgentContentDelta): + """Handle content delta events by updating content state.""" + if not event.content_delta: + return + if event.agent_message_id not in self.states: + state = ContentState( + agent_name=event.agent.name, + first_timestamp=event.timestamp, + ) + state.update_content(event.content_snapshot) + self.states[event.agent_message_id] = state + else: + state = self.states[event.agent_message_id] + if isinstance(state, ContentState): + state.update_content(event.content_snapshot) + + self.update_display() + + def on_agent_tool_call_delta(self, event: AgentToolCallDelta): + """Handle tool call delta events by updating tool state.""" + # Handle CLI input special case + if event.tool_call_snapshot["name"] == "cli_input": + self.paused_id = event.tool_call_snapshot["id"] + if self.live and self.live.is_started: + self.live.stop() + return + + # Skip completion tools if configured + if ( + not self.include_completion_tools + and event.tool + and event.tool.metadata.get("is_completion_tool") + ): + return + + tool_id = event.tool_call_snapshot["id"] + if tool_id not in self.states: + self.states[tool_id] = ToolState( + agent_name=event.agent.name, + first_timestamp=event.timestamp, + name=event.tool_call_snapshot["name"], + args=event.args, + tool=event.tool, + ) + + self.update_display() def on_tool_result(self, event: ToolResult): - # skip completion tools if configured to do so + """Handle tool result events by updating tool state.""" + # Handle CLI input resume + if event.tool_result.tool_call["name"] == "cli_input": + if self.paused_id == event.tool_result.tool_call["id"]: + self.paused_id = None + print() + self.live = Live(console=cf_console, auto_refresh=False) + self.live.start() + return + + # Skip completion tools if configured if ( not self.include_completion_tools and event.tool_result.tool @@ -100,116 +210,33 @@ def on_tool_result(self, event: ToolResult): ): return - self.events[f"tool-result:{event.tool_result.tool_call['id']}"] = event + tool_id = event.tool_result.tool_call["id"] + if tool_id in self.states: + state = self.states[tool_id] + if isinstance(state, ToolState): + state.is_complete = True + state.is_error = event.tool_result.is_error + state.result = event.tool_result.str_result - # # if we were paused, resume the live display - if self.paused_id and self.paused_id == event.tool_result.tool_call["id"]: - self.paused_id = None - # print newline to avoid odd formatting issues - print() - self.live = Live(auto_refresh=False) - self.live.start() - self.update_live(latest=event) - - -ROLE_COLORS = { - "system": "gray", - "ai": "blue", - "user": "green", -} -ROLE_NAMES = { - "system": "System", - "ai": "Agent", - "user": "User", -} - - -def format_timestamp(timestamp: datetime.datetime) -> str: - local_timestamp = timestamp.astimezone() - return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) - - -def status(icon, text) -> Table: - t = Table.grid(padding=1) - t.add_row(icon, text) - return t - - -def format_event( - event: Union[AgentMessageDelta, AgentMessage], - tool_results: dict[str, ToolResult] = None, -) -> Panel: - title = f"Agent: {event.agent.name}" - - content = [] - if isinstance(event, AgentMessageDelta): - message = event.message_snapshot - elif isinstance(event, AgentMessage): - message = event.message - else: - return - - if message["content"]: - if isinstance(message["content"], str): - content.append(Markdown(str(message["content"]))) - elif isinstance(message["content"], dict): - if "content" in message["content"]: - content.append(Markdown(str(message["content"]["content"]))) - elif "text" in message["content"]: - content.append(Markdown(str(message["content"]["text"]))) - elif isinstance(message["content"], list): - for item in message["content"]: - if isinstance(item, str): - content.append(Markdown(str(item))) - elif "content" in item: - content.append(Markdown(str(item["content"]))) - elif "text" in item: - content.append(Markdown(str(item["text"]))) - - tool_content = [] - for tool_call in message["tool_calls"] + message["invalid_tool_calls"]: - tool_result = (tool_results or {}).get(tool_call["id"]) - if tool_result: - c = format_tool_result(tool_result) - else: - c = format_tool_call(tool_call) - if c: - tool_content.append(c) - - if content and tool_content: - content.append("\n") - - return Panel( - Group(*content, *tool_content), - title=f"[bold]{title}[/]", - subtitle=f"[italic]{format_timestamp(event.timestamp)}[/]", - title_align="left", - subtitle_align="right", - border_style=ROLE_COLORS.get("ai", "red"), - box=box.ROUNDED, - width=100, - expand=True, - padding=(1, 2), - ) - - -def format_tool_call(tool_call: ToolCall) -> Panel: - if controlflow.settings.tools_verbose: - return status( - Spinner("dots"), - f'Tool call: "{tool_call["name"]}"\n\nTool args: {tool_call["args"]}', - ) - return status(Spinner("dots"), f'Tool call: "{tool_call["name"]}"') + self.update_display() + def on_orchestrator_start(self, event: OrchestratorStart): + """Initialize live display.""" + self.live = Live( + auto_refresh=False, console=cf_console, vertical_overflow="visible" + ) + self.states.clear() + try: + self.live.start() + except rich.errors.LiveError: + pass -def format_tool_result(event: ToolResult) -> Panel: - if event.tool_result.is_error: - icon = ":x:" - else: - icon = ":white_check_mark:" + def on_orchestrator_end(self, event: OrchestratorEnd): + """Clean up live display.""" + if self.live and self.live.is_started: + self.live.stop() - if controlflow.settings.tools_verbose: - msg = f'Tool call: "{event.tool_result.tool_call["name"]}"\n\nTool args: {event.tool_result.tool_call["args"]}\n\nTool result: {event.tool_result.str_result}' - else: - msg = f'Tool call: "{event.tool_result.tool_call["name"]}"' - return status(icon, msg) + def on_orchestrator_error(self, event: OrchestratorError): + """Clean up live display on error.""" + if self.live and self.live.is_started: + self.live.stop() diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index e6f68bd1..7edafb4f 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -96,8 +96,8 @@ def handle_event(self, event: Event): Args: event (Event): The event to handle. """ - if not isinstance(event, AgentMessageDelta): - logger.debug(f"Handling event: {repr(event)}") + from controlflow.events.events import AgentContentDelta + for handler in self.handlers: if isinstance(handler, Handler): handler.handle(event) From 11eb2c7acab5363702eaca9cdc738fca61c3c5c3 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 9 Nov 2024 22:53:27 -0500 Subject: [PATCH 04/12] Add multi-color borders --- src/controlflow/handlers/print_handler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index c8ca252e..f8f0d783 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -100,11 +100,13 @@ def render_panel(self, show_details: bool = True) -> Panel: tool_text = f'Tool "{self.name}": {self.result}' else: tool_text = f'Tool "{self.name}" completed' + border_style = "red" if self.is_error else "green" else: icon = Spinner("dots") tool_text = f'Tool "{self.name}" running...' if show_details and self.args: tool_text += f"\nArguments: {self.args}" + border_style = "dim" t.add_row(icon, tool_text) @@ -112,7 +114,7 @@ def render_panel(self, show_details: bool = True) -> Panel: t, subtitle=f"[italic]{self.format_timestamp()}[/]", subtitle_align="right", - border_style="red" if self.is_error else "blue", + border_style=border_style, box=box.ROUNDED, width=100, padding=(1, 2), From dde5bd275be3a6da67b784c49c87ff6c600234a6 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 08:36:01 -0500 Subject: [PATCH 05/12] Update print handler --- src/controlflow/agents/agent.py | 5 -- src/controlflow/events/events.py | 9 ++- src/controlflow/handlers/print_handler.py | 77 +++++++++++++++++------ 3 files changed, 63 insertions(+), 28 deletions(-) diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 97f39f49..6c4381a9 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -289,12 +289,8 @@ def _run_model( model_kwargs: Optional[dict] = None, ) -> Generator[Event, None, None]: from controlflow.events.events import ( - AgentContent, - AgentContentDelta, AgentMessage, AgentMessageDelta, - AgentToolCall, - AgentToolCallDelta, ToolResult, ) @@ -357,7 +353,6 @@ async def _run_model_async( from controlflow.events.events import ( AgentMessage, AgentMessageDelta, - AgentToolCall, ToolResult, ) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index ea71463a..9920800e 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Literal, Optional, Union +import pydantic_core from pydantic import ConfigDict, field_validator, model_validator from controlflow.agents.agent import Agent @@ -140,8 +141,8 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: call_snapshot = next( ( c - for i, c in enumerate(self.message_snapshot["tool_calls"]) - if i == call_delta.get("index") + for c in self.message_snapshot["tool_call_chunks"] + if c.get("index", -1) == call_delta.get("index", -2) ), None, ) @@ -156,7 +157,9 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: tool_call_delta=call_delta, tool_call_snapshot=call_snapshot, tool=tool, - args=call_snapshot["args"], + args=pydantic_core.from_json( + call_snapshot["args"] or "{}", allow_partial=True + ), agent_message_id=self.message_snapshot.get("id"), ) ) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index f8f0d783..ec6d8a08 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -90,34 +90,67 @@ class ToolState(DisplayState): is_complete: bool = False tool: Optional[Tool] = None - def render_panel(self, show_details: bool = True) -> Panel: - """Render tool state as a panel with status indicator.""" - t = Table.grid(padding=1) - + def get_status_style(self) -> tuple[str | Spinner, str, str]: + """Returns (icon, text style, border style) for current status.""" if self.is_complete: - icon = ":x:" if self.is_error else ":white_check_mark:" - if show_details and self.result: - tool_text = f'Tool "{self.name}": {self.result}' + if self.is_error: + return "❌", "red", "red" else: - tool_text = f'Tool "{self.name}" completed' - border_style = "red" if self.is_error else "green" - else: - icon = Spinner("dots") - tool_text = f'Tool "{self.name}" running...' - if show_details and self.args: - tool_text += f"\nArguments: {self.args}" - border_style = "dim" + return "✅", "green", "green3" # Slightly softer green + return ( + Spinner("dots"), + "yellow", + "gray50", + ) # Animated spinner, softer running state - t.add_row(icon, tool_text) + def render_panel(self, show_details: bool = True) -> Panel: + """Render tool state as a panel with status indicator.""" + icon, text_style, border_style = self.get_status_style() + + # Main content with clean layout + table = Table.grid(padding=0, expand=True) + + # Tool name and icon as a clean header + header = Table.grid(padding=1) + header.add_column(width=2) # Icon + header.add_column() # Name + tool_name = self.name.replace("_", " ").title() # Prettier display name + header.add_row(icon, f"[{text_style} bold]{tool_name}[/]") + table.add_row(header) + + if show_details: + details = Table.grid(padding=(0, 2)) + details.add_column(style="dim", width=9) + details.add_column() + + # Arguments with pretty formatting + if self.args: + details.add_row( + " Input:", # Indent to align with tool name + rich.pretty.Pretty(self.args, indent_size=2, expand_all=True), + ) + + # Result when complete + if self.is_complete and self.result: + label = "Error" if self.is_error else "Output" + style = "red" if self.is_error else "green3" + details.add_row( + f" {label}:", # Indent to align with tool name + f"[{style}]{self.result}[/]", + ) + + table.add_row(details) return Panel( - t, + table, + title=f"[bold]Agent: {self.agent_name}[/]", subtitle=f"[italic]{self.format_timestamp()}[/]", + title_align="left", subtitle_align="right", border_style=border_style, box=box.ROUNDED, width=100, - padding=(1, 2), + padding=(0, 1), ) @@ -190,6 +223,10 @@ def on_agent_tool_call_delta(self, event: AgentToolCallDelta): args=event.args, tool=event.tool, ) + else: + state = self.states[tool_id] + if isinstance(state, ToolState): + state.args = event.args self.update_display() @@ -200,7 +237,7 @@ def on_tool_result(self, event: ToolResult): if self.paused_id == event.tool_result.tool_call["id"]: self.paused_id = None print() - self.live = Live(console=cf_console, auto_refresh=False) + self.live = Live(console=cf_console, auto_refresh=True) self.live.start() return @@ -225,7 +262,7 @@ def on_tool_result(self, event: ToolResult): def on_orchestrator_start(self, event: OrchestratorStart): """Initialize live display.""" self.live = Live( - auto_refresh=False, console=cf_console, vertical_overflow="visible" + auto_refresh=True, console=cf_console, vertical_overflow="visible" ) self.states.clear() try: From 2c5c95a970b9eba2a789c8c65d5c8a224e7d1fa0 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 10:42:44 -0500 Subject: [PATCH 06/12] Update print handler --- src/controlflow/handlers/print_handler.py | 86 ++++++++++++++++++----- src/controlflow/tasks/task.py | 11 ++- 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index ec6d8a08..55fc1579 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -103,18 +103,78 @@ def get_status_style(self) -> tuple[str | Spinner, str, str]: "gray50", ) # Animated spinner, softer running state + def render_completion_tool(self) -> Panel: + """Special rendering for completion tools.""" + table = Table.grid(padding=0, expand=True) + header = Table.grid(padding=1) + header.add_column(width=2) + header.add_column() + + is_success_tool = self.tool.metadata.get("is_success_tool", False) + is_fail_tool = self.tool.metadata.get("is_fail_tool", False) + task = self.tool.metadata.get("completion_task") + task_name = task.friendly_name() if task else "Unknown Task" + task_result = task.result if task else None + + if not self.is_complete: + # Running state - muted style + icon = Spinner("dots") + message = f"Working on task: {task_name}" + text_style = "dim" + border_style = "gray50" + else: + if self.is_error: + # Tool execution failed (error in the tool itself) + icon = "❌" + message = f"Error marking task status: {task_name}" + text_style = "red" + border_style = "red" + if self.result: + message += f"\nError: {self.result}" + elif is_fail_tool: + # Failure tool succeeded (task is being marked as failed) + icon = "❌" + message = f"Task failed: {task_name}" + text_style = "red" + border_style = "red" + if task_result: + message += f"\nReason: {task_result or 'No reason provided'}" + else: + # Success tool succeeded (task completed normally) + icon = "✓" + message = f"Task complete: {task_name}" + text_style = "dim" + border_style = "gray50" + + header.add_row(icon, f"[{text_style}]{message}[/]") + table.add_row(header) + + return Panel( + table, + title=f"[bold]Agent: {self.agent_name}[/]", + subtitle=f"[italic]{self.format_timestamp()}[/]", + title_align="left", + subtitle_align="right", + border_style=border_style, + box=box.ROUNDED, + width=100, + padding=(0, 1), + ) + def render_panel(self, show_details: bool = True) -> Panel: """Render tool state as a panel with status indicator.""" - icon, text_style, border_style = self.get_status_style() + # Handle completion tools separately + if self.tool and self.tool.metadata.get("is_completion_tool"): + return self.render_completion_tool() - # Main content with clean layout + # Regular tool display logic continues as before... + icon, text_style, border_style = self.get_status_style() table = Table.grid(padding=0, expand=True) - # Tool name and icon as a clean header header = Table.grid(padding=1) - header.add_column(width=2) # Icon - header.add_column() # Name - tool_name = self.name.replace("_", " ").title() # Prettier display name + header.add_column(width=2) + header.add_column() + tool_name = self.name.replace("_", " ").title() header.add_row(icon, f"[{text_style} bold]{tool_name}[/]") table.add_row(header) @@ -123,19 +183,17 @@ def render_panel(self, show_details: bool = True) -> Panel: details.add_column(style="dim", width=9) details.add_column() - # Arguments with pretty formatting if self.args: details.add_row( - " Input:", # Indent to align with tool name + " Input:", rich.pretty.Pretty(self.args, indent_size=2, expand_all=True), ) - # Result when complete if self.is_complete and self.result: label = "Error" if self.is_error else "Output" style = "red" if self.is_error else "green3" details.add_row( - f" {label}:", # Indent to align with tool name + f" {label}:", f"[{style}]{self.result}[/]", ) @@ -206,14 +264,6 @@ def on_agent_tool_call_delta(self, event: AgentToolCallDelta): self.live.stop() return - # Skip completion tools if configured - if ( - not self.include_completion_tools - and event.tool - and event.tool.metadata.get("is_completion_tool") - ): - return - tool_id = event.tool_call_snapshot["id"] if tool_id not in self.states: self.states[tool_id] = ToolState( diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index ad36127c..2ea43730 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -583,7 +583,11 @@ def get_success_tool(self) -> Tool: """ options = {} instructions = [] - metadata = {"is_completion_tool": True} + metadata = { + "is_completion_tool": True, + "is_success_tool": True, + "completion_task": self, + } result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide @@ -714,6 +718,11 @@ def get_fail_tool(self) -> Tool: failure.""" ), include_return_description=False, + metadata={ + "is_completion_tool": True, + "is_fail_tool": True, + "completion_task": self, + }, ) def fail(reason: str) -> str: self.mark_failed(reason=reason) From 00bda7b849f83604c48680bd9583f5cbbc5b7fc6 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:31:38 -0500 Subject: [PATCH 07/12] Fix streaming for completion results --- src/controlflow/events/events.py | 46 ++++++--- src/controlflow/handlers/print_handler.py | 116 +++++++++++++++++----- 2 files changed, 119 insertions(+), 43 deletions(-) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index 9920800e..d4b1c87f 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -137,8 +137,9 @@ def _finalize(self): def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: deltas = [] for call_delta in self.message_delta["tool_call_chunks"]: - # try to retrieve the matching snapshot based on index - call_snapshot = next( + # First match chunks by index because streaming chunks come in sequence (0,1,2...) + # and this index lets us correlate deltas to their snapshots during streaming + chunk_snapshot = next( ( c for c in self.message_snapshot["tool_call_chunks"] @@ -147,22 +148,35 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: None, ) - if call_snapshot: - tool = next( - (t for t in tools if t.name == call_snapshot.get("name")), None + if chunk_snapshot and chunk_snapshot.get("id"): + # Once we have the matching chunk, use its ID to find the full tool call + # The full tool calls contain properly parsed arguments (as Python dicts) + # while chunks just contain raw JSON strings + call_snapshot = next( + ( + c + for c in self.message_snapshot["tool_calls"] + if c.get("id") == chunk_snapshot["id"] + ), + None, ) - deltas.append( - AgentToolCallDelta( - agent=self.agent, - tool_call_delta=call_delta, - tool_call_snapshot=call_snapshot, - tool=tool, - args=pydantic_core.from_json( - call_snapshot["args"] or "{}", allow_partial=True - ), - agent_message_id=self.message_snapshot.get("id"), + + if call_snapshot: + tool = next( + (t for t in tools if t.name == call_snapshot.get("name")), None + ) + # Use call_snapshot.args which is already parsed into a Python dict + # This avoids issues with pydantic's more limited JSON parser + deltas.append( + AgentToolCallDelta( + agent=self.agent, + tool_call_delta=call_delta, + tool_call_snapshot=call_snapshot, + tool=tool, + args=call_snapshot.get("args", {}), + agent_message_id=self.message_snapshot.get("id"), + ) ) - ) return deltas def to_content_delta(self) -> "AgentContentDelta": diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index 55fc1579..6dfd7447 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -103,7 +103,9 @@ def get_status_style(self) -> tuple[str | Spinner, str, str]: "gray50", ) # Animated spinner, softer running state - def render_completion_tool(self) -> Panel: + def render_completion_tool( + self, show_inputs: bool = False, show_outputs: bool = False + ) -> Panel: """Special rendering for completion tools.""" table = Table.grid(padding=0, expand=True) header = Table.grid(padding=1) @@ -114,33 +116,30 @@ def render_completion_tool(self) -> Panel: is_fail_tool = self.tool.metadata.get("is_fail_tool", False) task = self.tool.metadata.get("completion_task") task_name = task.friendly_name() if task else "Unknown Task" + # completion tools store their results on the task, rather than returning them directly task_result = task.result if task else None if not self.is_complete: - # Running state - muted style icon = Spinner("dots") message = f"Working on task: {task_name}" text_style = "dim" border_style = "gray50" else: if self.is_error: - # Tool execution failed (error in the tool itself) icon = "❌" message = f"Error marking task status: {task_name}" text_style = "red" border_style = "red" - if self.result: + if show_outputs and self.result: message += f"\nError: {self.result}" elif is_fail_tool: - # Failure tool succeeded (task is being marked as failed) icon = "❌" message = f"Task failed: {task_name}" text_style = "red" border_style = "red" - if task_result: - message += f"\nReason: {task_result or 'No reason provided'}" + if show_outputs and task_result: + message += f"\nReason: {task_result}" else: - # Success tool succeeded (task completed normally) icon = "✓" message = f"Task complete: {task_name}" text_style = "dim" @@ -149,6 +148,33 @@ def render_completion_tool(self) -> Panel: header.add_row(icon, f"[{text_style}]{message}[/]") table.add_row(header) + # Show details (streaming args or final result) + if show_outputs and self.args: + details = Table.grid(padding=(0, 2)) + details.add_column(style="dim", width=9) + details.add_column() + + # If complete and successful, show task_result + if ( + self.is_complete + and not self.is_error + and not is_fail_tool + and task_result + ): + label = "Result" if is_success_tool else "Reason" + details.add_row( + f" {label}:", + f"{task_result}", + ) + # Otherwise show streaming args + else: + label = "Result" if is_success_tool else "Reason" + details.add_row( + f" {label}:", + rich.pretty.Pretty(self.args, indent_size=2, expand_all=True), + ) + table.add_row(details) + return Panel( table, title=f"[bold]Agent: {self.agent_name}[/]", @@ -161,13 +187,17 @@ def render_completion_tool(self) -> Panel: padding=(0, 1), ) - def render_panel(self, show_details: bool = True) -> Panel: + def render_panel( + self, + show_inputs: bool = True, + show_outputs: bool = True, + ) -> Panel: """Render tool state as a panel with status indicator.""" - # Handle completion tools separately if self.tool and self.tool.metadata.get("is_completion_tool"): - return self.render_completion_tool() + return self.render_completion_tool( + show_inputs=show_inputs, show_outputs=show_outputs + ) - # Regular tool display logic continues as before... icon, text_style, border_style = self.get_status_style() table = Table.grid(padding=0, expand=True) @@ -178,18 +208,18 @@ def render_panel(self, show_details: bool = True) -> Panel: header.add_row(icon, f"[{text_style} bold]{tool_name}[/]") table.add_row(header) - if show_details: + if show_inputs or show_outputs: details = Table.grid(padding=(0, 2)) details.add_column(style="dim", width=9) details.add_column() - if self.args: + if show_inputs and self.args: details.add_row( " Input:", rich.pretty.Pretty(self.args, indent_size=2, expand_all=True), ) - if self.is_complete and self.result: + if show_outputs and self.is_complete and self.result: label = "Error" if self.is_error else "Output" style = "red" if self.is_error else "green3" details.add_row( @@ -213,9 +243,21 @@ def render_panel(self, show_details: bool = True) -> Panel: class PrintHandler(Handler): - def __init__(self, include_completion_tools: bool = True): + def __init__( + self, + show_completion_tools: bool = True, + show_tool_inputs: bool = True, + show_tool_outputs: bool = True, + show_completion_tool_results: bool = False, + ): super().__init__() - self.include_completion_tools = include_completion_tools + # Tool display settings + self.show_completion_tools = show_completion_tools + self.show_tool_inputs = show_tool_inputs + self.show_tool_outputs = show_tool_outputs + # Completion tool specific settings + self.show_completion_tool_results = show_completion_tool_results + self.live: Optional[Live] = None self.paused_id: Optional[str] = None self.states: dict[str, DisplayState] = {} @@ -225,14 +267,34 @@ def update_display(self): if not self.live or not self.live.is_started or self.paused_id: return - # Sort states by timestamp and render panels sorted_states = sorted(self.states.values(), key=lambda s: s.first_timestamp) - panels = [ - state.render_panel(show_details=self.include_completion_tools) - if isinstance(state, ToolState) - else state.render_panel() - for state in sorted_states - ] + panels = [] + + for state in sorted_states: + if isinstance(state, ToolState): + is_completion_tool = state.tool and state.tool.metadata.get( + "is_completion_tool" + ) + + # Skip completion tools if disabled + if not self.show_completion_tools and is_completion_tool: + continue + + if is_completion_tool: + panels.append( + state.render_completion_tool( + show_outputs=self.show_completion_tool_results + ) + ) + else: + panels.append( + state.render_panel( + show_inputs=self.show_tool_inputs, + show_outputs=self.show_tool_outputs, + ) + ) + else: + panels.append(state.render_panel()) if panels: self.live.update(Group(*panels), refresh=True) @@ -287,13 +349,13 @@ def on_tool_result(self, event: ToolResult): if self.paused_id == event.tool_result.tool_call["id"]: self.paused_id = None print() - self.live = Live(console=cf_console, auto_refresh=True) + self.live = Live(console=cf_console, auto_refresh=False) self.live.start() return - # Skip completion tools if configured + # Skip completion tools if disabled if ( - not self.include_completion_tools + not self.show_completion_tools and event.tool_result.tool and event.tool_result.tool.metadata.get("is_completion_tool") ): From 994f4bc93cfdeb45108a77a4f9c1af24327111ca Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:39:59 -0500 Subject: [PATCH 08/12] Ensure live updates spinner --- src/controlflow/handlers/print_handler.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index 6dfd7447..3de7d72c 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -21,6 +21,9 @@ from controlflow.tools.tools import Tool from controlflow.utilities.rich import console as cf_console +# Global spinner for consistent animation +RUNNING_SPINNER = Spinner("dots") + class DisplayState(BaseModel): """Base class for content to be displayed.""" @@ -98,10 +101,10 @@ def get_status_style(self) -> tuple[str | Spinner, str, str]: else: return "✅", "green", "green3" # Slightly softer green return ( - Spinner("dots"), + RUNNING_SPINNER, "yellow", "gray50", - ) # Animated spinner, softer running state + ) # Use shared spinner instance def render_completion_tool( self, show_inputs: bool = False, show_outputs: bool = False @@ -120,7 +123,7 @@ def render_completion_tool( task_result = task.result if task else None if not self.is_complete: - icon = Spinner("dots") + icon = RUNNING_SPINNER # Use shared spinner instance message = f"Working on task: {task_name}" text_style = "dim" border_style = "gray50" @@ -349,7 +352,11 @@ def on_tool_result(self, event: ToolResult): if self.paused_id == event.tool_result.tool_call["id"]: self.paused_id = None print() - self.live = Live(console=cf_console, auto_refresh=False) + self.live = Live( + console=cf_console, + vertical_overflow="visible", + auto_refresh=True, + ) self.live.start() return @@ -374,7 +381,9 @@ def on_tool_result(self, event: ToolResult): def on_orchestrator_start(self, event: OrchestratorStart): """Initialize live display.""" self.live = Live( - auto_refresh=True, console=cf_console, vertical_overflow="visible" + console=cf_console, + vertical_overflow="visible", + auto_refresh=True, ) self.states.clear() try: From c6fac8be98b88eafa9dd0b933d3b691304fbfd51 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:42:29 -0500 Subject: [PATCH 09/12] Update defaults --- src/controlflow/orchestration/orchestrator.py | 3 ++- src/controlflow/settings.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 7edafb4f..73fdeb65 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -84,7 +84,8 @@ def _validate_handlers(cls, v): if v is None and controlflow.settings.enable_default_print_handler: v = [ PrintHandler( - include_completion_tools=controlflow.settings.default_print_handler_include_completion_tools + show_completion_tools=controlflow.settings.default_print_handler_show_completion_tools, + show_completion_tool_results=controlflow.settings.default_print_handler_show_completion_tool_results, ) ] return v or [] diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index cb2fbe72..4f4c8b0f 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -75,10 +75,14 @@ def _validate_pretty_print_agent_events(cls, data: dict) -> dict: description="If True, a PrintHandler will be enabled and automatically " "pretty-print agent events and completion tools.", ) - default_print_handler_include_completion_tools: bool = Field( + default_print_handler_show_completion_tools: bool = Field( default=True, description="If True, the default PrintHandler will include completion tools.", ) + default_print_handler_show_completion_tool_results: bool = Field( + default=False, + description="If True, the default PrintHandler will show the full results of completion tools.", + ) # ------------ orchestration settings ------------ orchestrator_max_agent_turns: Optional[int] = Field( From 072e460df106d586317e481f8fbb7654f5de009f Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:36:54 -0500 Subject: [PATCH 10/12] Fix import --- tests/tools/test_lc_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tools/test_lc_tools.py b/tests/tools/test_lc_tools.py index 3e380843..89267926 100644 --- a/tests/tools/test_lc_tools.py +++ b/tests/tools/test_lc_tools.py @@ -5,7 +5,7 @@ from pydantic import BaseModel import controlflow -from controlflow.events.events import AIMessage, ToolCall +from controlflow.events.events import AgentToolCall, AIMessage class LCBaseToolInput(BaseModel): @@ -26,7 +26,7 @@ def test_lc_base_tool(default_fake_llm, monkeypatch): AIMessage( content="", tool_calls=[ - ToolCall( + AgentToolCall( id="abc", name="TestTool", args={"x": 3}, @@ -52,7 +52,7 @@ def test_ddg_tool(default_fake_llm, monkeypatch): AIMessage( content="", tool_calls=[ - ToolCall( + AgentToolCall( id="abc", name="duckduckgo_search", args={"query": "top business headlines"}, From 0b0c9fb312dddebaca6fd7ed0f3edf71172e5bdf Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 22:08:36 -0500 Subject: [PATCH 11/12] Clean up tests --- src/controlflow/events/events.py | 9 +++--- src/controlflow/orchestration/orchestrator.py | 31 ++++--------------- tests/tools/test_lc_tools.py | 6 ++-- tests/utilities/test_testing.py | 14 ++++----- 4 files changed, 20 insertions(+), 40 deletions(-) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index d4b1c87f..b7c9f413 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -12,8 +12,7 @@ HumanMessage, ToolMessage, ) -from controlflow.tools.tools import InvalidToolCall, Tool -from controlflow.tools.tools import ToolCall as ToolCallPayload +from controlflow.tools.tools import InvalidToolCall, Tool, ToolCall from controlflow.tools.tools import ToolResult as ToolResultPayload from controlflow.utilities.logging import get_logger @@ -136,13 +135,13 @@ def _finalize(self): def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: deltas = [] - for call_delta in self.message_delta["tool_call_chunks"]: + for call_delta in self.message_delta.get("tool_call_chunks", []): # First match chunks by index because streaming chunks come in sequence (0,1,2...) # and this index lets us correlate deltas to their snapshots during streaming chunk_snapshot = next( ( c - for c in self.message_snapshot["tool_call_chunks"] + for c in self.message_snapshot.get("tool_call_chunks", []) if c.get("index", -1) == call_delta.get("index", -2) ), None, @@ -210,7 +209,7 @@ class AgentToolCall(Event): event: Literal["tool-call"] = "tool-call" agent: Agent agent_message_id: Optional[str] = None - tool_call: Union[ToolCallPayload, InvalidToolCall] + tool_call: Union[ToolCall, InvalidToolCall] tool: Optional[Tool] = None args: dict = {} diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 73fdeb65..10d04269 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -644,32 +644,10 @@ async def _run_agent_turn_async( async def _run_async( self, - max_llm_calls: Optional[int] = None, - max_agent_turns: Optional[int] = None, + run_context: RunContext, model_kwargs: Optional[dict] = None, - run_until: Optional[ - Union[RunEndCondition, Callable[[RunContext], bool]] - ] = None, ) -> AsyncIterator[Event]: - """Async version of _run.""" - # Create the base termination condition - if run_until is None: - run_until = AllComplete() - elif not isinstance(run_until, RunEndCondition): - run_until = FnCondition(run_until) - - # Add max_llm_calls condition - if max_llm_calls is None: - max_llm_calls = controlflow.settings.orchestrator_max_llm_calls - run_until = run_until | MaxLLMCalls(max_llm_calls) - - # Add max_agent_turns condition - if max_agent_turns is None: - max_agent_turns = controlflow.settings.orchestrator_max_agent_turns - run_until = run_until | MaxAgentTurns(max_agent_turns) - - run_context = RunContext(orchestrator=self, run_end_condition=run_until) - + """Run the orchestrator asynchronously, yielding events as they occur.""" # Initialize the agent if not already set if not self.agent: self.agent = self.turn_strategy.get_next_agent( @@ -686,6 +664,7 @@ async def _run_async( yield AgentTurnStart(orchestrator=self, agent=self.agent) + # Run turn and yield its events async for event in self._run_agent_turn_async( run_context=run_context, model_kwargs=model_kwargs, @@ -701,10 +680,12 @@ async def _run_async( ) except Exception as exc: + # Yield error event if something goes wrong yield OrchestratorError(orchestrator=self, error=exc) raise finally: - yield OrchestratorEnd(orchestrator=self) + # Signal the end of orchestration + yield OrchestratorEnd(orchestrator=self, run_context=run_context) # Rebuild all models with forward references after Orchestrator is defined diff --git a/tests/tools/test_lc_tools.py b/tests/tools/test_lc_tools.py index 89267926..3e380843 100644 --- a/tests/tools/test_lc_tools.py +++ b/tests/tools/test_lc_tools.py @@ -5,7 +5,7 @@ from pydantic import BaseModel import controlflow -from controlflow.events.events import AgentToolCall, AIMessage +from controlflow.events.events import AIMessage, ToolCall class LCBaseToolInput(BaseModel): @@ -26,7 +26,7 @@ def test_lc_base_tool(default_fake_llm, monkeypatch): AIMessage( content="", tool_calls=[ - AgentToolCall( + ToolCall( id="abc", name="TestTool", args={"x": 3}, @@ -52,7 +52,7 @@ def test_ddg_tool(default_fake_llm, monkeypatch): AIMessage( content="", tool_calls=[ - AgentToolCall( + ToolCall( id="abc", name="duckduckgo_search", args={"query": "top business headlines"}, diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index 2a4e78fb..380bdd5f 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -37,15 +37,15 @@ def test_record_task_events(default_fake_llm): assert response == events[1].ai_message assert events[3].event == "tool-result" - assert events[3].tool_call == { + assert events[3].tool_result.tool_call == { "name": "mark_task_12345_successful", "args": {"task_result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", } - assert events[3].tool_result.model_dump() == dict( - tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe", - str_result='Task #12345 ("say hello") marked successful.', - is_error=False, - tool_metadata={"is_completion_tool": True}, - ) + tool_result = events[3].tool_result.model_dump() + assert tool_result["tool_call"]["id"] == "call_ZEPdV8mCgeBe5UHjKzm6e3pe" + assert tool_result["str_result"] == 'Task #12345 ("say hello") marked successful.' + assert not tool_result["is_error"] + assert tool_result["tool"]["metadata"]["is_completion_tool"] + assert tool_result["tool"]["metadata"]["is_success_tool"] From afe44ad1ee05fe02d6ceed0a8597688960aa9eab Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 11 Nov 2024 22:15:20 -0500 Subject: [PATCH 12/12] Fix invalid 3.9 syntax --- src/controlflow/handlers/print_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py index 3de7d72c..cc7ddf5d 100644 --- a/src/controlflow/handlers/print_handler.py +++ b/src/controlflow/handlers/print_handler.py @@ -1,5 +1,5 @@ import datetime -from typing import Optional +from typing import Optional, Union import rich from pydantic import BaseModel @@ -93,7 +93,7 @@ class ToolState(DisplayState): is_complete: bool = False tool: Optional[Tool] = None - def get_status_style(self) -> tuple[str | Spinner, str, str]: + def get_status_style(self) -> tuple[Union[str, Spinner], str, str]: """Returns (icon, text style, border style) for current status.""" if self.is_complete: if self.is_error: