diff --git a/docs/mint.json b/docs/mint.json index 4a44f92d..8fdb4180 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -55,6 +55,7 @@ "patterns/running-tasks", "patterns/task-results", "patterns/tools", + "patterns/streaming", "patterns/interactivity", "patterns/dependencies", "patterns/memory", diff --git a/docs/patterns/running-tasks.mdx b/docs/patterns/running-tasks.mdx index 7d05519b..3227a7db 100644 --- a/docs/patterns/running-tasks.mdx +++ b/docs/patterns/running-tasks.mdx @@ -154,7 +154,47 @@ Crafting worlds and shaping dreams. +## Streaming + + +In addition to running tasks to completion, ControlFlow supports streaming events during task execution. This allows you to process or display intermediate outputs like agent messages, tool calls, and results in real-time. + +To enable streaming, set `stream=True` when running tasks: + +```python +import controlflow as cf + +# Stream all events +for event, snapshot, delta in cf.run("Write a poem", stream=True, handlers=[]): + print(f"Event type: {event.event}") + + if event.event == "agent-content": + print(f"Agent said: {snapshot}") + elif event.event == "agent-tool-call": + print(f"Tool called: {snapshot}") +``` + +You can also filter which events you want to receive using the `Stream` enum: + +```python +import controlflow as cf + +# Only stream content events +for event, content, delta in cf.run( + "Write a poem", + stream=cf.Stream.CONTENT, + handlers=[], # remove the default print handler +): + if delta: + # Print incremental content updates + print(delta, end="", flush=True) + else: + # Print complete messages + print(content) +``` + +For more details on working with streaming events, including programmatic event handlers, see the [Streaming guide](/patterns/streaming). ## Multi-Agent Collaboration For tasks involving multiple agents, ControlFlow needs a way to manage their collaboration. What makes this more complicated than simply making an LLM call and moving on to the next agent is that it may take multiple LLM calls to complete a single agentic "turn" of work. diff --git a/docs/patterns/streaming.mdx b/docs/patterns/streaming.mdx new file mode 100644 index 00000000..090acf5c --- /dev/null +++ b/docs/patterns/streaming.mdx @@ -0,0 +1,252 @@ +--- +title: Streaming +description: Process agent responses, tool calls and results in real-time through streaming or handlers. +icon: bars-staggered +--- + +import { VersionBadge } from '/snippets/version-badge.mdx' + + +ControlFlow provides two ways to process events during task execution: +- [**Streaming**](#streaming): Iterate over events in real-time using a Python iterator +- [**Handlers**](#handlers): Register callback functions that are called for each event + +Both approaches give you access to the same events - which one you choose depends on how you want to integrate with your application. + +## Streaming + + + +When you enable streaming, task execution returns an iterator that yields events as they occur. Each iteration provides a tuple of (event, snapshot, delta) representing what just happened in the workflow: + +```python +import controlflow as cf + +for event, snapshot, delta in cf.run( + "Write a poem about AI", + stream=True, +): + # For complete events, snapshot contains the full content + if event.event == "agent-content": + print(f"Agent wrote: {snapshot}") + + # For delta events, delta contains just what's new + elif event.event == "agent-content-delta": + print(delta, end="", flush=True) +``` + +You can focus on specific events using the `Stream` enum. Here, we return only content updates: + +```python +import controlflow as cf + +# Only stream content updates +for event, snapshot, delta in cf.run( + "Write a poem", + stream=cf.Stream.CONTENT +): + print(delta if delta else snapshot) +``` + +The available stream filters are: +- `Stream.ALL`: All events (equivalent to `stream=True`) +- `Stream.CONTENT`: Agent content and content deltas +- `Stream.TOOLS`: All tool events +- `Stream.COMPLETION_TOOLS`: Completion tool events (like marking a task successful or failed) +- `Stream.AGENT_TOOLS`: Tools used by agents for any purpose other than completing a task +- `Stream.TASK_EVENTS`: Task lifecycle events (starting, completion, failure, etc) + +You can combine filters with the `|` operator: + +```python +# Stream content and tool events +stream = Stream.CONTENT | Stream.TOOLS +``` + +For more complex filtering, set stream=True and filter the events manually, or use a handler. + +## Handlers + + +For more complex event processing, or when you want to decouple event handling from your main workflow, use handlers: + +```python +from controlflow.orchestration.handler import Handler +from controlflow.events.events import AgentMessage + +class LoggingHandler(Handler): + def on_agent_message(self, event: AgentMessage): + print(f"Agent {event.agent.name} said: {event.message['content']}") + + def on_tool_result(self, event: ToolResult): + print(f"Tool call result: {event.tool_result.str_result}") + +# Use the handler +cf.run("Write a poem", handlers=[LoggingHandler()]) +``` + +Handlers are especially useful for: +- Adding logging or monitoring +- Collecting metrics +- Updating UI elements +- Processing events asynchronously + +Handlers call their `on_` methods for each event type. For a complete list of available methods, see the [Event Details](#event-details) section below. + + +### Async Handlers + + + +For asynchronous event processing, use `AsyncHandler`: + +```python +import asyncio +from controlflow.orchestration.handler import AsyncHandler + +class AsyncLoggingHandler(AsyncHandler): + async def on_agent_message(self, event: AgentMessage): + await asyncio.sleep(0.1) # Simulate async operation + print(f"Agent {event.agent.name} said: {event.message['content']}") + +await cf.run_async("Write a poem", handlers=[AsyncLoggingHandler()]) +``` + +## Example: Real-time Content Display + +Here's a complete example showing both approaches to display content in real-time: + + +```python Streaming +import controlflow as cf + +for event, snapshot, delta in cf.run( + "Write a story about time travel", + stream=cf.Stream.CONTENT +): + # Print character by character + if delta: + print(delta, end="", flush=True) +``` + +```python Handler +import controlflow as cf +from controlflow.orchestration.handler import Handler + +class ContentHandler(Handler): + def on_agent_content_delta(self, event): + # Print character by character + print(event.content_delta, end="", flush=True) + +cf.run( + "Write a story about time travel", + handlers=[ContentHandler()] +) +``` + + +## Event Details + +Now that we've seen how to process events, let's look at the types of events you can receive: + +### Content Events +Content events give you access to what an agent is saying or writing: + +```python +# Complete content +{ + "event": "agent-content", + "agent": agent, # Agent object + "content": "Hello, world!", # The complete content + "agent_message_id": "msg_123" # Optional ID linking to parent message +} + +# Content delta (incremental update) +{ + "event": "agent-content-delta", + "agent": agent, + "content_delta": "Hello", # New content since last update + "content_snapshot": "Hello, world!", # Complete content so far + "agent_message_id": "msg_123" +} +``` + +### Tool Events +Tool events let you observe when agents use tools and get their results: + +```python +# Tool being called +{ + "event": "agent-tool-call", + "agent": agent, + "tool_call": {...}, # The complete tool call info + "tool": tool, # The Tool object being called + "args": {...}, # Arguments passed to the tool + "agent_message_id": "msg_123" +} + +# Tool call delta (incremental update) +{ + "event": "agent-tool-call-delta", + "agent": agent, + "tool_call_delta": {...}, # Changes to the tool call + "tool_call_snapshot": {...}, # Complete tool call info so far + "tool": tool, + "args": {...}, + "agent_message_id": "msg_123" +} + +# Tool result +{ + "event": "tool-result", + "agent": agent, + "tool_result": { + "tool_call": {...}, # The original tool call + "tool": tool, # The Tool object that was called + "result": any, # The raw result value + "str_result": "...", # String representation of result + "is_error": False # Whether the tool call failed + } +} +``` + +### Workflow Events +### Task Events +Events that mark key points in a task's lifecycle: +- `TaskStart`: A task has begun execution +- `TaskSuccess`: A task completed successfully (includes the final result) +- `TaskFailure`: A task failed (includes the error reason) +- `TaskSkipped`: A task was skipped + +### Orchestration Events +Events related to orchestrating the overall workflow: +- `OrchestratorStart`/`End`: Workflow orchestration starting/ending +- `AgentTurnStart`/`End`: An agent's turn starting/ending +- `OrchestratorError`: An error occurred during orchestration + +### Handler Methods + +Each handler can implement methods for different types of events. The method will be called whenever that type of event occurs. Here are all available handler methods: + +| Method | Event Type | Description | +|--------|------------|-------------| +| `on_event(event)` | Any | Called for every event, before any specific handler | +| `on_agent_message(event)` | AgentMessage | Raw LLM output containing both content and tool calls | +| `on_agent_message_delta(event)` | AgentMessageDelta | Incremental updates to raw LLM output | +| `on_agent_content(event)` | AgentContent | Unstructured text output from an agent | +| `on_agent_content_delta(event)` | AgentContentDelta | Incremental updates to agent content | +| `on_agent_tool_call(event)` | AgentToolCall | Tool being called by an agent | +| `on_agent_tool_call_delta(event)` | AgentToolCallDelta | Incremental updates to a tool call | +| `on_tool_result(event)` | ToolResult | Result returned from a tool | +| `on_orchestrator_start(event)` | OrchestratorStart | Workflow orchestration starting | +| `on_orchestrator_end(event)` | OrchestratorEnd | Workflow orchestration completed | +| `on_agent_turn_start(event)` | AgentTurnStart | An agent beginning their turn | +| `on_agent_turn_end(event)` | AgentTurnEnd | An agent completing their turn | +| `on_orchestrator_error(event)` | OrchestratorError | Error during orchestration | + +Note that AgentMessage is the "raw" output from the LLM and contains both unstructured content and structured tool calls. When you receive an AgentMessage, you will also receive separate AgentContent and/or AgentToolCall events for any content or tool calls contained in that message. This allows you to: +1. Handle all LLM output in one place with `on_agent_message` +2. Handle just content with `on_agent_content` +3. Handle just tool calls with `on_agent_tool_call` + +For streaming cases, the delta events (e.g. AgentMessageDelta, AgentContentDelta) provide incremental updates as the LLM generates its response. Task events, in contrast, are complete events that mark important points in a task's lifecycle - you can use these to track progress and get results without managing the task object directly.. diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index 27a0218c..8fd513b7 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -16,7 +16,7 @@ from .instructions import instructions from .decorators import flow, task from .tools import tool -from .run import run, run_async, run_tasks, run_tasks_async +from .run import run, run_async, run_tasks, run_tasks_async, Stream from .plan import plan import controlflow.orchestration diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 6c4381a9..6e2ab985 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -8,6 +8,7 @@ Any, AsyncGenerator, Generator, + Iterator, Optional, Union, ) @@ -43,10 +44,11 @@ from controlflow.utilities.prefect import create_markdown_artifact, prefect_task if TYPE_CHECKING: + from controlflow.events.events import Event + from controlflow.flows import Flow from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.turn_strategies import TurnStrategy - from controlflow.tasks import Task - from controlflow.tools.tools import Tool + from controlflow.stream import Stream logger = logging.getLogger(__name__) @@ -223,13 +225,29 @@ def run( *, turn_strategy: Optional["TurnStrategy"] = None, handlers: Optional[list["Handler"]] = None, + stream: Union[bool, "Stream"] = False, **task_kwargs, - ): + ) -> Union[Any, Iterator[tuple["Event", Any, Optional[Any]]]]: + """ + Run a task with this agent. + + Args: + objective: The objective to accomplish + turn_strategy: Optional turn strategy to use + handlers: Optional list of handlers + stream: If True, stream all events. Can also provide StreamFilter flags. + **task_kwargs: Additional kwargs passed to Task creation + + Returns: + If not streaming: The task result + If streaming: Iterator of (event, snapshot, delta) tuples + """ return controlflow.run( objective=objective, agents=[self], turn_strategy=turn_strategy, handlers=handlers, + stream=stream, **task_kwargs, ) @@ -239,6 +257,7 @@ async def run_async( *, turn_strategy: Optional["TurnStrategy"] = None, handlers: Optional[list[Union["Handler", "AsyncHandler"]]] = None, + stream: Union[bool, "Stream"] = False, **task_kwargs, ): return await controlflow.run_async( @@ -246,6 +265,7 @@ async def run_async( agents=[self], turn_strategy=turn_strategy, handlers=handlers, + stream=stream, **task_kwargs, ) diff --git a/src/controlflow/events/base.py b/src/controlflow/events/base.py index aad788cb..ad794042 100644 --- a/src/controlflow/events/base.py +++ b/src/controlflow/events/base.py @@ -30,7 +30,7 @@ def to_messages(self, context: "CompileContext") -> list["BaseMessage"]: return [] def __repr__(self) -> str: - return f"<{self.event} {self.timestamp}>" + return f"" class UnpersistedEvent(Event): diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index b7c9f413..d0953ee1 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -90,15 +90,17 @@ def to_tool_calls(self, tools: list[Tool]) -> list["AgentToolCall"]: ) return calls - def to_content(self) -> "AgentContent": - return AgentContent( - agent=self.agent, - content=self.message["content"], - agent_message_id=self.message.get("id"), - ) + def to_content(self) -> Optional["AgentContent"]: + if self.message.get("content"): + return AgentContent( + agent=self.agent, + content=self.message["content"], + agent_message_id=self.message.get("id"), + ) def all_related_events(self, tools: list[Tool]) -> list[Event]: - return [self, self.to_content()] + self.to_tool_calls(tools) + content = self.to_content() + return [self] + ([content] if content else []) + self.to_tool_calls(tools) def to_messages(self, context: "CompileContext") -> list[BaseMessage]: if self.agent.name == context.agent.name: @@ -178,16 +180,22 @@ def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: ) return deltas - def to_content_delta(self) -> "AgentContentDelta": - return AgentContentDelta( - agent=self.agent, - content_delta=self.message_delta["content"], - content_snapshot=self.message_snapshot["content"], - agent_message_id=self.message_snapshot.get("id"), - ) + def to_content_delta(self) -> Optional["AgentContentDelta"]: + if self.message_delta.get("content"): + return AgentContentDelta( + agent=self.agent, + content_delta=self.message_delta["content"], + content_snapshot=self.message_snapshot["content"], + agent_message_id=self.message_snapshot.get("id"), + ) def all_related_events(self, tools: list[Tool]) -> list[Event]: - return [self, self.to_content_delta()] + self.to_tool_call_deltas(tools) + content_delta = self.to_content_delta() + return ( + [self] + + ([content_delta] if content_delta else []) + + self.to_tool_call_deltas(tools) + ) class AgentContent(UnpersistedEvent): @@ -206,7 +214,7 @@ class AgentContentDelta(UnpersistedEvent): class AgentToolCall(Event): - event: Literal["tool-call"] = "tool-call" + event: Literal["agent-tool-call"] = "agent-tool-call" agent: Agent agent_message_id: Optional[str] = None tool_call: Union[ToolCall, InvalidToolCall] diff --git a/src/controlflow/events/orchestrator_events.py b/src/controlflow/events/orchestrator_events.py index 88370f74..611dd18b 100644 --- a/src/controlflow/events/orchestrator_events.py +++ b/src/controlflow/events/orchestrator_events.py @@ -1,16 +1,17 @@ from dataclasses import Field -from typing import TYPE_CHECKING, Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional from pydantic.functional_serializers import PlainSerializer from controlflow.agents.agent import Agent -from controlflow.events.base import UnpersistedEvent +from controlflow.events.base import Event, UnpersistedEvent if TYPE_CHECKING: from controlflow.orchestration.conditions import RunContext from controlflow.orchestration.orchestrator import Orchestrator +# Orchestrator events class OrchestratorStart(UnpersistedEvent): event: Literal["orchestrator-start"] = "orchestrator-start" persist: bool = False diff --git a/src/controlflow/events/task_events.py b/src/controlflow/events/task_events.py new file mode 100644 index 00000000..dbc2598f --- /dev/null +++ b/src/controlflow/events/task_events.py @@ -0,0 +1,33 @@ +from dataclasses import Field +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional + +from pydantic.functional_serializers import PlainSerializer + +from controlflow.events.base import UnpersistedEvent +from controlflow.tasks.task import Task + + +# Task events +class TaskStart(UnpersistedEvent): + event: Literal["task-start"] = "task-start" + task: Task + + +class TaskSuccess(UnpersistedEvent): + event: Literal["task-success"] = "task-success" + task: Task + result: Annotated[ + Any, + PlainSerializer(lambda x: str(x) if x else None, return_type=Optional[str]), + ] = None + + +class TaskFailure(UnpersistedEvent): + event: Literal["task-failure"] = "task-failure" + task: Task + reason: Optional[str] = None + + +class TaskSkipped(UnpersistedEvent): + event: Literal["task-skipped"] = "task-skipped" + task: Task diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index def848fc..55916b2e 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -112,9 +112,11 @@ def get_events( return events def add_events(self, events: list[Event]): - for event in events: + persist_events = [e for e in events if e.persist] + for event in persist_events: event.thread_id = self.thread_id - self.history.add_events(thread_id=self.thread_id, events=events) + if persist_events: + self.history.add_events(thread_id=self.thread_id, events=persist_events) @contextmanager def create_context(self, **prefect_kwargs) -> Generator[Self, None, None]: diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 10d04269..2fbdccec 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,7 +1,9 @@ import logging -from typing import AsyncIterator, Callable, Iterator, Optional, TypeVar, Union +from collections import deque +from contextlib import contextmanager +from typing import AsyncIterator, Callable, Iterator, Optional, Set, TypeVar, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, PrivateAttr, field_validator import controlflow from controlflow.agents.agent import Agent @@ -31,6 +33,7 @@ from controlflow.orchestration.turn_strategies import Popcorn, TurnStrategy from controlflow.tasks.task import Task from controlflow.tools.tools import Tool, as_tools +from controlflow.utilities.context import ctx from controlflow.utilities.general import ControlFlowModel from controlflow.utilities.prefect import prefect_task @@ -61,6 +64,8 @@ class Orchestrator(ControlFlowModel): handlers: list[Union[Handler, AsyncHandler]] = Field( None, validate_default=True, exclude=True ) + _processed_event_ids: Set[str] = PrivateAttr(default_factory=set) + _pending_events: deque[Event] = PrivateAttr(default_factory=deque) @field_validator("turn_strategy", mode="before") def _validate_turn_strategy(cls, v): @@ -90,14 +95,18 @@ def _validate_handlers(cls, v): ] return v or [] - def handle_event(self, event: Event): + def add_event(self, event: Event) -> None: + """Add an event to be handled and yielded during the next run loop iteration""" + self._pending_events.append(event) + + def handle_event(self, event: Event) -> Event: """ Handle an event by passing it to all handlers and persisting if necessary. - - Args: - event (Event): The event to handle. + Includes idempotency check to prevent double-processing events. """ - from controlflow.events.events import AgentContentDelta + # Skip if we've already processed this event + if event.id in self._processed_event_ids: + return event for handler in self.handlers: if isinstance(handler, Handler): @@ -105,13 +114,22 @@ def handle_event(self, event: Event): if event.persist: self.flow.add_events([event]) - async def handle_event_async(self, event: Event): + # Mark event as processed + self._processed_event_ids.add(event.id) + return event + + async def handle_event_async(self, event: Event) -> Event: """ Handle an event asynchronously by passing it to all handlers and persisting if necessary. + Includes idempotency check to prevent double-processing events. Args: event (Event): The event to handle. """ + # Skip if we've already processed this event + if event.id in self._processed_event_ids: + return event + if not isinstance(event, AgentMessageDelta): logger.debug(f"Handling event asynchronously: {repr(event)}") for handler in self.handlers: @@ -122,6 +140,10 @@ async def handle_event_async(self, event: Event): if event.persist: self.flow.add_events([event]) + # Mark event as processed + self._processed_event_ids.add(event.id) + return event + def get_available_agents(self) -> dict[Agent, list[Task]]: """ Get a dictionary of all available agents for active tasks, mapped to @@ -212,12 +234,13 @@ def _run_agent_turn( tools = self.get_tools() # Run model and yield events - for event in self.agent._run_model( - messages=messages, - tools=tools, - model_kwargs=model_kwargs, - ): - yield event + with ctx(orchestrator=self): + for event in self.agent._run_model( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event run_context.llm_calls += 1 for task in assigned_tasks: @@ -225,7 +248,75 @@ def _run_agent_turn( run_context.agent_turns += 1 - @prefect_task(task_run_name="Orchestrator.run()") + @prefect_task(task_run_name="Run agent orchestrator") + def _run( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> Iterator[Event]: + """Run the orchestrator, yielding handled events as they occur.""" + # Initialize the agent if not already set + if not self.agent: + self.agent = self.turn_strategy.get_next_agent( + None, self.get_available_agents() + ) + + # Signal the start of orchestration + start_event = OrchestratorStart(orchestrator=self, run_context=run_context) + self.handle_event(start_event) + yield start_event + + try: + while True: + if run_context.should_end(): + break + + turn_start = AgentTurnStart(orchestrator=self, agent=self.agent) + self.handle_event(turn_start) + yield turn_start + + # Run turn and yield its events + for event in self._run_agent_turn( + run_context=run_context, + model_kwargs=model_kwargs, + ): + self.handle_event(event) + yield event + + # Handle any events added during the turn + while self._pending_events: + event = self._pending_events.popleft() + self.handle_event(event) + yield event + + turn_end = AgentTurnEnd(orchestrator=self, agent=self.agent) + self.handle_event(turn_end) + yield turn_end + + # Select the next agent for the following turn + if available_agents := self.get_available_agents(): + self.agent = self.turn_strategy.get_next_agent( + self.agent, available_agents + ) + + except Exception as exc: + # Yield error event if something goes wrong + error_event = OrchestratorError(orchestrator=self, error=exc) + self.handle_event(error_event) + yield error_event + raise + finally: + # Signal the end of orchestration + end_event = OrchestratorEnd(orchestrator=self, run_context=run_context) + self.handle_event(end_event) + yield end_event + + # Handle any final pending events + while self._pending_events: + event = self._pending_events.popleft() + self.handle_event(event) + yield event + def run( self, max_llm_calls: Optional[int] = None, @@ -234,10 +325,21 @@ def run( run_until: Optional[ Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, - ) -> RunContext: + stream: bool = False, + ) -> Union[RunContext, Iterator[Event]]: """ - Run the orchestrator, handling events internally. - Returns the final run context. + Run the orchestrator. + + Args: + max_llm_calls: Maximum number of LLM calls allowed + max_agent_turns: Maximum number of agent turns allowed + model_kwargs: Additional kwargs for the model + run_until: Condition for ending the run + stream: If True, return iterator of events. If False, consume events and return context + + Returns: + If stream=True, returns Iterator[Event] + If stream=False, returns RunContext """ # Create run context at the outermost level if run_until is None: @@ -257,19 +359,25 @@ def run( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - for event in self._run( + result = self._run( run_context=run_context, model_kwargs=model_kwargs, - ): - self.handle_event(event) - return run_context + ) - def _run( + if stream: + return result + else: + for _ in result: + pass + return run_context + + @prefect_task(task_run_name="Run agent orchestrator") + async def _run_async( self, run_context: RunContext, model_kwargs: Optional[dict] = None, - ) -> Iterator[Event]: - """Run the orchestrator, yielding events as they occur.""" + ) -> AsyncIterator[Event]: + """Run the orchestrator asynchronously, yielding handled events as they occur.""" # Initialize the agent if not already set if not self.agent: self.agent = self.turn_strategy.get_next_agent( @@ -277,23 +385,36 @@ def _run( ) # Signal the start of orchestration - yield OrchestratorStart(orchestrator=self, run_context=run_context) + start_event = OrchestratorStart(orchestrator=self, run_context=run_context) + await self.handle_event_async(start_event) + yield start_event try: while True: if run_context.should_end(): break - yield AgentTurnStart(orchestrator=self, agent=self.agent) + turn_start = AgentTurnStart(orchestrator=self, agent=self.agent) + await self.handle_event_async(turn_start) + yield turn_start # Run turn and yield its events - for event in self._run_agent_turn( + async for event in self._run_agent_turn_async( run_context=run_context, model_kwargs=model_kwargs, ): + await self.handle_event_async(event) + yield event + + # Handle any events added during the turn + while self._pending_events: + event = self._pending_events.popleft() + await self.handle_event_async(event) yield event - yield AgentTurnEnd(orchestrator=self, agent=self.agent) + turn_end = AgentTurnEnd(orchestrator=self, agent=self.agent) + await self.handle_event_async(turn_end) + yield turn_end # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -303,13 +424,22 @@ def _run( except Exception as exc: # Yield error event if something goes wrong - yield OrchestratorError(orchestrator=self, error=exc) + error_event = OrchestratorError(orchestrator=self, error=exc) + await self.handle_event_async(error_event) + yield error_event raise finally: # Signal the end of orchestration - yield OrchestratorEnd(orchestrator=self, run_context=run_context) + end_event = OrchestratorEnd(orchestrator=self, run_context=run_context) + await self.handle_event_async(end_event) + yield end_event + + # Handle any final pending events + while self._pending_events: + event = self._pending_events.popleft() + await self.handle_event_async(event) + yield event - @prefect_task async def run_async( self, max_llm_calls: Optional[int] = None, @@ -318,10 +448,21 @@ async def run_async( run_until: Optional[ Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, - ) -> RunContext: + stream: bool = False, + ) -> Union[RunContext, AsyncIterator[Event]]: """ - Run the orchestrator asynchronously, handling events internally. - Returns the final run context. + Run the orchestrator asynchronously. + + Args: + max_llm_calls: Maximum number of LLM calls allowed + max_agent_turns: Maximum number of agent turns allowed + model_kwargs: Additional kwargs for the model + run_until: Condition for ending the run + stream: If True, return async iterator of events. If False, consume events and return context + + Returns: + If stream=True, returns AsyncIterator[Event] + If stream=False, returns RunContext """ # Create run context at the outermost level if run_until is None: @@ -341,126 +482,17 @@ async def run_async( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - async for event in self._run_async( + result = self._run_async( run_context=run_context, model_kwargs=model_kwargs, - ): - await self.handle_event_async(event) - return run_context - - @prefect_task(task_run_name="Agent turn: {self.agent.name}") - def run_agent_turn( - self, - run_context: RunContext, - model_kwargs: Optional[dict] = None, - ) -> int: - """ - Run a single agent turn, which may consist of multiple LLM calls. - """ - assigned_tasks = self.get_tasks("assigned") - - self.turn_strategy.begin_turn() - - # Mark assigned tasks as running - for task in assigned_tasks: - if not task.is_running(): - task.mark_running() - self.handle_event( - OrchestratorMessage( - content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " - f"with objective: {task.objective}" - ) - ) - - while not self.turn_strategy.should_end_turn(): - # fail any tasks that have reached their max llm calls - for task in assigned_tasks: - if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: - task.mark_failed(reason="Max LLM calls reached for this task.") - - # Check if there are any ready tasks left - if not any(t.is_ready() for t in assigned_tasks): - logger.debug("No `ready` tasks to run") - break - - if run_context.should_end(): - break - - messages = self.compile_messages() - tools = self.get_tools() - - for event in self.agent._run_model( - messages=messages, - tools=tools, - model_kwargs=model_kwargs, - ): - self.handle_event(event) - - run_context.llm_calls += 1 - for task in assigned_tasks: - task._llm_calls += 1 - - run_context.agent_turns += 1 - - @prefect_task - async def run_agent_turn_async( - self, - run_context: RunContext, - model_kwargs: Optional[dict] = None, - ) -> int: - """ - Run a single agent turn asynchronously, which may consist of multiple LLM calls. - - Args: - max_llm_calls (Optional[int]): The number of LLM calls allowed. - - Returns: - int: The number of LLM calls made during this turn. - """ - assigned_tasks = self.get_tasks("assigned") - - self.turn_strategy.begin_turn() - - # Mark assigned tasks as running - for task in assigned_tasks: - if not task.is_running(): - task.mark_running() - await self.handle_event_async( - OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) " - f"with objective: {task.objective}" - ) - ) - - while not self.turn_strategy.should_end_turn(): - # fail any tasks that have reached their max llm calls - for task in assigned_tasks: - if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: - task.mark_failed(reason="Max LLM calls reached for this task.") - - # Check if there are any ready tasks left - if not any(t.is_ready() for t in assigned_tasks): - logger.debug("No `ready` tasks to run") - break - - if run_context.should_end(): - break - - messages = self.compile_messages() - tools = self.get_tools() - - async for event in self.agent._run_model_async( - messages=messages, - tools=tools, - model_kwargs=model_kwargs, - ): - await self.handle_event_async(event) - - run_context.llm_calls += 1 - for task in assigned_tasks: - task._llm_calls += 1 + ) - run_context.agent_turns += 1 + if stream: + return result + else: + async for _ in result: + pass + return run_context def compile_prompt(self) -> str: """ @@ -629,12 +661,14 @@ async def _run_agent_turn_async( messages = self.compile_messages() tools = self.get_tools() - async for event in self.agent._run_model_async( - messages=messages, - tools=tools, - model_kwargs=model_kwargs, - ): - yield event + # Run model and yield events + with ctx(orchestrator=self): + async for event in self.agent._run_model_async( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event run_context.llm_calls += 1 for task in assigned_tasks: @@ -642,51 +676,6 @@ async def _run_agent_turn_async( run_context.agent_turns += 1 - async def _run_async( - self, - run_context: RunContext, - model_kwargs: Optional[dict] = None, - ) -> AsyncIterator[Event]: - """Run the orchestrator asynchronously, yielding events as they occur.""" - # Initialize the agent if not already set - if not self.agent: - self.agent = self.turn_strategy.get_next_agent( - None, self.get_available_agents() - ) - - # Signal the start of orchestration - yield OrchestratorStart(orchestrator=self, run_context=run_context) - - try: - while True: - if run_context.should_end(): - break - - yield AgentTurnStart(orchestrator=self, agent=self.agent) - - # Run turn and yield its events - async for event in self._run_agent_turn_async( - run_context=run_context, - model_kwargs=model_kwargs, - ): - yield event - - yield AgentTurnEnd(orchestrator=self, agent=self.agent) - - # Select the next agent for the following turn - if available_agents := self.get_available_agents(): - self.agent = self.turn_strategy.get_next_agent( - self.agent, available_agents - ) - - except Exception as exc: - # Yield error event if something goes wrong - yield OrchestratorError(orchestrator=self, error=exc) - raise - finally: - # Signal the end of orchestration - yield OrchestratorEnd(orchestrator=self, run_context=run_context) - # Rebuild all models with forward references after Orchestrator is defined OrchestratorStart.model_rebuild() diff --git a/src/controlflow/run.py b/src/controlflow/run.py index dc4e285d..2362abea 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -1,25 +1,17 @@ -from typing import Any, Callable, Optional, Union - -from prefect.context import TaskRunContext +from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union import controlflow from controlflow.agents.agent import Agent +from controlflow.events.events import Event from controlflow.flows import Flow, get_flow from controlflow.orchestration.conditions import RunContext, RunEndCondition from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy +from controlflow.stream import Stream, filter_events_async, filter_events_sync from controlflow.tasks.task import Task from controlflow.utilities.prefect import prefect_task -def get_task_run_name() -> str: - context = TaskRunContext.get() - tasks = context.parameters["tasks"] - task_names = " | ".join(t.friendly_name() for t in tasks) - return f"Run task{'s' if len(tasks) > 1 else ''}: {task_names}" - - -@prefect_task(task_run_name=get_task_run_name) def run_tasks( tasks: list[Task], instructions: str = None, @@ -32,11 +24,10 @@ def run_tasks( handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, -) -> list[Any]: + stream: Union[bool, Stream] = False, +) -> Union[list[Any], Iterator[tuple[Event, Any, Optional[Any]]]]: """ Run a list of tasks. - - Returns a list of task results corresponding to the input tasks, or raises an error if any tasks failed. """ flow = flow or get_flow() or Flow() @@ -49,13 +40,19 @@ def run_tasks( ) with controlflow.instructions(instructions): - orchestrator.run( + result = orchestrator.run( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, model_kwargs=model_kwargs, run_until=run_until, + stream=bool(stream), ) + if stream: + # Convert True to ALL filter, otherwise use provided filter + stream_filter = Stream.ALL if stream is True else stream + return filter_events_sync(result, stream_filter) + if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] if errors: @@ -67,7 +64,6 @@ def run_tasks( return [t.result for t in tasks] -@prefect_task(task_run_name=get_task_run_name) async def run_tasks_async( tasks: list[Task], instructions: str = None, @@ -80,7 +76,8 @@ async def run_tasks_async( handlers: list[Union[Handler, AsyncHandler]] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, -): + stream: Union[bool, Stream] = False, +) -> Union[list[Any], AsyncIterator[tuple[Event, Any, Optional[Any]]]]: """ Run a list of tasks asynchronously. """ @@ -94,13 +91,19 @@ async def run_tasks_async( ) with controlflow.instructions(instructions): - await orchestrator.run_async( + result = await orchestrator.run_async( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, model_kwargs=model_kwargs, run_until=run_until, + stream=bool(stream), ) + if stream: + # Convert True to ALL filter, otherwise use provided filter + stream_filter = Stream.ALL if stream is True else stream + return filter_events_async(result, stream_filter) + if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] if errors: @@ -122,8 +125,24 @@ def run( handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, + stream: Union[bool, Stream] = False, **task_kwargs, -) -> Any: +) -> Union[Any, Iterator[tuple[Event, Any, Optional[Any]]]]: + """ + Run a single task. + + Args: + objective: Objective of the task. + turn_strategy: Turn strategy to use for the task. + max_llm_calls: Maximum number of LLM calls to make. + max_agent_turns: Maximum number of agent turns to make. + raise_on_failure: Whether to raise an error if the task fails. + handlers: List of handlers to use for the task. + model_kwargs: Keyword arguments to pass to the LLM. + run_until: Condition to stop running the task. + stream: If True, stream all events. Can also provide StreamFilter flags to filter specific events. + e.g. StreamFilter.CONTENT | StreamFilter.AGENT_TOOLS + """ task = Task(objective=objective, **task_kwargs) results = run_tasks( tasks=[task], @@ -134,8 +153,12 @@ def run( handlers=handlers, model_kwargs=model_kwargs, run_until=run_until, + stream=stream, ) - return results[0] + if stream: + return results + else: + return results[0] async def run_async( @@ -150,6 +173,7 @@ async def run_async( handlers: list[Union[Handler, AsyncHandler]] = None, model_kwargs: Optional[dict] = None, run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, + stream: Union[bool, Stream] = False, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -164,5 +188,9 @@ async def run_async( handlers=handlers, model_kwargs=model_kwargs, run_until=run_until, + stream=stream, ) - return results[0] + if stream: + return results + else: + return results[0] diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py index e8552ebf..f7a2bd2e 100644 --- a/src/controlflow/stream.py +++ b/src/controlflow/stream.py @@ -1,24 +1,6 @@ -# Example usage -# -# # Stream all events -# for event in cf.stream.events("Write a story"): -# print(event) -# -# # Stream just messages -# for event in cf.stream.events("Write a story", events='messages'): -# print(event.content) -# -# # Stream just the result -# for delta, snapshot in cf.stream.result("Write a story"): -# print(f"New: {delta}") -# -# # Stream results from multiple tasks -# for delta, snapshot in cf.stream.result_from_tasks([task1, task2]): -# print(f"New result: {delta}") -# -from typing import Any, AsyncIterator, Callable, Iterator, Literal, Optional, Union - -from controlflow.events.base import Event +from enum import Flag, auto +from typing import Any, AsyncIterator, Iterator, Optional, Union + from controlflow.events.events import ( AgentContent, AgentContentDelta, @@ -26,184 +8,111 @@ AgentMessageDelta, AgentToolCall, AgentToolCallDelta, + Event, ToolResult, ) -from controlflow.orchestration.handler import AsyncHandler, Handler -from controlflow.orchestration.orchestrator import Orchestrator -from controlflow.tasks.task import Task - -StreamEvents = Union[ - list[str], - Literal["all", "messages", "content", "tools", "completion_tools", "agent_tools"], -] - - -def event_filter(events: StreamEvents) -> Callable[[Event], bool]: - def _event_filter(event: Event) -> bool: - if events == "all": - return True - elif events == "messages": - return isinstance(event, (AgentMessage, AgentMessageDelta)) - elif events == "content": - return isinstance(event, (AgentContent, AgentContentDelta)) - elif events == "tools": - return isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)) - elif events == "completion_tools": - if isinstance(event, (AgentToolCall, AgentToolCallDelta)): - return event.tool and event.tool.metadata.get("is_completion_tool") - elif isinstance(event, ToolResult): - return event.tool_result and event.tool_result.tool.metadata.get( - "is_completion_tool" - ) - return False - elif events == "agent_tools": - if isinstance(event, (AgentToolCall, AgentToolCallDelta)): - return event.tool and event.tool in event.agent.get_tools() - elif isinstance(event, ToolResult): - return ( - event.tool_result - and event.tool_result.tool in event.agent.get_tools() - ) - return False - else: - raise ValueError(f"Invalid event type: {events}") - - return _event_filter - - -# -------------------- BELOW HERE IS THE OLD STUFF -------------------- - - -def events( - objective: str, - *, - events: StreamEvents = "all", - filter_fn: Optional[Callable[[Event], bool]] = None, - **kwargs, -) -> Iterator[Event]: - """ - Stream events from a task execution. - - Args: - objective: The task objective - events: Which events to stream. Can be list of event types or: - 'all' - all events - 'messages' - agent messages - 'tools' - all tool calls/results - 'completion_tools' - only completion tools - filter_fn: Optional additional filter function - **kwargs: Additional arguments passed to Task - - Returns: - Iterator of Event objects - """ +from controlflow.events.task_events import ( + TaskFailure, + TaskSkipped, + TaskStart, + TaskSuccess, +) - def get_event_filter(): - if isinstance(events, list): - return lambda e: e.event in events - elif events == "messages": - return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) - elif events == "tools": - return lambda e: isinstance(e, (AgentToolCall, ToolResult)) - elif events == "completion_tools": - return lambda e: ( - isinstance(e, (AgentToolCall, ToolResult)) - and e.tool_call["name"].startswith("mark_task_") - ) - else: # 'all' - return lambda e: True - - event_filter = get_event_filter() - - def event_handler(event: Event): - if event_filter(event) and (not filter_fn or filter_fn(event)): - yield event - - task = Task(objective=objective) - task.run(handlers=[Handler(event_handler)], **kwargs) - - -def result( - objective: str, - **kwargs, -) -> Iterator[tuple[Any, Any]]: - """ - Stream result from a task execution. - Args: - objective: The task objective - **kwargs: Additional arguments passed to Task +class Stream(Flag): + """ + Filter flags for event streaming. - Returns: - Iterator of (delta, accumulated) result tuples + Can be combined using bitwise operators: + stream_filter = Stream.CONTENT | Stream.AGENT_TOOLS """ - current_result = None - - def result_handler(event: Event): - nonlocal current_result - if isinstance(event, ToolResult): - if event.tool_call["name"].startswith("mark_task_"): - result = event.tool_result.result # Get actual result value - if result != current_result: # Only yield if changed - current_result = result - yield (result, result) # For now delta == full result - - task = Task(objective=objective) - task.run(handlers=[Handler(result_handler)], **kwargs) - - -def events_from_tasks( - tasks: list[Task], - events: StreamEvents = "all", - filter_fn: Optional[Callable[[Event], bool]] = None, - **kwargs, -) -> Iterator[Event]: - """Stream events from multiple task executions.""" - - def get_event_filter(): - if isinstance(events, list): - return lambda e: e.event in events - elif events == "messages": - return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) - elif events == "tools": - return lambda e: isinstance(e, (AgentToolCall, ToolResult)) - elif events == "completion_tools": - return lambda e: ( - isinstance(e, (AgentToolCall, ToolResult)) - and e.tool_call["name"].startswith("mark_task_") - ) - else: # 'all' - return lambda e: True - - event_filter = get_event_filter() - - def event_handler(event: Event): - if event_filter(event) and (not filter_fn or filter_fn(event)): - yield event - - orchestrator = Orchestrator( - tasks=tasks, handlers=[Handler(event_handler)], **kwargs - ) - orchestrator.run() - - -def result_from_tasks( - tasks: list[Task], - **kwargs, -) -> Iterator[tuple[Any, Any]]: - """Stream results from multiple task executions.""" - current_results = {task.id: None for task in tasks} - - def result_handler(event: Event): - if isinstance(event, ToolResult): - if event.tool_call["name"].startswith("mark_task_"): - task_id = event.task.id - result = event.tool_result.result - if result != current_results[task_id]: - current_results[task_id] = result - yield (result, result) - - orchestrator = Orchestrator( - tasks=tasks, handlers=[Handler(result_handler)], **kwargs - ) - orchestrator.run() + + NONE = 0 + ALL = auto() # All events + CONTENT = auto() # Agent content and deltas + AGENT_TOOLS = auto() # Non-completion tool events + COMPLETION_TOOLS = auto() # Completion tool events + TOOLS = AGENT_TOOLS | COMPLETION_TOOLS # All tool events + TASK_EVENTS = auto() # Task state change events + + +def should_include_event(event: Event, stream_filter: Stream) -> bool: + """Determine if an event should be included based on the stream filter.""" + # Pass all events if ALL is specified + if stream_filter == Stream.ALL: + return True + + # Content events + if isinstance(event, (AgentContent, AgentContentDelta)): + return bool(stream_filter & Stream.CONTENT) + + # Tool events + if isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)): + if is_completion_tool_event(event): + return bool(stream_filter & Stream.COMPLETION_TOOLS) + return bool(stream_filter & Stream.AGENT_TOOLS) + + # Task events + if isinstance(event, (TaskStart, TaskSuccess, TaskFailure, TaskSkipped)): + return bool(stream_filter & Stream.TASK_EVENTS) + + return False + + +def is_completion_tool_event(event: Event) -> bool: + """Check if an event is related to a completion tool call.""" + if isinstance(event, ToolResult): + tool = event.tool_result.tool + elif isinstance(event, (AgentToolCall, AgentToolCallDelta)): + tool = event.tool + else: + return False + + return tool and tool.metadata.get("is_completion_tool") + + +def process_event(event: Event) -> tuple[Event, Any, Optional[Any]]: + """Process a single event and return the appropriate tuple.""" + # Message events + if isinstance(event, AgentMessage): + return event, event.message, None + elif isinstance(event, AgentMessageDelta): + return event, event.message_snapshot, event.message_delta + + # Content events + elif isinstance(event, AgentContent): + return event, event.content, None + elif isinstance(event, AgentContentDelta): + return event, event.content_snapshot, event.content_delta + + # Tool call events + elif isinstance(event, AgentToolCall): + return event, event.tool_call, None + elif isinstance(event, AgentToolCallDelta): + return event, event.tool_call_snapshot, event.tool_call_delta + + # Tool result events + elif isinstance(event, ToolResult): + return event, event.tool_result, None + + else: + # Pass through any other events with no snapshot/delta + return event, None, None + + +def filter_events_sync( + events: Iterator[Event], stream_filter: Stream +) -> Iterator[tuple[Event, Any, Optional[Any]]]: + """Synchronously filter events based on Stream flags.""" + for event in events: + if should_include_event(event, stream_filter): + yield process_event(event) + + +async def filter_events_async( + events: AsyncIterator[Event], stream_filter: Stream +) -> AsyncIterator[tuple[Event, Any, Optional[Any]]]: + """Asynchronously filter events based on Stream flags.""" + async for event in events: + if should_include_event(event, stream_filter): + yield process_event(event) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 2ea43730..bf502492 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -9,6 +9,7 @@ Callable, Generator, GenericAlias, + Iterator, Literal, Optional, TypeVar, @@ -49,26 +50,20 @@ unwrap, ) from controlflow.utilities.logging import get_logger -from controlflow.utilities.prefect import prefect_task as prefect_task if TYPE_CHECKING: + from controlflow.events.events import Event from controlflow.flows import Flow from controlflow.orchestration.handler import AsyncHandler, Handler from controlflow.orchestration.turn_strategies import TurnStrategy + from controlflow.stream import Stream T = TypeVar("T") logger = get_logger(__name__) - COMPLETION_TOOLS = Literal["SUCCEED", "FAIL"] -def get_task_run_name(): - context = TaskRunContext.get() - task = context.parameters["self"] - return f"Task.run() ({task.friendly_name()})" - - class Labels(RootModel): root: tuple[Any, ...] @@ -392,7 +387,6 @@ def add_dependency(self, task: "Task"): self.depends_on.add(task) task._downstreams.add(self) - @prefect_task(task_run_name=get_task_run_name) def run( self, agent: Optional[Agent] = None, @@ -403,12 +397,27 @@ def run( handlers: list["Handler"] = None, raise_on_failure: bool = True, model_kwargs: Optional[dict] = None, - ) -> T: + stream: Union[bool, "Stream"] = False, + ) -> Union[T, Iterator[tuple["Event", Any, Optional[Any]]]]: """ Run the task - """ - controlflow.run_tasks( + Args: + agent: Optional agent to run the task + flow: Optional flow to run the task in + turn_strategy: Optional turn strategy to use + max_llm_calls: Maximum number of LLM calls to make + max_agent_turns: Maximum number of agent turns to make + handlers: Optional list of handlers + raise_on_failure: Whether to raise on task failure + model_kwargs: Optional kwargs to pass to the model + stream: If True, stream all events. Can also provide StreamFilter flags. + + Returns: + If not streaming: The task result + If streaming: Iterator of (event, snapshot, delta) tuples + """ + result = controlflow.run_tasks( tasks=[self], flow=flow, agent=agent, @@ -418,14 +427,16 @@ def run( raise_on_failure=False, handlers=handlers, model_kwargs=model_kwargs, + stream=stream, ) - if self.is_successful(): + if stream: + return result + elif self.is_successful(): return self.result elif raise_on_failure and self.is_failed(): raise ValueError(f"{self.friendly_name()} failed: {self.result}") - @prefect_task(task_run_name=get_task_run_name) async def run_async( self, agent: Optional[Agent] = None, @@ -435,12 +446,13 @@ async def run_async( max_agent_turns: int = None, handlers: list[Union["Handler", "AsyncHandler"]] = None, raise_on_failure: bool = True, + stream: Union[bool, "Stream"] = False, ) -> T: """ Run the task """ - await controlflow.run_tasks_async( + result = await controlflow.run_tasks_async( tasks=[self], flow=flow, agent=agent, @@ -449,9 +461,12 @@ async def run_async( max_agent_turns=max_agent_turns, raise_on_failure=False, handlers=handlers, + stream=stream, ) - if self.is_successful(): + if stream: + return result + elif self.is_successful(): return self.result elif raise_on_failure and self.is_failed(): raise ValueError(f"{self.friendly_name()} failed: {self.result}") @@ -564,18 +579,38 @@ def set_status(self, status: TaskStatus): tui.update_task(self) def mark_running(self): + """Mark the task as running and emit a TaskStart event.""" self.set_status(TaskStatus.RUNNING) + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskStart + + orchestrator.add_event(TaskStart(task=self)) def mark_successful(self, result: T = None): + """Mark the task as successful and emit a TaskSuccess event.""" self.result = self.validate_result(result) self.set_status(TaskStatus.SUCCESSFUL) + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskSuccess + + orchestrator.add_event(TaskSuccess(task=self, result=result)) def mark_failed(self, reason: Optional[str] = None): + """Mark the task as failed and emit a TaskFailure event.""" self.result = reason self.set_status(TaskStatus.FAILED) + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskFailure + + orchestrator.add_event(TaskFailure(task=self, reason=reason)) def mark_skipped(self): + """Mark the task as skipped and emit a TaskSkipped event.""" self.set_status(TaskStatus.SKIPPED) + if orchestrator := ctx.get("orchestrator"): + from controlflow.events.task_events import TaskSkipped + + orchestrator.add_event(TaskSkipped(task=self)) def get_success_tool(self) -> Tool: """ diff --git a/tests/test_run.py b/tests/test_run.py index 7c14a25d..c8f1cabd 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,10 +1,15 @@ -from controlflow import instructions +import pytest + +import controlflow +from controlflow import Stream, instructions from controlflow.events.base import Event from controlflow.events.events import AgentMessage +from controlflow.llm.messages import AIMessage from controlflow.orchestration.conditions import AnyComplete, AnyFailed, MaxLLMCalls from controlflow.orchestration.handler import Handler from controlflow.run import run, run_async, run_tasks, run_tasks_async from controlflow.tasks.task import Task +from tests.fixtures.controlflow import default_fake_llm class TestHandlers: @@ -167,3 +172,117 @@ async def test_min_failed(self): assert task1.is_failed() assert task2.is_incomplete() assert task3.is_failed() + + +class TestRunStreaming: + # Helper function to collect async iterator results + async def collect_stream(self, ait): + return [x async for x in ait] + + @pytest.fixture + def task(self, default_fake_llm): + task = controlflow.Task("say hello", id="12345") + + response = AIMessage( + id="run-2af8bb73-661f-4ec3-92ff-d7d8e3074926", + name="Marvin", + role="ai", + content="", + tool_calls=[ + { + "name": "mark_task_12345_successful", + "args": {"task_result": "Hello!"}, + "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", + "type": "tool_call", + } + ], + ) + + default_fake_llm.set_responses(["Hello!", response]) + + return task + + def test_stream_all(self, default_fake_llm): + result = run("what's 2 + 2", stream=True, max_llm_calls=1) + r = list(result) + assert len(r) > 5 + + async def test_stream_all_async(self, default_fake_llm): + result = await run_async("what's 2 + 2", stream=True, max_llm_calls=1) + r = await self.collect_stream(result) + assert len(r) > 5 + + def test_stream_task(self, task): + result = list(task.run(stream=True)) + assert result[0][0].event == "orchestrator-start" + assert result[1][0].event == "agent-turn-start" + assert result[-1][0].event == "orchestrator-end" + assert any(r[0].event == "agent-message" for r in result) + assert any(r[0].event == "agent-message-delta" for r in result) + assert any(r[0].event == "agent-content" for r in result) + assert any(r[0].event == "agent-content-delta" for r in result) + assert any(r[0].event == "agent-tool-call" for r in result) + + async def test_stream_task_async(self, task): + result = await task.run_async(stream=True) + r = await self.collect_stream(result) + assert r[0][0].event == "orchestrator-start" + assert r[1][0].event == "agent-turn-start" + assert r[-1][0].event == "orchestrator-end" + assert any(x[0].event == "agent-message" for x in r) + assert any(x[0].event == "agent-message-delta" for x in r) + assert any(x[0].event == "agent-content" for x in r) + assert any(x[0].event == "agent-content-delta" for x in r) + assert any(x[0].event == "agent-tool-call" for x in r) + + def test_stream_content(self, task): + result = list(task.run(stream=Stream.CONTENT)) + assert all( + r[0].event in ("agent-content", "agent-content-delta") for r in result + ) + + async def test_stream_content_async(self, task): + result = await task.run_async(stream=Stream.CONTENT) + r = await self.collect_stream(result) + assert all(x[0].event in ("agent-content", "agent-content-delta") for x in r) + + def test_stream_tools(self, task): + result = list(task.run(stream=Stream.TOOLS)) + assert all( + r[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for r in result + ) + + async def test_stream_tools_async(self, task): + result = await task.run_async(stream=Stream.TOOLS) + r = await self.collect_stream(result) + assert all( + x[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for x in r + ) + + def test_stream_completion_tools(self, task): + result = list(task.run(stream=Stream.COMPLETION_TOOLS)) + assert all( + r[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for r in result + ) + + async def test_stream_completion_tools_async(self, task): + result = await task.run_async(stream=Stream.COMPLETION_TOOLS) + r = await self.collect_stream(result) + assert all( + x[0].event in ("agent-tool-call", "agent-tool-call-delta", "tool-result") + for x in r + ) + + def test_stream_task_events(self, task): + result = list(task.run(stream=Stream.TASK_EVENTS)) + assert result[-1][0].event == "task-success" + assert result[0][0].task is task + + async def test_stream_task_events_async(self, task): + result = await task.run_async(stream=Stream.TASK_EVENTS) + r = await self.collect_stream(result) + assert r[-1][0].event == "task-success" + assert r[0][0].task is task