From 3bbdd0bd6ebd34b099e725873410271fa151d5ae Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:44:57 -0500 Subject: [PATCH 1/8] Add streaming kwarg to run --- src/controlflow/__init__.py | 2 +- src/controlflow/events/base.py | 2 +- src/controlflow/events/events.py | 38 ++- src/controlflow/orchestration/orchestrator.py | 200 ++++++++--- src/controlflow/run.py | 69 +++- src/controlflow/stream.py | 316 ++++++++---------- tests/test_run.py | 35 ++ 7 files changed, 400 insertions(+), 262 deletions(-) diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index 27a0218c..8fd513b7 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -16,7 +16,7 @@ from .instructions import instructions from .decorators import flow, task from .tools import tool -from .run import run, run_async, run_tasks, run_tasks_async +from .run import run, run_async, run_tasks, run_tasks_async, Stream from .plan import plan import controlflow.orchestration diff --git a/src/controlflow/events/base.py b/src/controlflow/events/base.py index aad788cb..ad794042 100644 --- a/src/controlflow/events/base.py +++ b/src/controlflow/events/base.py @@ -30,7 +30,7 @@ def to_messages(self, context: "CompileContext") -> list["BaseMessage"]: return [] def __repr__(self) -> str: - return f"<{self.event} {self.timestamp}>" + return f"" class UnpersistedEvent(Event): diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index b7c9f413..bf3adf7c 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -90,15 +90,17 @@ def to_tool_calls(self, tools: list[Tool]) -> list["AgentToolCall"]: ) return calls - def to_content(self) -> "AgentContent": - return AgentContent( - agent=self.agent, - content=self.message["content"], - agent_message_id=self.message.get("id"), - ) + def to_content(self) -> Optional["AgentContent"]: + if self.message.get("content"): + return AgentContent( + agent=self.agent, + content=self.message["content"], + agent_message_id=self.message.get("id"), + ) def all_related_events(self, tools: list[Tool]) -> list[Event]: - return [self, self.to_content()] + self.to_tool_calls(tools) + content = self.to_content() + return [self] + ([content] if content else []) + self.to_tool_calls(tools) def to_messages(self, context: "CompileContext") -> list[BaseMessage]: if self.agent.name == context.agent.name: @@ -178,16 +180,22 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: ) return deltas - def to_content_delta(self) -> "AgentContentDelta": - return AgentContentDelta( - agent=self.agent, - content_delta=self.message_delta["content"], - content_snapshot=self.message_snapshot["content"], - agent_message_id=self.message_snapshot.get("id"), - ) + def to_content_delta(self) -> Optional["AgentContentDelta"]: + if self.message_delta.get("content"): + return AgentContentDelta( + agent=self.agent, + content_delta=self.message_delta["content"], + content_snapshot=self.message_snapshot["content"], + agent_message_id=self.message_snapshot.get("id"), + ) def all_related_events(self, tools: list[Tool]) -> list[Event]: - return [self, self.to_content_delta()] + self.to_tool_call_deltas(tools) + content_delta = self.to_content_delta() + return ( + [self] + + ([content_delta] if content_delta else []) + + self.to_tool_call_deltas(tools) + ) class AgentContent(UnpersistedEvent): diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 10d04269..a4f68242 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,7 +1,7 @@ import logging -from typing import AsyncIterator, Callable, Iterator, Optional, TypeVar, Union +from typing import AsyncIterator, Callable, Iterator, Optional, Set, TypeVar, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, PrivateAttr, field_validator import controlflow from controlflow.agents.agent import Agent @@ -61,6 +61,7 @@ class Orchestrator(ControlFlowModel): handlers: list[Union[Handler, AsyncHandler]] = Field( None, validate_default=True, exclude=True ) + _processed_event_ids: Set[str] = PrivateAttr(default_factory=set) @field_validator("turn_strategy", mode="before") def _validate_turn_strategy(cls, v): @@ -90,28 +91,42 @@ def _validate_handlers(cls, v): ] return v or [] - def handle_event(self, event: Event): + def handle_event(self, event: Event) -> Event: """ Handle an event by passing it to all handlers and persisting if necessary. + Includes idempotency check to prevent double-processing events. Args: event (Event): The event to handle. """ from controlflow.events.events import AgentContentDelta + # Skip if we've already processed this event + if event.id in self._processed_event_ids: + return event + for handler in self.handlers: if isinstance(handler, Handler): handler.handle(event) if event.persist: self.flow.add_events([event]) - async def handle_event_async(self, event: Event): + # Mark event as processed + self._processed_event_ids.add(event.id) + return event + + async def handle_event_async(self, event: Event) -> Event: """ Handle an event asynchronously by passing it to all handlers and persisting if necessary. + Includes idempotency check to prevent double-processing events. Args: event (Event): The event to handle. """ + # Skip if we've already processed this event + if event.id in self._processed_event_ids: + return event + if not isinstance(event, AgentMessageDelta): logger.debug(f"Handling event asynchronously: {repr(event)}") for handler in self.handlers: @@ -122,6 +137,10 @@ async def handle_event_async(self, event: Event): if event.persist: self.flow.add_events([event]) + # Mark event as processed + self._processed_event_ids.add(event.id) + return event + def get_available_agents(self) -> dict[Agent, list[Task]]: """ Get a dictionary of all available agents for active tasks, mapped to @@ -225,7 +244,60 @@ def _run_agent_turn( run_context.agent_turns += 1 - @prefect_task(task_run_name="Orchestrator.run()") + @prefect_task(task_run_name="Run agent orchestrator") + def _run( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> Iterator[Event]: + """Run the orchestrator, yielding handled events as they occur.""" + # Initialize the agent if not already set + if not self.agent: + self.agent = self.turn_strategy.get_next_agent( + None, self.get_available_agents() + ) + + # Signal the start of orchestration + yield self.handle_event( + OrchestratorStart(orchestrator=self, run_context=run_context) + ) + + try: + while True: + if run_context.should_end(): + break + + yield self.handle_event( + AgentTurnStart(orchestrator=self, agent=self.agent) + ) + + # Run turn and yield its events + for event in self._run_agent_turn( + run_context=run_context, + model_kwargs=model_kwargs, + ): + yield self.handle_event(event) + + yield self.handle_event( + AgentTurnEnd(orchestrator=self, agent=self.agent) + ) + + # Select the next agent for the following turn + if available_agents := self.get_available_agents(): + self.agent = self.turn_strategy.get_next_agent( + self.agent, available_agents + ) + + except Exception as exc: + # Yield error event if something goes wrong + yield self.handle_event(OrchestratorError(orchestrator=self, error=exc)) + raise + finally: + # Signal the end of orchestration + yield self.handle_event( + OrchestratorEnd(orchestrator=self, run_context=run_context) + ) + def run( self, max_llm_calls: Optional[int] = None, @@ -234,10 +306,21 @@ def run( run_until: Optional[ Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, - ) -> RunContext: + stream: bool = False, + ) -> Union[RunContext, Iterator[Event]]: """ - Run the orchestrator, handling events internally. - Returns the final run context. + Run the orchestrator. + + Args: + max_llm_calls: Maximum number of LLM calls allowed + max_agent_turns: Maximum number of agent turns allowed + model_kwargs: Additional kwargs for the model + run_until: Condition for ending the run + stream: If True, return iterator of events. If False, consume events and return context + + Returns: + If stream=True, returns Iterator[Event] + If stream=False, returns RunContext """ # Create run context at the outermost level if run_until is None: @@ -257,19 +340,26 @@ def run( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - for event in self._run( + iterator = self._run( run_context=run_context, model_kwargs=model_kwargs, - ): - self.handle_event(event) + ) + + if stream: + return iterator + + # Consume iterator if not streaming + for _ in iterator: + pass return run_context - def _run( + @prefect_task(task_run_name="Run agent orchestrator") + async def _run_async( self, run_context: RunContext, model_kwargs: Optional[dict] = None, - ) -> Iterator[Event]: - """Run the orchestrator, yielding events as they occur.""" + ) -> AsyncIterator[Event]: + """Run the orchestrator asynchronously, yielding handled events as they occur.""" # Initialize the agent if not already set if not self.agent: self.agent = self.turn_strategy.get_next_agent( @@ -277,23 +367,29 @@ def _run( ) # Signal the start of orchestration - yield OrchestratorStart(orchestrator=self, run_context=run_context) + yield await self.handle_event_async( + OrchestratorStart(orchestrator=self, run_context=run_context) + ) try: while True: if run_context.should_end(): break - yield AgentTurnStart(orchestrator=self, agent=self.agent) + yield await self.handle_event_async( + AgentTurnStart(orchestrator=self, agent=self.agent) + ) # Run turn and yield its events - for event in self._run_agent_turn( + async for event in self._run_agent_turn_async( run_context=run_context, model_kwargs=model_kwargs, ): - yield event + yield await self.handle_event_async(event) - yield AgentTurnEnd(orchestrator=self, agent=self.agent) + yield await self.handle_event_async( + AgentTurnEnd(orchestrator=self, agent=self.agent) + ) # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -303,13 +399,16 @@ def _run( except Exception as exc: # Yield error event if something goes wrong - yield OrchestratorError(orchestrator=self, error=exc) + yield await self.handle_event_async( + OrchestratorError(orchestrator=self, error=exc) + ) raise finally: # Signal the end of orchestration - yield OrchestratorEnd(orchestrator=self, run_context=run_context) + yield await self.handle_event_async( + OrchestratorEnd(orchestrator=self, run_context=run_context) + ) - @prefect_task async def run_async( self, max_llm_calls: Optional[int] = None, @@ -318,10 +417,21 @@ async def run_async( run_until: Optional[ Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, - ) -> RunContext: + stream: bool = False, + ) -> Union[RunContext, AsyncIterator[Event]]: """ - Run the orchestrator asynchronously, handling events internally. - Returns the final run context. + Run the orchestrator asynchronously. + + Args: + max_llm_calls: Maximum number of LLM calls allowed + max_agent_turns: Maximum number of agent turns allowed + model_kwargs: Additional kwargs for the model + run_until: Condition for ending the run + stream: If True, return async iterator of events. If False, consume events and return context + + Returns: + If stream=True, returns AsyncIterator[Event] + If stream=False, returns RunContext """ # Create run context at the outermost level if run_until is None: @@ -341,11 +451,17 @@ async def run_async( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - async for event in self._run_async( + iterator = self._run_async( run_context=run_context, model_kwargs=model_kwargs, - ): - await self.handle_event_async(event) + ) + + if stream: + return iterator + + # Consume iterator if not streaming + async for _ in iterator: + pass return run_context @prefect_task(task_run_name="Agent turn: {self.agent.name}") @@ -353,7 +469,7 @@ def run_agent_turn( self, run_context: RunContext, model_kwargs: Optional[dict] = None, - ) -> int: + ) -> Iterator[Event]: """ Run a single agent turn, which may consist of multiple LLM calls. """ @@ -365,11 +481,9 @@ def run_agent_turn( for task in assigned_tasks: if not task.is_running(): task.mark_running() - self.handle_event( - OrchestratorMessage( - content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " - f"with objective: {task.objective}" - ) + yield OrchestratorMessage( + content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " + f"with objective: {task.objective}" ) while not self.turn_strategy.should_end_turn(): @@ -394,7 +508,7 @@ def run_agent_turn( tools=tools, model_kwargs=model_kwargs, ): - self.handle_event(event) + yield event run_context.llm_calls += 1 for task in assigned_tasks: @@ -402,20 +516,18 @@ def run_agent_turn( run_context.agent_turns += 1 - @prefect_task + @prefect_task(task_run_name="Agent turn: {self.agent.name}") async def run_agent_turn_async( self, run_context: RunContext, model_kwargs: Optional[dict] = None, - ) -> int: + ) -> AsyncIterator[Event]: """ Run a single agent turn asynchronously, which may consist of multiple LLM calls. Args: max_llm_calls (Optional[int]): The number of LLM calls allowed. - Returns: - int: The number of LLM calls made during this turn. """ assigned_tasks = self.get_tasks("assigned") @@ -425,11 +537,9 @@ async def run_agent_turn_async( for task in assigned_tasks: if not task.is_running(): task.mark_running() - await self.handle_event_async( - OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) " - f"with objective: {task.objective}" - ) + yield OrchestratorMessage( + content=f"Starting task {task.name} (ID {task.id}) " + f"with objective: {task.objective}" ) while not self.turn_strategy.should_end_turn(): @@ -454,7 +564,7 @@ async def run_agent_turn_async( tools=tools, model_kwargs=model_kwargs, ): - await self.handle_event_async(event) + yield event run_context.llm_calls += 1 for task in assigned_tasks: diff --git a/src/controlflow/run.py b/src/controlflow/run.py index dc4e285d..bd734d43 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -1,25 +1,17 @@ -from typing import Any, Callable, Optional, Union - -from prefect.context import TaskRunContext +from typing import Any, Callable, Iterator, Optional, Union import controlflow from controlflow.agents.agent import Agent +from controlflow.events.events import Event from controlflow.flows import Flow, get_flow from controlflow.orchestration.conditions import RunContext, RunEndCondition from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy +from controlflow.stream import Stream, filter_events from controlflow.tasks.task import Task from controlflow.utilities.prefect import prefect_task -def get_task_run_name() -> str: - context = TaskRunContext.get() - tasks = context.parameters["tasks"] - task_names = " | ".join(t.friendly_name() for t in tasks) - return f"Run task{'s' if len(tasks) > 1 else ''}: {task_names}" - - -@prefect_task(task_run_name=get_task_run_name) def run_tasks( tasks: list[Task], instructions: str = None, @@ -32,11 +24,29 @@ def run_tasks( handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, -) -> list[Any]: + stream: Union[bool, Stream] = False, +) -> Union[list[Any], Iterator[tuple[Event, Any, Optional[Any]]]]: """ Run a list of tasks. - Returns a list of task results corresponding to the input tasks, or raises an error if any tasks failed. + Args: + tasks: List of tasks to run. + instructions: Instructions for the tasks. + flow: Flow to run the tasks in. + agent: Agent to run the tasks with. + turn_strategy: Turn strategy to use for the tasks. + raise_on_failure: Whether to raise an error if any tasks fail. + max_llm_calls: Maximum number of LLM calls to make. + max_agent_turns: Maximum number of agent turns to make. + handlers: List of handlers to use for the tasks. + model_kwargs: Keyword arguments to pass to the LLM. + run_until: Condition to stop running tasks. + stream: If True, stream all events. Can also provide StreamFilter flags to filter specific events. + e.g. StreamFilter.CONTENT | StreamFilter.AGENT_TOOLS + + Returns: + If not streaming: List of task results + If streaming: Iterator of (event, snapshot, delta) tuples """ flow = flow or get_flow() or Flow() @@ -49,13 +59,19 @@ def run_tasks( ) with controlflow.instructions(instructions): - orchestrator.run( + result = orchestrator.run( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, model_kwargs=model_kwargs, run_until=run_until, + stream=bool(stream), ) + if stream: + # Convert True to ALL filter, otherwise use provided filter + stream_filter = Stream.ALL if stream is True else stream + return filter_events(result, Stream._filter_names(stream_filter)) + if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] if errors: @@ -67,7 +83,6 @@ def run_tasks( return [t.result for t in tasks] -@prefect_task(task_run_name=get_task_run_name) async def run_tasks_async( tasks: list[Task], instructions: str = None, @@ -122,8 +137,24 @@ def run( handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, + stream: Union[bool, Stream] = False, **task_kwargs, -) -> Any: +) -> Union[Any, Iterator[tuple[Event, Any, Optional[Any]]]]: + """ + Run a single task. + + Args: + objective: Objective of the task. + turn_strategy: Turn strategy to use for the task. + max_llm_calls: Maximum number of LLM calls to make. + max_agent_turns: Maximum number of agent turns to make. + raise_on_failure: Whether to raise an error if the task fails. + handlers: List of handlers to use for the task. + model_kwargs: Keyword arguments to pass to the LLM. + run_until: Condition to stop running the task. + stream: If True, stream all events. Can also provide StreamFilter flags to filter specific events. + e.g. StreamFilter.CONTENT | StreamFilter.AGENT_TOOLS + """ task = Task(objective=objective, **task_kwargs) results = run_tasks( tasks=[task], @@ -134,8 +165,12 @@ def run( handlers=handlers, model_kwargs=model_kwargs, run_until=run_until, + stream=stream, ) - return results[0] + if stream: + return results + else: + return results[0] async def run_async( diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py index e8552ebf..d4d5650b 100644 --- a/src/controlflow/stream.py +++ b/src/controlflow/stream.py @@ -1,24 +1,6 @@ -# Example usage -# -# # Stream all events -# for event in cf.stream.events("Write a story"): -# print(event) -# -# # Stream just messages -# for event in cf.stream.events("Write a story", events='messages'): -# print(event.content) -# -# # Stream just the result -# for delta, snapshot in cf.stream.result("Write a story"): -# print(f"New: {delta}") -# -# # Stream results from multiple tasks -# for delta, snapshot in cf.stream.result_from_tasks([task1, task2]): -# print(f"New result: {delta}") -# -from typing import Any, AsyncIterator, Callable, Iterator, Literal, Optional, Union - -from controlflow.events.base import Event +from enum import Flag, auto +from typing import Any, Iterator, Optional + from controlflow.events.events import ( AgentContent, AgentContentDelta, @@ -26,184 +8,152 @@ AgentMessageDelta, AgentToolCall, AgentToolCallDelta, + Event, ToolResult, ) -from controlflow.orchestration.handler import AsyncHandler, Handler -from controlflow.orchestration.orchestrator import Orchestrator -from controlflow.tasks.task import Task - -StreamEvents = Union[ - list[str], - Literal["all", "messages", "content", "tools", "completion_tools", "agent_tools"], -] - - -def event_filter(events: StreamEvents) -> Callable[[Event], bool]: - def _event_filter(event: Event) -> bool: - if events == "all": - return True - elif events == "messages": - return isinstance(event, (AgentMessage, AgentMessageDelta)) - elif events == "content": - return isinstance(event, (AgentContent, AgentContentDelta)) - elif events == "tools": - return isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)) - elif events == "completion_tools": - if isinstance(event, (AgentToolCall, AgentToolCallDelta)): - return event.tool and event.tool.metadata.get("is_completion_tool") - elif isinstance(event, ToolResult): - return event.tool_result and event.tool_result.tool.metadata.get( - "is_completion_tool" - ) - return False - elif events == "agent_tools": - if isinstance(event, (AgentToolCall, AgentToolCallDelta)): - return event.tool and event.tool in event.agent.get_tools() - elif isinstance(event, ToolResult): - return ( - event.tool_result - and event.tool_result.tool in event.agent.get_tools() - ) - return False - else: - raise ValueError(f"Invalid event type: {events}") - - return _event_filter - - -# -------------------- BELOW HERE IS THE OLD STUFF -------------------- -def events( - objective: str, - *, - events: StreamEvents = "all", - filter_fn: Optional[Callable[[Event], bool]] = None, - **kwargs, -) -> Iterator[Event]: +class Stream(Flag): """ - Stream events from a task execution. - - Args: - objective: The task objective - events: Which events to stream. Can be list of event types or: - 'all' - all events - 'messages' - agent messages - 'tools' - all tool calls/results - 'completion_tools' - only completion tools - filter_fn: Optional additional filter function - **kwargs: Additional arguments passed to Task + Filter flags for event streaming. - Returns: - Iterator of Event objects + Can be combined using bitwise operators: + stream_filter = StreamFilter.CONTENT | StreamFilter.AGENT_TOOLS """ - def get_event_filter(): - if isinstance(events, list): - return lambda e: e.event in events - elif events == "messages": - return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) - elif events == "tools": - return lambda e: isinstance(e, (AgentToolCall, ToolResult)) - elif events == "completion_tools": - return lambda e: ( - isinstance(e, (AgentToolCall, ToolResult)) - and e.tool_call["name"].startswith("mark_task_") - ) - else: # 'all' - return lambda e: True - - event_filter = get_event_filter() - - def event_handler(event: Event): - if event_filter(event) and (not filter_fn or filter_fn(event)): - yield event - - task = Task(objective=objective) - task.run(handlers=[Handler(event_handler)], **kwargs) - - -def result( - objective: str, - **kwargs, -) -> Iterator[tuple[Any, Any]]: + NONE = 0 + CONTENT = auto() # Agent content and deltas + AGENT_TOOLS = auto() # Non-completion tool events + RESULTS = auto() # Completion tool events + TOOLS = AGENT_TOOLS | RESULTS # All tool events + ALL = CONTENT | TOOLS # Everything + + @classmethod + def _filter_names(cls, filter: "Stream") -> list[str]: + """Convert StreamFilter to list of filter names for filter_events()""" + names = [] + if filter & cls.CONTENT: + names.append("content") + if filter & cls.AGENT_TOOLS: + names.append("agent_tools") + if filter & cls.RESULTS: + names.append("results") + return names + + +def filter_events( + events: Iterator[Event], filters: list[str] +) -> Iterator[tuple[Event, Any, Optional[Any]]]: """ - Stream result from a task execution. - - Args: - objective: The task objective - **kwargs: Additional arguments passed to Task + Filter events based on a list of event types or shortcuts. + Returns tuples of (event, snapshot, delta) where snapshot and delta depend on the event type. Returns: - Iterator of (delta, accumulated) result tuples + Iterator of (event, snapshot, delta) tuples where: + - event: The original event + - snapshot: Full state (e.g., complete message, tool state) + - delta: Incremental change (None for non-delta events) + + Patterns for different event types: + - Content events: (event, full_text, new_text) + - Tool calls: (event, tool_state, tool_delta) + - Tool results: (event, result_state, None) """ - current_result = None - def result_handler(event: Event): - nonlocal current_result + def is_completion_tool_event(event: Event) -> bool: + """Check if an event is related to a completion tool call.""" if isinstance(event, ToolResult): - if event.tool_call["name"].startswith("mark_task_"): - result = event.tool_result.result # Get actual result value - if result != current_result: # Only yield if changed - current_result = result - yield (result, result) # For now delta == full result - - task = Task(objective=objective) - task.run(handlers=[Handler(result_handler)], **kwargs) - - -def events_from_tasks( - tasks: list[Task], - events: StreamEvents = "all", - filter_fn: Optional[Callable[[Event], bool]] = None, - **kwargs, -) -> Iterator[Event]: - """Stream events from multiple task executions.""" - - def get_event_filter(): - if isinstance(events, list): - return lambda e: e.event in events - elif events == "messages": - return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) - elif events == "tools": - return lambda e: isinstance(e, (AgentToolCall, ToolResult)) - elif events == "completion_tools": - return lambda e: ( - isinstance(e, (AgentToolCall, ToolResult)) - and e.tool_call["name"].startswith("mark_task_") - ) - else: # 'all' - return lambda e: True - - event_filter = get_event_filter() + tool = event.tool_result.tool + elif isinstance(event, (AgentToolCall, AgentToolCallDelta)): + tool = event.tool + else: + return False - def event_handler(event: Event): - if event_filter(event) and (not filter_fn or filter_fn(event)): - yield event + return tool and tool.metadata.get("is_completion_tool") - orchestrator = Orchestrator( - tasks=tasks, handlers=[Handler(event_handler)], **kwargs - ) - orchestrator.run() + def is_agent_tool_event(event: Event) -> bool: + """Check if an event is related to a regular (non-completion) tool call.""" + if isinstance(event, ToolResult): + tool = event.tool_result.tool + elif isinstance(event, (AgentToolCall, AgentToolCallDelta)): + tool = event.tool + else: + return False + return tool and not tool.metadata.get("is_completion_tool") -def result_from_tasks( - tasks: list[Task], - **kwargs, -) -> Iterator[tuple[Any, Any]]: - """Stream results from multiple task executions.""" - current_results = {task.id: None for task in tasks} + # Expand shortcuts to event types and build filtering predicates + event_filters = [] + for filter_name in filters: + if filter_name == "content": + event_filters.append( + lambda e: e.event in {"agent-content", "agent-content-delta"} + ) + elif filter_name == "tools": + event_filters.append( + lambda e: e.event + in { + "agent-tool-call", + "agent-tool-call-delta", + "tool-result", + } + ) + elif filter_name == "agent_tools": + event_filters.append( + lambda e: ( + e.event + in { + "agent-tool-call", + "agent-tool-call-delta", + "tool-result", + } + and is_agent_tool_event(e) + ) + ) + elif filter_name == "results": + event_filters.append( + lambda e: ( + e.event + in { + "agent-tool-call", + "agent-tool-call-delta", + "tool-result", + } + and is_completion_tool_event(e) + ) + ) + else: + # Raw event type + event_filters.append(lambda e, t=filter_name: e.event == t) + + def passes_filters(event: Event) -> bool: + return any(f(event) for f in event_filters) + + for event in events: + if not passes_filters(event): + continue + + # Message events + if isinstance(event, AgentMessage): + yield event, event.message, None + elif isinstance(event, AgentMessageDelta): + yield event, event.message_snapshot, event.message_delta + + # Content events + elif isinstance(event, AgentContent): + yield event, event.content, None + elif isinstance(event, AgentContentDelta): + yield event, event.content_snapshot, event.content_delta + + # Tool call events + elif isinstance(event, AgentToolCall): + yield event, event.tool_call, None + elif isinstance(event, AgentToolCallDelta): + yield event, event.tool_call_snapshot, event.tool_call_delta + + # Tool result events + elif isinstance(event, ToolResult): + yield event, event.tool_result, None - def result_handler(event: Event): - if isinstance(event, ToolResult): - if event.tool_call["name"].startswith("mark_task_"): - task_id = event.task.id - result = event.tool_result.result - if result != current_results[task_id]: - current_results[task_id] = result - yield (result, result) - - orchestrator = Orchestrator( - tasks=tasks, handlers=[Handler(result_handler)], **kwargs - ) - orchestrator.run() + else: + yield event, None, None diff --git a/tests/test_run.py b/tests/test_run.py index 7c14a25d..bbc65b4a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,10 +1,13 @@ +import controlflow from controlflow import instructions from controlflow.events.base import Event from controlflow.events.events import AgentMessage +from controlflow.llm.messages import AIMessage from controlflow.orchestration.conditions import AnyComplete, AnyFailed, MaxLLMCalls from controlflow.orchestration.handler import Handler from controlflow.run import run, run_async, run_tasks, run_tasks_async from controlflow.tasks.task import Task +from tests.fixtures.controlflow import default_fake_llm class TestHandlers: @@ -167,3 +170,35 @@ async def test_min_failed(self): assert task1.is_failed() assert task2.is_incomplete() assert task3.is_failed() + + +class TestRunStreaming: + def test_stream_all(self, default_fake_llm): + result = run("what's 2 + 2", stream=True, max_llm_calls=1) + r = list(result) + assert len(r) == 2 + assert r[0][0].event == "agent-content-delta" + assert r[1][0].event == "agent-content" + + def test_stream_content(self, default_fake_llm): + response = AIMessage( + id="run-2af8bb73-661f-4ec3-92ff-d7d8e3074926", + name="Marvin", + role="ai", + content="", + tool_calls=[ + { + "name": "mark_task_12345_successful", + "args": {"task_result": "Hello!"}, + "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", + "type": "tool_call", + } + ], + ) + + default_fake_llm.set_responses(["Hello!", response]) + result = run("say hello", stream=True, max_llm_calls=1) + r = list(result) + assert len(r) == 2 + assert r[0][0].event == "agent-content-delta" + assert r[1][0].event == "agent-content" From a4aa51ee3bfe96cf33c8856aad9deedbf29385ce Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 20:26:50 -0500 Subject: [PATCH 2/8] Add streaming config and tests --- src/controlflow/agents/agent.py | 24 +++++++- src/controlflow/events/events.py | 2 +- src/controlflow/run.py | 7 ++- src/controlflow/stream.py | 94 +++++++------------------------- src/controlflow/tasks/task.py | 39 ++++++++----- tests/test_run.py | 55 +++++++++++++++---- 6 files changed, 115 insertions(+), 106 deletions(-) diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 6c4381a9..185a833f 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -8,6 +8,7 @@ Any, AsyncGenerator, Generator, + Iterator, Optional, Union, ) @@ -43,10 +44,11 @@ from controlflow.utilities.prefect import create_markdown_artifact, prefect_task if TYPE_CHECKING: + from controlflow.events.events import Event + from controlflow.flows import Flow from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.turn_strategies import TurnStrategy - from controlflow.tasks import Task - from controlflow.tools.tools import Tool + from controlflow.stream import Stream logger = logging.getLogger(__name__) @@ -223,13 +225,29 @@ def run( *, turn_strategy: Optional["TurnStrategy"] = None, handlers: Optional[list["Handler"]] = None, + stream: Union[bool, "Stream"] = False, **task_kwargs, - ): + ) -> Union[Any, Iterator[tuple["Event", Any, Optional[Any]]]]: + """ + Run a task with this agent. + + Args: + objective: The objective to accomplish + turn_strategy: Optional turn strategy to use + handlers: Optional list of handlers + stream: If True, stream all events. Can also provide StreamFilter flags. + **task_kwargs: Additional kwargs passed to Task creation + + Returns: + If not streaming: The task result + If streaming: Iterator of (event, snapshot, delta) tuples + """ return controlflow.run( objective=objective, agents=[self], turn_strategy=turn_strategy, handlers=handlers, + stream=stream, **task_kwargs, ) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index bf3adf7c..d0953ee1 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -214,7 +214,7 @@ class AgentContentDelta(UnpersistedEvent): class AgentToolCall(Event): - event: Literal["tool-call"] = "tool-call" + event: Literal["agent-tool-call"] = "agent-tool-call" agent: Agent agent_message_id: Optional[str] = None tool_call: Union[ToolCall, InvalidToolCall] diff --git a/src/controlflow/run.py b/src/controlflow/run.py index bd734d43..b0811682 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -41,8 +41,9 @@ def run_tasks( handlers: List of handlers to use for the tasks. model_kwargs: Keyword arguments to pass to the LLM. run_until: Condition to stop running tasks. - stream: If True, stream all events. Can also provide StreamFilter flags to filter specific events. - e.g. StreamFilter.CONTENT | StreamFilter.AGENT_TOOLS + stream: If True, stream all events (equivalent to Stream.ALL). + Can also provide Stream flags to filter specific events. + e.g. Stream.CONTENT | Stream.AGENT_TOOLS Returns: If not streaming: List of task results @@ -70,7 +71,7 @@ def run_tasks( if stream: # Convert True to ALL filter, otherwise use provided filter stream_filter = Stream.ALL if stream is True else stream - return filter_events(result, Stream._filter_names(stream_filter)) + return filter_events(result, stream_filter) if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py index d4d5650b..3fa69c5e 100644 --- a/src/controlflow/stream.py +++ b/src/controlflow/stream.py @@ -18,34 +18,22 @@ class Stream(Flag): Filter flags for event streaming. Can be combined using bitwise operators: - stream_filter = StreamFilter.CONTENT | StreamFilter.AGENT_TOOLS + stream_filter = Stream.CONTENT | Stream.AGENT_TOOLS """ NONE = 0 + ALL = auto() # All events CONTENT = auto() # Agent content and deltas AGENT_TOOLS = auto() # Non-completion tool events RESULTS = auto() # Completion tool events TOOLS = AGENT_TOOLS | RESULTS # All tool events - ALL = CONTENT | TOOLS # Everything - - @classmethod - def _filter_names(cls, filter: "Stream") -> list[str]: - """Convert StreamFilter to list of filter names for filter_events()""" - names = [] - if filter & cls.CONTENT: - names.append("content") - if filter & cls.AGENT_TOOLS: - names.append("agent_tools") - if filter & cls.RESULTS: - names.append("results") - return names def filter_events( - events: Iterator[Event], filters: list[str] + events: Iterator[Event], stream_filter: Stream ) -> Iterator[tuple[Event, Any, Optional[Any]]]: """ - Filter events based on a list of event types or shortcuts. + Filter events based on Stream flags. Returns tuples of (event, snapshot, delta) where snapshot and delta depend on the event type. Returns: @@ -58,6 +46,7 @@ def filter_events( - Content events: (event, full_text, new_text) - Tool calls: (event, tool_state, tool_delta) - Tool results: (event, result_state, None) + - Other events: (event, None, None) """ def is_completion_tool_event(event: Event) -> bool: @@ -71,66 +60,25 @@ def is_completion_tool_event(event: Event) -> bool: return tool and tool.metadata.get("is_completion_tool") - def is_agent_tool_event(event: Event) -> bool: - """Check if an event is related to a regular (non-completion) tool call.""" - if isinstance(event, ToolResult): - tool = event.tool_result.tool - elif isinstance(event, (AgentToolCall, AgentToolCallDelta)): - tool = event.tool - else: - return False + def should_include_event(event: Event) -> bool: + # Pass all events if ALL is specified + if stream_filter == Stream.ALL: + return True - return tool and not tool.metadata.get("is_completion_tool") - - # Expand shortcuts to event types and build filtering predicates - event_filters = [] - for filter_name in filters: - if filter_name == "content": - event_filters.append( - lambda e: e.event in {"agent-content", "agent-content-delta"} - ) - elif filter_name == "tools": - event_filters.append( - lambda e: e.event - in { - "agent-tool-call", - "agent-tool-call-delta", - "tool-result", - } - ) - elif filter_name == "agent_tools": - event_filters.append( - lambda e: ( - e.event - in { - "agent-tool-call", - "agent-tool-call-delta", - "tool-result", - } - and is_agent_tool_event(e) - ) - ) - elif filter_name == "results": - event_filters.append( - lambda e: ( - e.event - in { - "agent-tool-call", - "agent-tool-call-delta", - "tool-result", - } - and is_completion_tool_event(e) - ) - ) - else: - # Raw event type - event_filters.append(lambda e, t=filter_name: e.event == t) + # Content events + if isinstance(event, (AgentContent, AgentContentDelta)): + return bool(stream_filter & Stream.CONTENT) - def passes_filters(event: Event) -> bool: - return any(f(event) for f in event_filters) + # Tool events + if isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)): + if is_completion_tool_event(event): + return bool(stream_filter & Stream.RESULTS) + return bool(stream_filter & Stream.AGENT_TOOLS) + + return False for event in events: - if not passes_filters(event): + if not should_include_event(event): continue # Message events @@ -154,6 +102,6 @@ def passes_filters(event: Event) -> bool: # Tool result events elif isinstance(event, ToolResult): yield event, event.tool_result, None - else: + # Pass through any other events with no snapshot/delta yield event, None, None diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 2ea43730..c7e2dd1a 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -9,6 +9,7 @@ Callable, Generator, GenericAlias, + Iterator, Literal, Optional, TypeVar, @@ -49,26 +50,20 @@ unwrap, ) from controlflow.utilities.logging import get_logger -from controlflow.utilities.prefect import prefect_task as prefect_task if TYPE_CHECKING: + from controlflow.events.events import Event from controlflow.flows import Flow from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.turn_strategies import TurnStrategy + from controlflow.stream import Stream T = TypeVar("T") logger = get_logger(__name__) - COMPLETION_TOOLS = Literal["SUCCEED", "FAIL"] -def get_task_run_name(): - context = TaskRunContext.get() - task = context.parameters["self"] - return f"Task.run() ({task.friendly_name()})" - - class Labels(RootModel): root: tuple[Any, ...] @@ -392,7 +387,6 @@ def add_dependency(self, task: "Task"): self.depends_on.add(task) task._downstreams.add(self) - @prefect_task(task_run_name=get_task_run_name) def run( self, agent: Optional[Agent] = None, @@ -403,12 +397,27 @@ def run( handlers: list["Handler"] = None, raise_on_failure: bool = True, model_kwargs: Optional[dict] = None, - ) -> T: + stream: Union[bool, "Stream"] = False, + ) -> Union[T, Iterator[tuple["Event", Any, Optional[Any]]]]: """ Run the task - """ - controlflow.run_tasks( + Args: + agent: Optional agent to run the task + flow: Optional flow to run the task in + turn_strategy: Optional turn strategy to use + max_llm_calls: Maximum number of LLM calls to make + max_agent_turns: Maximum number of agent turns to make + handlers: Optional list of handlers + raise_on_failure: Whether to raise on task failure + model_kwargs: Optional kwargs to pass to the model + stream: If True, stream all events. Can also provide StreamFilter flags. + + Returns: + If not streaming: The task result + If streaming: Iterator of (event, snapshot, delta) tuples + """ + result = controlflow.run_tasks( tasks=[self], flow=flow, agent=agent, @@ -418,14 +427,16 @@ def run( raise_on_failure=False, handlers=handlers, model_kwargs=model_kwargs, + stream=stream, ) - if self.is_successful(): + if stream: + return result + elif self.is_successful(): return self.result elif raise_on_failure and self.is_failed(): raise ValueError(f"{self.friendly_name()} failed: {self.result}") - @prefect_task(task_run_name=get_task_run_name) async def run_async( self, agent: Optional[Agent] = None, diff --git a/tests/test_run.py b/tests/test_run.py index bbc65b4a..4128e748 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,5 +1,7 @@ +import pytest + import controlflow -from controlflow import instructions +from controlflow import Stream, instructions from controlflow.events.base import Event from controlflow.events.events import AgentMessage from controlflow.llm.messages import AIMessage @@ -173,14 +175,10 @@ async def test_min_failed(self): class TestRunStreaming: - def test_stream_all(self, default_fake_llm): - result = run("what's 2 + 2", stream=True, max_llm_calls=1) - r = list(result) - assert len(r) == 2 - assert r[0][0].event == "agent-content-delta" - assert r[1][0].event == "agent-content" + @pytest.fixture + def task(self, default_fake_llm): + task = controlflow.Task("say hello", id="12345") - def test_stream_content(self, default_fake_llm): response = AIMessage( id="run-2af8bb73-661f-4ec3-92ff-d7d8e3074926", name="Marvin", @@ -197,8 +195,41 @@ def test_stream_content(self, default_fake_llm): ) default_fake_llm.set_responses(["Hello!", response]) - result = run("say hello", stream=True, max_llm_calls=1) + + return task + + def test_stream_all(self, default_fake_llm): + result = run("what's 2 + 2", stream=True, max_llm_calls=1) r = list(result) - assert len(r) == 2 - assert r[0][0].event == "agent-content-delta" - assert r[1][0].event == "agent-content" + assert len(r) > 5 + + def test_stream_task(self, task): + result = list(task.run(stream=True)) + assert result[0][0].event == "orchestrator-start" + assert result[1][0].event == "agent-turn-start" + assert result[-1][0].event == "orchestrator-end" + assert any(r[0].event == "agent-message" for r in result) + assert any(r[0].event == "agent-message-delta" for r in result) + assert any(r[0].event == "agent-content" for r in result) + assert any(r[0].event == "agent-content-delta" for r in result) + assert any(r[0].event == "agent-tool-call" for r in result) + + def test_stream_content(self, task): + result = list(task.run(stream=Stream.CONTENT)) + assert all( + r[0].event in ("agent-content", "agent-content-delta") for r in result + ) + + def test_stream_tools(self, task): + result = list(task.run(stream=Stream.TOOLS)) + assert all( + r[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for r in result + ) + + def test_stream_results(self, task): + result = list(task.run(stream=Stream.RESULTS)) + assert all( + r[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for r in result + ) From abdc8406f2b3d8e9708ec43a6b2d5f4b7b674876 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 22:16:39 -0500 Subject: [PATCH 3/8] Add streaming --- docs/mint.json | 1 + docs/patterns/running-tasks.mdx | 40 +++ docs/patterns/streaming-tasks.mdx | 242 ++++++++++++++++++ src/controlflow/events/orchestrator_events.py | 5 +- src/controlflow/events/task_events.py | 33 +++ src/controlflow/flows/flow.py | 6 +- src/controlflow/orchestration/orchestrator.py | 176 ++++--------- src/controlflow/stream.py | 19 +- src/controlflow/tasks/task.py | 21 ++ tests/test_run.py | 2 +- 10 files changed, 408 insertions(+), 137 deletions(-) create mode 100644 docs/patterns/streaming-tasks.mdx create mode 100644 src/controlflow/events/task_events.py diff --git a/docs/mint.json b/docs/mint.json index 4a44f92d..53713d78 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -55,6 +55,7 @@ "patterns/running-tasks", "patterns/task-results", "patterns/tools", + "patterns/streaming-tasks", "patterns/interactivity", "patterns/dependencies", "patterns/memory", diff --git a/docs/patterns/running-tasks.mdx b/docs/patterns/running-tasks.mdx index 7d05519b..3227a7db 100644 --- a/docs/patterns/running-tasks.mdx +++ b/docs/patterns/running-tasks.mdx @@ -154,7 +154,47 @@ Crafting worlds and shaping dreams. +## Streaming + + +In addition to running tasks to completion, ControlFlow supports streaming events during task execution. This allows you to process or display intermediate outputs like agent messages, tool calls, and results in real-time. + +To enable streaming, set `stream=True` when running tasks: + +```python +import controlflow as cf + +# Stream all events +for event, snapshot, delta in cf.run("Write a poem", stream=True, handlers=[]): + print(f"Event type: {event.event}") + + if event.event == "agent-content": + print(f"Agent said: {snapshot}") + elif event.event == "agent-tool-call": + print(f"Tool called: {snapshot}") +``` + +You can also filter which events you want to receive using the `Stream` enum: + +```python +import controlflow as cf + +# Only stream content events +for event, content, delta in cf.run( + "Write a poem", + stream=cf.Stream.CONTENT, + handlers=[], # remove the default print handler +): + if delta: + # Print incremental content updates + print(delta, end="", flush=True) + else: + # Print complete messages + print(content) +``` + +For more details on working with streaming events, including programmatic event handlers, see the [Streaming guide](/patterns/streaming). ## Multi-Agent Collaboration For tasks involving multiple agents, ControlFlow needs a way to manage their collaboration. What makes this more complicated than simply making an LLM call and moving on to the next agent is that it may take multiple LLM calls to complete a single agentic "turn" of work. diff --git a/docs/patterns/streaming-tasks.mdx b/docs/patterns/streaming-tasks.mdx new file mode 100644 index 00000000..66109769 --- /dev/null +++ b/docs/patterns/streaming-tasks.mdx @@ -0,0 +1,242 @@ +--- +title: Streaming +description: Process agent responses, tool calls and results in real-time through streaming or handlers. +icon: wave-square +--- + +import { VersionBadge } from '/snippets/version-badge.mdx' + + +ControlFlow provides two ways to process events during task execution: +- [**Streaming**](#streaming): Iterate over events in real-time using a Python iterator +- [**Handlers**](#handlers): Register callback functions that are called for each event + +Both approaches give you access to the same events - which one you choose depends on how you want to integrate with your application. + +## Streaming + + + +When you enable streaming, task execution returns an iterator that yields events as they occur. Each iteration provides a tuple of (event, snapshot, delta) representing what just happened in the workflow: + +```python +import controlflow as cf + +for event, snapshot, delta in cf.run( + "Write a poem about AI", + stream=True, +): + # For complete events, snapshot contains the full content + if event.event == "agent-content": + print(f"Agent wrote: {snapshot}") + + # For delta events, delta contains just what's new + elif event.event == "agent-content-delta": + print(delta, end="", flush=True) +``` + +You can focus on specific events using the `Stream` enum: + +```python +from controlflow import Stream + +# Only stream content updates +for event, snapshot, delta in cf.run( + "Write a poem", + stream=Stream.CONTENT +): + print(delta if delta else snapshot) +``` + +The available stream filters are: +- `Stream.ALL`: All events (equivalent to `stream=True`) +- `Stream.CONTENT`: Agent content and content deltas +- `Stream.TOOLS`: All tool events +- `Stream.COMPLETION_TOOLS`: Completion tool events (like marking a task successful or failed) +- `Stream.AGENT_TOOLS`: Tools used by agents for any purpose other than completing a task +- `Stream.TASK_EVENTS`: Task state change events +You can combine filters with the `|` operator: + +```python +# Stream content and tool events +stream = Stream.CONTENT | Stream.TOOLS +``` + +## Handlers + + +For more complex event processing, or when you want to decouple event handling from your main workflow, use handlers: + +```python +from controlflow.orchestration.handler import Handler +from controlflow.events.events import AgentMessage + +class LoggingHandler(Handler): + def on_agent_message(self, event: AgentMessage): + print(f"Agent {event.agent.name} said: {event.message['content']}") + + def on_tool_result(self, event: ToolResult): + print(f"Tool call result: {event.tool_result.str_result}") + +# Use the handler +cf.run("Write a poem", handlers=[LoggingHandler()]) +``` + +Handlers are especially useful for: +- Adding logging or monitoring +- Collecting metrics +- Updating UI elements +- Processing events asynchronously + +Handlers call their `on_` methods for each event type. For a complete list of available methods, see the [Event Details](#event-details) section below. + + +### Async Handlers + + + +For asynchronous event processing, use `AsyncHandler`: + +```python +import asyncio +from controlflow.orchestration.handler import AsyncHandler + +class AsyncLoggingHandler(AsyncHandler): + async def on_agent_message(self, event: AgentMessage): + await asyncio.sleep(0.1) # Simulate async operation + print(f"Agent {event.agent.name} said: {event.message['content']}") + +await cf.run_async("Write a poem", handlers=[AsyncLoggingHandler()]) +``` + +## Example: Real-time Content Display + +Here's a complete example showing both approaches to display content in real-time: + + +```python Streaming +import controlflow as cf +from controlflow import Stream + +for event, snapshot, delta in cf.run( + "Write a story about time travel", + stream=Stream.CONTENT +): + # Print character by character + if delta: + print(delta, end="", flush=True) +``` + +```python Handler +import controlflow as cf +from controlflow.orchestration.handler import Handler + +class ContentHandler(Handler): + def on_agent_content_delta(self, event): + # Print character by character + print(event.content_delta, end="", flush=True) + +cf.run( + "Write a story about time travel", + handlers=[ContentHandler()] +) +``` + + +## Event Details + +Now that we've seen how to process events, let's look at the types of events you can receive: + +### Content Events +Content events give you access to what an agent is saying or writing: + +```python +# Complete content +{ + "event": "agent-content", + "agent": agent, # Agent object + "content": "Hello, world!", # The complete content + "agent_message_id": "msg_123" # Optional ID linking to parent message +} + +# Content delta (incremental update) +{ + "event": "agent-content-delta", + "agent": agent, + "content_delta": "Hello", # New content since last update + "content_snapshot": "Hello, world!", # Complete content so far + "agent_message_id": "msg_123" +} +``` + +### Tool Events +Tool events let you observe when agents use tools and get their results: + +```python +# Tool being called +{ + "event": "agent-tool-call", + "agent": agent, + "tool_call": {...}, # The complete tool call info + "tool": tool, # The Tool object being called + "args": {...}, # Arguments passed to the tool + "agent_message_id": "msg_123" +} + +# Tool call delta (incremental update) +{ + "event": "agent-tool-call-delta", + "agent": agent, + "tool_call_delta": {...}, # Changes to the tool call + "tool_call_snapshot": {...}, # Complete tool call info so far + "tool": tool, + "args": {...}, + "agent_message_id": "msg_123" +} + +# Tool result +{ + "event": "tool-result", + "agent": agent, + "tool_result": { + "tool_call": {...}, # The original tool call + "tool": tool, # The Tool object that was called + "result": any, # The raw result value + "str_result": "...", # String representation of result + "is_error": False # Whether the tool call failed + } +} +``` + +### Workflow Events +These events mark key points in task execution: +- `OrchestratorStart`/`End`: Workflow orchestration starting/ending +- `AgentTurnStart`/`End`: An agent's turn starting/ending +- `OrchestratorError`: An error occurred during orchestration + +### Handler Methods + +Each handler can implement methods for different types of events. The method will be called whenever that type of event occurs. Here are all available handler methods: + +| Method | Event Type | Description | +|--------|------------|-------------| +| `on_event(event)` | Any | Called for every event, before any specific handler | +| `on_agent_message(event)` | AgentMessage | Raw LLM output containing both content and tool calls | +| `on_agent_message_delta(event)` | AgentMessageDelta | Incremental updates to raw LLM output | +| `on_agent_content(event)` | AgentContent | Unstructured text output from an agent | +| `on_agent_content_delta(event)` | AgentContentDelta | Incremental updates to agent content | +| `on_agent_tool_call(event)` | AgentToolCall | Tool being called by an agent | +| `on_agent_tool_call_delta(event)` | AgentToolCallDelta | Incremental updates to a tool call | +| `on_tool_result(event)` | ToolResult | Result returned from a tool | +| `on_orchestrator_start(event)` | OrchestratorStart | Workflow orchestration starting | +| `on_orchestrator_end(event)` | OrchestratorEnd | Workflow orchestration completed | +| `on_agent_turn_start(event)` | AgentTurnStart | An agent beginning their turn | +| `on_agent_turn_end(event)` | AgentTurnEnd | An agent completing their turn | +| `on_orchestrator_error(event)` | OrchestratorError | Error during orchestration | + +Note that AgentMessage is the "raw" output from the LLM and contains both unstructured content and structured tool calls. When you receive an AgentMessage, you will also receive separate AgentContent and/or AgentToolCall events for any content or tool calls contained in that message. This allows you to: +1. Handle all LLM output in one place with `on_agent_message` +2. Handle just content with `on_agent_content` +3. Handle just tool calls with `on_agent_tool_call` + +For streaming cases, the delta events (e.g. AgentMessageDelta, AgentContentDelta) provide incremental updates as the LLM generates its response. diff --git a/src/controlflow/events/orchestrator_events.py b/src/controlflow/events/orchestrator_events.py index 88370f74..611dd18b 100644 --- a/src/controlflow/events/orchestrator_events.py +++ b/src/controlflow/events/orchestrator_events.py @@ -1,16 +1,17 @@ from dataclasses import Field -from typing import TYPE_CHECKING, Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional from pydantic.functional_serializers import PlainSerializer from controlflow.agents.agent import Agent -from controlflow.events.base import UnpersistedEvent +from controlflow.events.base import Event, UnpersistedEvent if TYPE_CHECKING: from controlflow.orchestration.conditions import RunContext from controlflow.orchestration.orchestrator import Orchestrator +# Orchestrator events class OrchestratorStart(UnpersistedEvent): event: Literal["orchestrator-start"] = "orchestrator-start" persist: bool = False diff --git a/src/controlflow/events/task_events.py b/src/controlflow/events/task_events.py new file mode 100644 index 00000000..dbc2598f --- /dev/null +++ b/src/controlflow/events/task_events.py @@ -0,0 +1,33 @@ +from dataclasses import Field +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional + +from pydantic.functional_serializers import PlainSerializer + +from controlflow.events.base import UnpersistedEvent +from controlflow.tasks.task import Task + + +# Task events +class TaskStart(UnpersistedEvent): + event: Literal["task-start"] = "task-start" + task: Task + + +class TaskSuccess(UnpersistedEvent): + event: Literal["task-success"] = "task-success" + task: Task + result: Annotated[ + Any, + PlainSerializer(lambda x: str(x) if x else None, return_type=Optional[str]), + ] = None + + +class TaskFailure(UnpersistedEvent): + event: Literal["task-failure"] = "task-failure" + task: Task + reason: Optional[str] = None + + +class TaskSkipped(UnpersistedEvent): + event: Literal["task-skipped"] = "task-skipped" + task: Task diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index def848fc..55916b2e 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -112,9 +112,11 @@ def get_events( return events def add_events(self, events: list[Event]): - for event in events: + persist_events = [e for e in events if e.persist] + for event in persist_events: event.thread_id = self.thread_id - self.history.add_events(thread_id=self.thread_id, events=events) + if persist_events: + self.history.add_events(thread_id=self.thread_id, events=persist_events) @contextmanager def create_context(self, **prefect_kwargs) -> Generator[Self, None, None]: diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index a4f68242..d3e508d6 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,4 +1,5 @@ import logging +from contextlib import contextmanager from typing import AsyncIterator, Callable, Iterator, Optional, Set, TypeVar, Union from pydantic import BaseModel, Field, PrivateAttr, field_validator @@ -31,6 +32,7 @@ from controlflow.orchestration.turn_strategies import Popcorn, TurnStrategy from controlflow.tasks.task import Task from controlflow.tools.tools import Tool, as_tools +from controlflow.utilities.context import ctx from controlflow.utilities.general import ControlFlowModel from controlflow.utilities.prefect import prefect_task @@ -231,12 +233,13 @@ def _run_agent_turn( tools = self.get_tools() # Run model and yield events - for event in self.agent._run_model( - messages=messages, - tools=tools, - model_kwargs=model_kwargs, - ): - yield event + with ctx(orchestrator=self): + for event in self.agent._run_model( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event run_context.llm_calls += 1 for task in assigned_tasks: @@ -298,6 +301,12 @@ def _run( OrchestratorEnd(orchestrator=self, run_context=run_context) ) + @contextmanager + def create_context(self): + """Create a context with this orchestrator.""" + with ctx(orchestrator=self): + yield self + def run( self, max_llm_calls: Optional[int] = None, @@ -340,17 +349,25 @@ def run( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - iterator = self._run( - run_context=run_context, - model_kwargs=model_kwargs, - ) - + # If streaming, return a generator that maintains the context if stream: - return iterator - # Consume iterator if not streaming - for _ in iterator: - pass + def event_generator(): + with ctx(orchestrator=self): + yield from self._run( + run_context=run_context, + model_kwargs=model_kwargs, + ) + + return event_generator() + + # If not streaming, consume events within the context + with ctx(orchestrator=self): + for _ in self._run( + run_context=run_context, + model_kwargs=model_kwargs, + ): + pass return run_context @prefect_task(task_run_name="Run agent orchestrator") @@ -451,126 +468,27 @@ async def run_async( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - iterator = self._run_async( - run_context=run_context, - model_kwargs=model_kwargs, - ) - + # If streaming, return an async generator that maintains the context if stream: - return iterator - - # Consume iterator if not streaming - async for _ in iterator: - pass - return run_context - - @prefect_task(task_run_name="Agent turn: {self.agent.name}") - def run_agent_turn( - self, - run_context: RunContext, - model_kwargs: Optional[dict] = None, - ) -> Iterator[Event]: - """ - Run a single agent turn, which may consist of multiple LLM calls. - """ - assigned_tasks = self.get_tasks("assigned") - - self.turn_strategy.begin_turn() - - # Mark assigned tasks as running - for task in assigned_tasks: - if not task.is_running(): - task.mark_running() - yield OrchestratorMessage( - content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " - f"with objective: {task.objective}" - ) - - while not self.turn_strategy.should_end_turn(): - # fail any tasks that have reached their max llm calls - for task in assigned_tasks: - if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: - task.mark_failed(reason="Max LLM calls reached for this task.") - - # Check if there are any ready tasks left - if not any(t.is_ready() for t in assigned_tasks): - logger.debug("No `ready` tasks to run") - break - - if run_context.should_end(): - break - - messages = self.compile_messages() - tools = self.get_tools() - - for event in self.agent._run_model( - messages=messages, - tools=tools, - model_kwargs=model_kwargs, - ): - yield event - - run_context.llm_calls += 1 - for task in assigned_tasks: - task._llm_calls += 1 - - run_context.agent_turns += 1 - - @prefect_task(task_run_name="Agent turn: {self.agent.name}") - async def run_agent_turn_async( - self, - run_context: RunContext, - model_kwargs: Optional[dict] = None, - ) -> AsyncIterator[Event]: - """ - Run a single agent turn asynchronously, which may consist of multiple LLM calls. - - Args: - max_llm_calls (Optional[int]): The number of LLM calls allowed. - - """ - assigned_tasks = self.get_tasks("assigned") - - self.turn_strategy.begin_turn() - - # Mark assigned tasks as running - for task in assigned_tasks: - if not task.is_running(): - task.mark_running() - yield OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) " - f"with objective: {task.objective}" - ) - while not self.turn_strategy.should_end_turn(): - # fail any tasks that have reached their max llm calls - for task in assigned_tasks: - if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: - task.mark_failed(reason="Max LLM calls reached for this task.") + async def event_generator(): + with ctx(orchestrator=self): + async for event in self._run_async( + run_context=run_context, + model_kwargs=model_kwargs, + ): + yield event - # Check if there are any ready tasks left - if not any(t.is_ready() for t in assigned_tasks): - logger.debug("No `ready` tasks to run") - break - - if run_context.should_end(): - break + return event_generator() - messages = self.compile_messages() - tools = self.get_tools() - - async for event in self.agent._run_model_async( - messages=messages, - tools=tools, + # If not streaming, consume events within the context + with ctx(orchestrator=self): + async for _ in self._run_async( + run_context=run_context, model_kwargs=model_kwargs, ): - yield event - - run_context.llm_calls += 1 - for task in assigned_tasks: - task._llm_calls += 1 - - run_context.agent_turns += 1 + pass + return run_context def compile_prompt(self) -> str: """ diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py index 3fa69c5e..65db7e22 100644 --- a/src/controlflow/stream.py +++ b/src/controlflow/stream.py @@ -11,6 +11,12 @@ Event, ToolResult, ) +from controlflow.events.task_events import ( + TaskFailure, + TaskSkipped, + TaskStart, + TaskSuccess, +) class Stream(Flag): @@ -25,8 +31,9 @@ class Stream(Flag): ALL = auto() # All events CONTENT = auto() # Agent content and deltas AGENT_TOOLS = auto() # Non-completion tool events - RESULTS = auto() # Completion tool events - TOOLS = AGENT_TOOLS | RESULTS # All tool events + COMPLETION_TOOLS = auto() # Completion tool events + TOOLS = AGENT_TOOLS | COMPLETION_TOOLS # All tool events + TASK_EVENTS = auto() # Task state change events def filter_events( @@ -46,6 +53,7 @@ def filter_events( - Content events: (event, full_text, new_text) - Tool calls: (event, tool_state, tool_delta) - Tool results: (event, result_state, None) + - Task events: (event, task_state, None) - Other events: (event, None, None) """ @@ -72,9 +80,13 @@ def should_include_event(event: Event) -> bool: # Tool events if isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)): if is_completion_tool_event(event): - return bool(stream_filter & Stream.RESULTS) + return bool(stream_filter & Stream.COMPLETION_TOOLS) return bool(stream_filter & Stream.AGENT_TOOLS) + # Task events + if isinstance(event, (TaskStart, TaskSuccess, TaskFailure, TaskSkipped)): + return bool(stream_filter & Stream.TASK_EVENTS) + return False for event in events: @@ -102,6 +114,7 @@ def should_include_event(event: Event) -> bool: # Tool result events elif isinstance(event, ToolResult): yield event, event.tool_result, None + else: # Pass through any other events with no snapshot/delta yield event, None, None diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index c7e2dd1a..3a07d471 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -575,18 +575,39 @@ def set_status(self, status: TaskStatus): tui.update_task(self) def mark_running(self): + """Mark the task as running and emit a TaskStart event.""" self.set_status(TaskStatus.RUNNING) + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskStart + + orchestrator.handle_event(TaskStart(task=self)) def mark_successful(self, result: T = None): + """Mark the task as successful and emit a TaskSuccess event.""" self.result = self.validate_result(result) self.set_status(TaskStatus.SUCCESSFUL) + breakpoint() + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskSuccess + + orchestrator.handle_event(TaskSuccess(task=self, result=result)) def mark_failed(self, reason: Optional[str] = None): + """Mark the task as failed and emit a TaskFailure event.""" self.result = reason self.set_status(TaskStatus.FAILED) + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskFailure + + orchestrator.handle_event(TaskFailure(task=self, reason=reason)) def mark_skipped(self): + """Mark the task as skipped and emit a TaskSkipped event.""" self.set_status(TaskStatus.SKIPPED) + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskSkipped + + orchestrator.handle_event(TaskSkipped(task=self)) def get_success_tool(self) -> Tool: """ diff --git a/tests/test_run.py b/tests/test_run.py index 4128e748..5e3e29dd 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -228,7 +228,7 @@ def test_stream_tools(self, task): ) def test_stream_results(self, task): - result = list(task.run(stream=Stream.RESULTS)) + result = list(task.run(stream=Stream.COMPLETION_TOOLS)) assert all( r[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") for r in result From 29cffb073cbe8642a0fe1e1b87fad178b56fc9c7 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:26:35 -0500 Subject: [PATCH 4/8] Add task events --- docs/patterns/streaming-tasks.mdx | 15 +- src/controlflow/orchestration/orchestrator.py | 174 +++++++++--------- src/controlflow/tasks/task.py | 9 +- 3 files changed, 106 insertions(+), 92 deletions(-) diff --git a/docs/patterns/streaming-tasks.mdx b/docs/patterns/streaming-tasks.mdx index 66109769..d30ec458 100644 --- a/docs/patterns/streaming-tasks.mdx +++ b/docs/patterns/streaming-tasks.mdx @@ -54,7 +54,8 @@ The available stream filters are: - `Stream.TOOLS`: All tool events - `Stream.COMPLETION_TOOLS`: Completion tool events (like marking a task successful or failed) - `Stream.AGENT_TOOLS`: Tools used by agents for any purpose other than completing a task -- `Stream.TASK_EVENTS`: Task state change events +- `Stream.TASK_EVENTS`: Task lifecycle events (starting, completion, failure, etc) + You can combine filters with the `|` operator: ```python @@ -209,7 +210,15 @@ Tool events let you observe when agents use tools and get their results: ``` ### Workflow Events -These events mark key points in task execution: +### Task Events +Events that mark key points in a task's lifecycle: +- `TaskStart`: A task has begun execution +- `TaskSuccess`: A task completed successfully (includes the final result) +- `TaskFailure`: A task failed (includes the error reason) +- `TaskSkipped`: A task was skipped + +### Orchestration Events +Events related to orchestrating the overall workflow: - `OrchestratorStart`/`End`: Workflow orchestration starting/ending - `AgentTurnStart`/`End`: An agent's turn starting/ending - `OrchestratorError`: An error occurred during orchestration @@ -239,4 +248,4 @@ Note that AgentMessage is the "raw" output from the LLM and contains both unstru 2. Handle just content with `on_agent_content` 3. Handle just tool calls with `on_agent_tool_call` -For streaming cases, the delta events (e.g. AgentMessageDelta, AgentContentDelta) provide incremental updates as the LLM generates its response. +For streaming cases, the delta events (e.g. AgentMessageDelta, AgentContentDelta) provide incremental updates as the LLM generates its response. Task events, in contrast, are complete events that mark important points in a task's lifecycle - you can use these to track progress and get results without managing the task object directly.. diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index d3e508d6..277c1c70 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,4 +1,5 @@ import logging +from collections import deque from contextlib import contextmanager from typing import AsyncIterator, Callable, Iterator, Optional, Set, TypeVar, Union @@ -64,6 +65,7 @@ class Orchestrator(ControlFlowModel): None, validate_default=True, exclude=True ) _processed_event_ids: Set[str] = PrivateAttr(default_factory=set) + _pending_events: deque[Event] = PrivateAttr(default_factory=deque) @field_validator("turn_strategy", mode="before") def _validate_turn_strategy(cls, v): @@ -93,16 +95,15 @@ def _validate_handlers(cls, v): ] return v or [] + def add_event(self, event: Event) -> None: + """Add an event to be handled and yielded during the next run loop iteration""" + self._pending_events.append(event) + def handle_event(self, event: Event) -> Event: """ Handle an event by passing it to all handlers and persisting if necessary. Includes idempotency check to prevent double-processing events. - - Args: - event (Event): The event to handle. """ - from controlflow.events.events import AgentContentDelta - # Skip if we've already processed this event if event.id in self._processed_event_ids: return event @@ -261,29 +262,36 @@ def _run( ) # Signal the start of orchestration - yield self.handle_event( - OrchestratorStart(orchestrator=self, run_context=run_context) - ) + start_event = OrchestratorStart(orchestrator=self, run_context=run_context) + self.handle_event(start_event) + yield start_event try: while True: if run_context.should_end(): break - yield self.handle_event( - AgentTurnStart(orchestrator=self, agent=self.agent) - ) + turn_start = AgentTurnStart(orchestrator=self, agent=self.agent) + self.handle_event(turn_start) + yield turn_start # Run turn and yield its events for event in self._run_agent_turn( run_context=run_context, model_kwargs=model_kwargs, ): - yield self.handle_event(event) + self.handle_event(event) + yield event - yield self.handle_event( - AgentTurnEnd(orchestrator=self, agent=self.agent) - ) + # Handle any events added during the turn + while self._pending_events: + event = self._pending_events.popleft() + self.handle_event(event) + yield event + + turn_end = AgentTurnEnd(orchestrator=self, agent=self.agent) + self.handle_event(turn_end) + yield turn_end # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -293,19 +301,21 @@ def _run( except Exception as exc: # Yield error event if something goes wrong - yield self.handle_event(OrchestratorError(orchestrator=self, error=exc)) + error_event = OrchestratorError(orchestrator=self, error=exc) + self.handle_event(error_event) + yield error_event raise finally: # Signal the end of orchestration - yield self.handle_event( - OrchestratorEnd(orchestrator=self, run_context=run_context) - ) - - @contextmanager - def create_context(self): - """Create a context with this orchestrator.""" - with ctx(orchestrator=self): - yield self + end_event = OrchestratorEnd(orchestrator=self, run_context=run_context) + self.handle_event(end_event) + yield end_event + + # Handle any final pending events + while self._pending_events: + event = self._pending_events.popleft() + self.handle_event(event) + yield event def run( self, @@ -349,26 +359,17 @@ def run( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - # If streaming, return a generator that maintains the context - if stream: - - def event_generator(): - with ctx(orchestrator=self): - yield from self._run( - run_context=run_context, - model_kwargs=model_kwargs, - ) - - return event_generator() + result = self._run( + run_context=run_context, + model_kwargs=model_kwargs, + ) - # If not streaming, consume events within the context - with ctx(orchestrator=self): - for _ in self._run( - run_context=run_context, - model_kwargs=model_kwargs, - ): + if stream: + return result + else: + for _ in result: pass - return run_context + return run_context @prefect_task(task_run_name="Run agent orchestrator") async def _run_async( @@ -384,29 +385,36 @@ async def _run_async( ) # Signal the start of orchestration - yield await self.handle_event_async( - OrchestratorStart(orchestrator=self, run_context=run_context) - ) + start_event = OrchestratorStart(orchestrator=self, run_context=run_context) + await self.handle_event_async(start_event) + yield start_event try: while True: if run_context.should_end(): break - yield await self.handle_event_async( - AgentTurnStart(orchestrator=self, agent=self.agent) - ) + turn_start = AgentTurnStart(orchestrator=self, agent=self.agent) + await self.handle_event_async(turn_start) + yield turn_start # Run turn and yield its events async for event in self._run_agent_turn_async( run_context=run_context, model_kwargs=model_kwargs, ): - yield await self.handle_event_async(event) + await self.handle_event_async(event) + yield event - yield await self.handle_event_async( - AgentTurnEnd(orchestrator=self, agent=self.agent) - ) + # Handle any events added during the turn + while self._pending_events: + event = self._pending_events.popleft() + await self.handle_event_async(event) + yield event + + turn_end = AgentTurnEnd(orchestrator=self, agent=self.agent) + await self.handle_event_async(turn_end) + yield turn_end # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -416,15 +424,21 @@ async def _run_async( except Exception as exc: # Yield error event if something goes wrong - yield await self.handle_event_async( - OrchestratorError(orchestrator=self, error=exc) - ) + error_event = OrchestratorError(orchestrator=self, error=exc) + await self.handle_event_async(error_event) + yield error_event raise finally: # Signal the end of orchestration - yield await self.handle_event_async( - OrchestratorEnd(orchestrator=self, run_context=run_context) - ) + end_event = OrchestratorEnd(orchestrator=self, run_context=run_context) + await self.handle_event_async(end_event) + yield end_event + + # Handle any final pending events + while self._pending_events: + event = self._pending_events.popleft() + await self.handle_event_async(event) + yield event async def run_async( self, @@ -468,27 +482,17 @@ async def run_async( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - # If streaming, return an async generator that maintains the context - if stream: + result = self._run_async( + run_context=run_context, + model_kwargs=model_kwargs, + ) - async def event_generator(): - with ctx(orchestrator=self): - async for event in self._run_async( - run_context=run_context, - model_kwargs=model_kwargs, - ): - yield event - - return event_generator() - - # If not streaming, consume events within the context - with ctx(orchestrator=self): - async for _ in self._run_async( - run_context=run_context, - model_kwargs=model_kwargs, - ): + if stream: + return result + else: + for _ in result: pass - return run_context + return run_context def compile_prompt(self) -> str: """ @@ -657,12 +661,14 @@ async def _run_agent_turn_async( messages = self.compile_messages() tools = self.get_tools() - async for event in self.agent._run_model_async( - messages=messages, - tools=tools, - model_kwargs=model_kwargs, - ): - yield event + # Run model and yield events + with ctx(orchestrator=self): + async for event in self.agent._run_model_async( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event run_context.llm_calls += 1 for task in assigned_tasks: diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 3a07d471..6603f840 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -580,17 +580,16 @@ def mark_running(self): if orchestrator := ctx.get("orchestrator"): from controlflow.events.task_events import TaskStart - orchestrator.handle_event(TaskStart(task=self)) + orchestrator.add_event(TaskStart(task=self)) def mark_successful(self, result: T = None): """Mark the task as successful and emit a TaskSuccess event.""" self.result = self.validate_result(result) self.set_status(TaskStatus.SUCCESSFUL) - breakpoint() if orchestrator := ctx.get("orchestrator"): from controlflow.events.task_events import TaskSuccess - orchestrator.handle_event(TaskSuccess(task=self, result=result)) + orchestrator.add_event(TaskSuccess(task=self, result=result)) def mark_failed(self, reason: Optional[str] = None): """Mark the task as failed and emit a TaskFailure event.""" @@ -599,7 +598,7 @@ def mark_failed(self, reason: Optional[str] = None): if orchestrator := ctx.get("orchestrator"): from controlflow.events.task_events import TaskFailure - orchestrator.handle_event(TaskFailure(task=self, reason=reason)) + orchestrator.add_event(TaskFailure(task=self, reason=reason)) def mark_skipped(self): """Mark the task as skipped and emit a TaskSkipped event.""" @@ -607,7 +606,7 @@ def mark_skipped(self): if orchestrator := ctx.get("orchestrator"): from controlflow.events.task_events import TaskSkipped - orchestrator.handle_event(TaskSkipped(task=self)) + orchestrator.add_event(TaskSkipped(task=self)) def get_success_tool(self) -> Tool: """ From 981d6c39d35606f0a161329aa522ba43f2799342 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:53:14 -0500 Subject: [PATCH 5/8] Add async tests --- src/controlflow/agents/agent.py | 2 + src/controlflow/orchestration/orchestrator.py | 45 ------ src/controlflow/run.py | 44 +++--- src/controlflow/stream.py | 146 +++++++++--------- src/controlflow/tasks/task.py | 8 +- tests/test_run.py | 55 ++++++- 6 files changed, 152 insertions(+), 148 deletions(-) diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 185a833f..6e2ab985 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -257,6 +257,7 @@ async def run_async( *, turn_strategy: Optional["TurnStrategy"] = None, handlers: Optional[list[Union["Handler", "AsyncHandler"]]] = None, + stream: Union[bool, "Stream"] = False, **task_kwargs, ): return await controlflow.run_async( @@ -264,6 +265,7 @@ async def run_async( agents=[self], turn_strategy=turn_strategy, handlers=handlers, + stream=stream, **task_kwargs, ) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 277c1c70..29628d18 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -676,51 +676,6 @@ async def _run_agent_turn_async( run_context.agent_turns += 1 - async def _run_async( - self, - run_context: RunContext, - model_kwargs: Optional[dict] = None, - ) -> AsyncIterator[Event]: - """Run the orchestrator asynchronously, yielding events as they occur.""" - # Initialize the agent if not already set - if not self.agent: - self.agent = self.turn_strategy.get_next_agent( - None, self.get_available_agents() - ) - - # Signal the start of orchestration - yield OrchestratorStart(orchestrator=self, run_context=run_context) - - try: - while True: - if run_context.should_end(): - break - - yield AgentTurnStart(orchestrator=self, agent=self.agent) - - # Run turn and yield its events - async for event in self._run_agent_turn_async( - run_context=run_context, - model_kwargs=model_kwargs, - ): - yield event - - yield AgentTurnEnd(orchestrator=self, agent=self.agent) - - # Select the next agent for the following turn - if available_agents := self.get_available_agents(): - self.agent = self.turn_strategy.get_next_agent( - self.agent, available_agents - ) - - except Exception as exc: - # Yield error event if something goes wrong - yield OrchestratorError(orchestrator=self, error=exc) - raise - finally: - # Signal the end of orchestration - yield OrchestratorEnd(orchestrator=self, run_context=run_context) - # Rebuild all models with forward references after Orchestrator is defined OrchestratorStart.model_rebuild() diff --git a/src/controlflow/run.py b/src/controlflow/run.py index b0811682..2362abea 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterator, Optional, Union +from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union import controlflow from controlflow.agents.agent import Agent @@ -7,7 +7,7 @@ from controlflow.orchestration.conditions import RunContext, RunEndCondition from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy -from controlflow.stream import Stream, filter_events +from controlflow.stream import Stream, filter_events_async, filter_events_sync from controlflow.tasks.task import Task from controlflow.utilities.prefect import prefect_task @@ -28,26 +28,6 @@ def run_tasks( ) -> Union[list[Any], Iterator[tuple[Event, Any, Optional[Any]]]]: """ Run a list of tasks. - - Args: - tasks: List of tasks to run. - instructions: Instructions for the tasks. - flow: Flow to run the tasks in. - agent: Agent to run the tasks with. - turn_strategy: Turn strategy to use for the tasks. - raise_on_failure: Whether to raise an error if any tasks fail. - max_llm_calls: Maximum number of LLM calls to make. - max_agent_turns: Maximum number of agent turns to make. - handlers: List of handlers to use for the tasks. - model_kwargs: Keyword arguments to pass to the LLM. - run_until: Condition to stop running tasks. - stream: If True, stream all events (equivalent to Stream.ALL). - Can also provide Stream flags to filter specific events. - e.g. Stream.CONTENT | Stream.AGENT_TOOLS - - Returns: - If not streaming: List of task results - If streaming: Iterator of (event, snapshot, delta) tuples """ flow = flow or get_flow() or Flow() @@ -71,7 +51,7 @@ def run_tasks( if stream: # Convert True to ALL filter, otherwise use provided filter stream_filter = Stream.ALL if stream is True else stream - return filter_events(result, stream_filter) + return filter_events_sync(result, stream_filter) if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] @@ -96,7 +76,8 @@ async def run_tasks_async( handlers: list[Union[Handler, AsyncHandler]] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, -): + stream: Union[bool, Stream] = False, +) -> Union[list[Any], AsyncIterator[tuple[Event, Any, Optional[Any]]]]: """ Run a list of tasks asynchronously. """ @@ -110,13 +91,19 @@ async def run_tasks_async( ) with controlflow.instructions(instructions): - await orchestrator.run_async( + result = await orchestrator.run_async( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, model_kwargs=model_kwargs, run_until=run_until, + stream=bool(stream), ) + if stream: + # Convert True to ALL filter, otherwise use provided filter + stream_filter = Stream.ALL if stream is True else stream + return filter_events_async(result, stream_filter) + if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] if errors: @@ -186,6 +173,7 @@ async def run_async( handlers: list[Union[Handler, AsyncHandler]] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, + stream: Union[bool, Stream] = False, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -200,5 +188,9 @@ async def run_async( handlers=handlers, model_kwargs=model_kwargs, run_until=run_until, + stream=stream, ) - return results[0] + if stream: + return results + else: + return results[0] diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py index 65db7e22..f7a2bd2e 100644 --- a/src/controlflow/stream.py +++ b/src/controlflow/stream.py @@ -1,5 +1,5 @@ from enum import Flag, auto -from typing import Any, Iterator, Optional +from typing import Any, AsyncIterator, Iterator, Optional, Union from controlflow.events.events import ( AgentContent, @@ -36,85 +36,83 @@ class Stream(Flag): TASK_EVENTS = auto() # Task state change events -def filter_events( - events: Iterator[Event], stream_filter: Stream -) -> Iterator[tuple[Event, Any, Optional[Any]]]: - """ - Filter events based on Stream flags. - Returns tuples of (event, snapshot, delta) where snapshot and delta depend on the event type. - - Returns: - Iterator of (event, snapshot, delta) tuples where: - - event: The original event - - snapshot: Full state (e.g., complete message, tool state) - - delta: Incremental change (None for non-delta events) - - Patterns for different event types: - - Content events: (event, full_text, new_text) - - Tool calls: (event, tool_state, tool_delta) - - Tool results: (event, result_state, None) - - Task events: (event, task_state, None) - - Other events: (event, None, None) - """ - - def is_completion_tool_event(event: Event) -> bool: - """Check if an event is related to a completion tool call.""" - if isinstance(event, ToolResult): - tool = event.tool_result.tool - elif isinstance(event, (AgentToolCall, AgentToolCallDelta)): - tool = event.tool - else: - return False +def should_include_event(event: Event, stream_filter: Stream) -> bool: + """Determine if an event should be included based on the stream filter.""" + # Pass all events if ALL is specified + if stream_filter == Stream.ALL: + return True - return tool and tool.metadata.get("is_completion_tool") + # Content events + if isinstance(event, (AgentContent, AgentContentDelta)): + return bool(stream_filter & Stream.CONTENT) - def should_include_event(event: Event) -> bool: - # Pass all events if ALL is specified - if stream_filter == Stream.ALL: - return True + # Tool events + if isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)): + if is_completion_tool_event(event): + return bool(stream_filter & Stream.COMPLETION_TOOLS) + return bool(stream_filter & Stream.AGENT_TOOLS) - # Content events - if isinstance(event, (AgentContent, AgentContentDelta)): - return bool(stream_filter & Stream.CONTENT) + # Task events + if isinstance(event, (TaskStart, TaskSuccess, TaskFailure, TaskSkipped)): + return bool(stream_filter & Stream.TASK_EVENTS) - # Tool events - if isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)): - if is_completion_tool_event(event): - return bool(stream_filter & Stream.COMPLETION_TOOLS) - return bool(stream_filter & Stream.AGENT_TOOLS) + return False - # Task events - if isinstance(event, (TaskStart, TaskSuccess, TaskFailure, TaskSkipped)): - return bool(stream_filter & Stream.TASK_EVENTS) +def is_completion_tool_event(event: Event) -> bool: + """Check if an event is related to a completion tool call.""" + if isinstance(event, ToolResult): + tool = event.tool_result.tool + elif isinstance(event, (AgentToolCall, AgentToolCallDelta)): + tool = event.tool + else: return False + return tool and tool.metadata.get("is_completion_tool") + + +def process_event(event: Event) -> tuple[Event, Any, Optional[Any]]: + """Process a single event and return the appropriate tuple.""" + # Message events + if isinstance(event, AgentMessage): + return event, event.message, None + elif isinstance(event, AgentMessageDelta): + return event, event.message_snapshot, event.message_delta + + # Content events + elif isinstance(event, AgentContent): + return event, event.content, None + elif isinstance(event, AgentContentDelta): + return event, event.content_snapshot, event.content_delta + + # Tool call events + elif isinstance(event, AgentToolCall): + return event, event.tool_call, None + elif isinstance(event, AgentToolCallDelta): + return event, event.tool_call_snapshot, event.tool_call_delta + + # Tool result events + elif isinstance(event, ToolResult): + return event, event.tool_result, None + + else: + # Pass through any other events with no snapshot/delta + return event, None, None + + +def filter_events_sync( + events: Iterator[Event], stream_filter: Stream +) -> Iterator[tuple[Event, Any, Optional[Any]]]: + """Synchronously filter events based on Stream flags.""" for event in events: - if not should_include_event(event): - continue - - # Message events - if isinstance(event, AgentMessage): - yield event, event.message, None - elif isinstance(event, AgentMessageDelta): - yield event, event.message_snapshot, event.message_delta - - # Content events - elif isinstance(event, AgentContent): - yield event, event.content, None - elif isinstance(event, AgentContentDelta): - yield event, event.content_snapshot, event.content_delta - - # Tool call events - elif isinstance(event, AgentToolCall): - yield event, event.tool_call, None - elif isinstance(event, AgentToolCallDelta): - yield event, event.tool_call_snapshot, event.tool_call_delta - - # Tool result events - elif isinstance(event, ToolResult): - yield event, event.tool_result, None - - else: - # Pass through any other events with no snapshot/delta - yield event, None, None + if should_include_event(event, stream_filter): + yield process_event(event) + + +async def filter_events_async( + events: AsyncIterator[Event], stream_filter: Stream +) -> AsyncIterator[tuple[Event, Any, Optional[Any]]]: + """Asynchronously filter events based on Stream flags.""" + async for event in events: + if should_include_event(event, stream_filter): + yield process_event(event) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 6603f840..bf502492 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -446,12 +446,13 @@ async def run_async( max_agent_turns: int = None, handlers: list[Union["Handler", "AsyncHandler"]] = None, raise_on_failure: bool = True, + stream: Union[bool, "Stream"] = False, ) -> T: """ Run the task """ - await controlflow.run_tasks_async( + result = await controlflow.run_tasks_async( tasks=[self], flow=flow, agent=agent, @@ -460,9 +461,12 @@ async def run_async( max_agent_turns=max_agent_turns, raise_on_failure=False, handlers=handlers, + stream=stream, ) - if self.is_successful(): + if stream: + return result + elif self.is_successful(): return self.result elif raise_on_failure and self.is_failed(): raise ValueError(f"{self.friendly_name()} failed: {self.result}") diff --git a/tests/test_run.py b/tests/test_run.py index 5e3e29dd..c8f1cabd 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -175,6 +175,10 @@ async def test_min_failed(self): class TestRunStreaming: + # Helper function to collect async iterator results + async def collect_stream(self, ait): + return [x async for x in ait] + @pytest.fixture def task(self, default_fake_llm): task = controlflow.Task("say hello", id="12345") @@ -203,6 +207,11 @@ def test_stream_all(self, default_fake_llm): r = list(result) assert len(r) > 5 + async def test_stream_all_async(self, default_fake_llm): + result = await run_async("what's 2 + 2", stream=True, max_llm_calls=1) + r = await self.collect_stream(result) + assert len(r) > 5 + def test_stream_task(self, task): result = list(task.run(stream=True)) assert result[0][0].event == "orchestrator-start" @@ -214,12 +223,29 @@ def test_stream_task(self, task): assert any(r[0].event == "agent-content-delta" for r in result) assert any(r[0].event == "agent-tool-call" for r in result) + async def test_stream_task_async(self, task): + result = await task.run_async(stream=True) + r = await self.collect_stream(result) + assert r[0][0].event == "orchestrator-start" + assert r[1][0].event == "agent-turn-start" + assert r[-1][0].event == "orchestrator-end" + assert any(x[0].event == "agent-message" for x in r) + assert any(x[0].event == "agent-message-delta" for x in r) + assert any(x[0].event == "agent-content" for x in r) + assert any(x[0].event == "agent-content-delta" for x in r) + assert any(x[0].event == "agent-tool-call" for x in r) + def test_stream_content(self, task): result = list(task.run(stream=Stream.CONTENT)) assert all( r[0].event in ("agent-content", "agent-content-delta") for r in result ) + async def test_stream_content_async(self, task): + result = await task.run_async(stream=Stream.CONTENT) + r = await self.collect_stream(result) + assert all(x[0].event in ("agent-content", "agent-content-delta") for x in r) + def test_stream_tools(self, task): result = list(task.run(stream=Stream.TOOLS)) assert all( @@ -227,9 +253,36 @@ def test_stream_tools(self, task): for r in result ) - def test_stream_results(self, task): + async def test_stream_tools_async(self, task): + result = await task.run_async(stream=Stream.TOOLS) + r = await self.collect_stream(result) + assert all( + x[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for x in r + ) + + def test_stream_completion_tools(self, task): result = list(task.run(stream=Stream.COMPLETION_TOOLS)) assert all( r[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") for r in result ) + + async def test_stream_completion_tools_async(self, task): + result = await task.run_async(stream=Stream.COMPLETION_TOOLS) + r = await self.collect_stream(result) + assert all( + x[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for x in r + ) + + def test_stream_task_events(self, task): + result = list(task.run(stream=Stream.TASK_EVENTS)) + assert result[-1][0].event == "task-success" + assert result[0][0].task is task + + async def test_stream_task_events_async(self, task): + result = await task.run_async(stream=Stream.TASK_EVENTS) + r = await self.collect_stream(result) + assert r[-1][0].event == "task-success" + assert r[0][0].task is task From feadc6c493be318f55a1050bd6774e2b397e00b8 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:54:01 -0500 Subject: [PATCH 6/8] Update docs --- docs/patterns/streaming-tasks.mdx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/patterns/streaming-tasks.mdx b/docs/patterns/streaming-tasks.mdx index d30ec458..ce8d8405 100644 --- a/docs/patterns/streaming-tasks.mdx +++ b/docs/patterns/streaming-tasks.mdx @@ -1,7 +1,7 @@ --- title: Streaming description: Process agent responses, tool calls and results in real-time through streaming or handlers. -icon: wave-square +icon: bars-staggered --- import { VersionBadge } from '/snippets/version-badge.mdx' @@ -35,7 +35,7 @@ for event, snapshot, delta in cf.run( print(delta, end="", flush=True) ``` -You can focus on specific events using the `Stream` enum: +You can focus on specific events using the `Stream` enum. Here, we return only content updates: ```python from controlflow import Stream @@ -63,6 +63,8 @@ You can combine filters with the `|` operator: stream = Stream.CONTENT | Stream.TOOLS ``` +For more complex filtering, set stream=True and filter the events manually, or use a handler. + ## Handlers From 78063212b57df9f57a298aee729202dddd5f44d5 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:56:59 -0500 Subject: [PATCH 7/8] Update streaming-tasks.mdx --- docs/patterns/streaming-tasks.mdx | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/patterns/streaming-tasks.mdx b/docs/patterns/streaming-tasks.mdx index ce8d8405..090acf5c 100644 --- a/docs/patterns/streaming-tasks.mdx +++ b/docs/patterns/streaming-tasks.mdx @@ -38,12 +38,12 @@ for event, snapshot, delta in cf.run( You can focus on specific events using the `Stream` enum. Here, we return only content updates: ```python -from controlflow import Stream +import controlflow as cf # Only stream content updates for event, snapshot, delta in cf.run( "Write a poem", - stream=Stream.CONTENT + stream=cf.Stream.CONTENT ): print(delta if delta else snapshot) ``` @@ -119,11 +119,10 @@ Here's a complete example showing both approaches to display content in real-tim ```python Streaming import controlflow as cf -from controlflow import Stream for event, snapshot, delta in cf.run( "Write a story about time travel", - stream=Stream.CONTENT + stream=cf.Stream.CONTENT ): # Print character by character if delta: From 77638288df7d6c7293b559cc351914f2f7959ec2 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:57:37 -0500 Subject: [PATCH 8/8] Update orchestrator.py --- src/controlflow/orchestration/orchestrator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 29628d18..2fbdccec 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -490,7 +490,7 @@ async def run_async( if stream: return result else: - for _ in result: + async for _ in result: pass return run_context