diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index cc671969..508ffa1c 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -140,8 +140,9 @@ def to_messages(self, context: "CompileContext") -> list[BaseMessage]: ] else: return OrchestratorMessage( - prefix=f'The following {"failed" if self.tool_result.is_error else "successful"} ' - f'tool result was received by "{self.agent.name}" with ID {self.agent.id}', + prefix=f'The following {"failed " if self.tool_result.is_error else ""}' + f'tool result was received by "{self.agent.name}" with ID {self.agent.id}. ' + f'The tool call was: {self.tool_call}', content=self.tool_result.str_result, name=self.agent.name, ).to_messages(context) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index caf1688f..3c95d8b9 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -98,12 +98,24 @@ def get_tools(self) -> list[Tool]: list[Tool]: A list of available tools. """ tools = [] + + # add flow tools tools.extend(self.flow.tools) + + # add task tools for task in self.get_tasks("assigned"): tools.extend(task.get_tools()) + + # add completion tools + if task.completion_agents is None or self.agent in task.completion_agents: + tools.append(task.create_success_tool()) + tools.append(task.create_fail_tool()) + + # add turn strategy tools tools.extend( self.turn_strategy.get_tools(self.agent, self.get_available_agents()) ) + tools = as_tools(tools) return tools @@ -138,10 +150,6 @@ def _run_turn(self, max_calls_per_turn: Optional[int] = None): for event in self.agent._run_model(messages=messages, tools=tools): self.handle_event(event) - # Check if the current agent is still available - if self.agent not in self.get_available_agents(): - break - # at the end of each turn, select the next agent if available_agents := self.get_available_agents(): self.agent = self.turn_strategy.get_next_agent(self.agent, available_agents) @@ -178,10 +186,6 @@ async def _run_turn_async(self, max_calls_per_turn: Optional[int] = None): ): self.handle_event(event) - # Check if the current agent is still available - if self.agent not in self.get_available_agents(): - break - # at the end of each turn, select the next agent if available_agents := self.get_available_agents(): self.agent = self.turn_strategy.get_next_agent(self.agent, available_agents) diff --git a/src/controlflow/orchestration/prompt_templates/tasks.jinja b/src/controlflow/orchestration/prompt_templates/tasks.jinja index 3c5297dc..39746dea 100644 --- a/src/controlflow/orchestration/prompt_templates/tasks.jinja +++ b/src/controlflow/orchestration/prompt_templates/tasks.jinja @@ -14,9 +14,12 @@ The following tasks are active: {% for task in tasks %} - Assigned agents: {{ task._serialize_agents(task.get_agents()) }} +Assigned agents: {{ task._serialize_agents(task.get_agents()) }} +{% if task.completion_agents -%} +Completion agents: {{ task._serialize_agents(task.completion_agents) }} +{% endif %} - {{ task.get_prompt() }} +{{ task.get_prompt() }} {% endfor %} diff --git a/src/controlflow/orchestration/turn_strategies.py b/src/controlflow/orchestration/turn_strategies.py index 8ca2240a..2a1efb19 100644 --- a/src/controlflow/orchestration/turn_strategies.py +++ b/src/controlflow/orchestration/turn_strategies.py @@ -62,8 +62,10 @@ def create_delegate_tool( strategy: TurnStrategy, available_agents: Dict[Agent, List[Task]] ) -> Tool: @tool - def delegate_to_agent(agent_id: str) -> str: - """Delegate to another agent.""" + def delegate_to_agent(agent_id: str, message: str = None) -> str: + """Delegate to another agent and optionally send a message.""" + if len(available_agents) <= 1: + return "Cannot delegate as there are no other available agents." next_agent = next( (a for a in available_agents.keys() if a.id == agent_id), None ) @@ -95,18 +97,17 @@ class Popcorn(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] ) -> List[Tool]: - return [create_delegate_tool(self, available_agents)] + if len(available_agents) > 1: + return [create_delegate_tool(self, available_agents)] + else: + return [create_end_turn_tool(self)] def get_next_agent( self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] ) -> Agent: if self.next_agent and self.next_agent in available_agents: return self.next_agent - return ( - current_agent - if current_agent in available_agents - else next(iter(available_agents)) - ) + return next(iter(available_agents)) # Always return an available agent class Random(TurnStrategy): diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index e656713d..5df4991b 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -117,6 +117,10 @@ class Task(ControlFlowModel): default_factory=list, description="Tools available to every agent working on this task.", ) + completion_agents: Optional[list[Agent]] = Field( + default=None, + description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.", + ) interactive: bool = False created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) _subtasks: set["Task"] = set() @@ -220,13 +224,13 @@ def __repr__(self) -> str: serialized = self.model_dump(include={"id", "objective"}) return f"{self.__class__.__name__}({', '.join(f'{key}={repr(value)}' for key, value in serialized.items())})" - @field_validator("agents", mode="after") + @field_validator("agents") def _validate_agents(cls, v): if isinstance(v, list) and not v: raise ValueError("Agents must be `None` or a non-empty list of agents.") return v - @field_validator("parent", mode="before") + @field_validator("parent") def _default_parent(cls, v): if v is None: parent_tasks = ctx.get("tasks", []) @@ -235,7 +239,7 @@ def _default_parent(cls, v): v = None return v - @field_validator("result_type", mode="before") + @field_validator("result_type") def _ensure_result_type_is_list_if_literal(cls, v): if isinstance(v, _LiteralGenericAlias): v = v.__args__ @@ -273,6 +277,13 @@ def _serialize_result_type(self, result_type: list["Task"]): def _serialize_agents(self, agents: list[Agent]): return [agent.serialize_for_prompt() for agent in self.get_agents()] + @field_serializer("completion_agents") + def _serialize_completion_agents(self, completion_agents: Optional[list[Agent]]): + if completion_agents is not None: + return [agent.serialize_for_prompt() for agent in completion_agents] + else: + return None + @field_serializer("tools") def _serialize_tools(self, tools: list[Callable]): return [t.serialize_for_prompt() for t in controlflow.tools.as_tools(tools)] @@ -444,12 +455,13 @@ def get_tools(self) -> list[Union[Tool, Callable]]: tools = self.tools.copy() if self.interactive: tools.append(cli_input) - tools.extend( - [ - self.create_success_tool(), - self.create_fail_tool(), - ] - ) + return tools + + def get_completion_tools(self) -> list[Tool]: + tools = [ + self.create_success_tool(), + self.create_fail_tool(), + ] return tools def get_prompt(self) -> str: diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index b5cc41e5..40faea60 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -166,6 +166,18 @@ def test_task_loads_agent_from_parent_before_flow(): assert child.get_agents() == [agent2] +def test_completion_agents_default(): + task = Task(objective="Test task") + assert task.completion_agents is None + + +def test_completion_agents_set(): + agent1 = Agent(name="Agent 1") + agent2 = Agent(name="Agent 2") + task = Task(objective="Test task", completion_agents=[agent1, agent2]) + assert task.completion_agents == [agent1, agent2] + + class TestTaskStatus: def test_task_status_transitions(self): task = SimpleTask()