Skip to content

Commit

Permalink
Merge pull request #275 from PrefectHQ/completion-agents
Browse files Browse the repository at this point in the history
Allow configurable completion agents
  • Loading branch information
jlowin authored Sep 4, 2024
2 parents 19ee8a2 + 877c4c1 commit d886202
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 29 deletions.
5 changes: 3 additions & 2 deletions src/controlflow/events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
]
else:
return OrchestratorMessage(
prefix=f'The following {"failed" if self.tool_result.is_error else "successful"} '
f'tool result was received by "{self.agent.name}" with ID {self.agent.id}',
prefix=f'The following {"failed " if self.tool_result.is_error else ""}'
f'tool result was received by "{self.agent.name}" with ID {self.agent.id}. '
f'The tool call was: {self.tool_call}',
content=self.tool_result.str_result,
name=self.agent.name,
).to_messages(context)
20 changes: 12 additions & 8 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,24 @@ def get_tools(self) -> list[Tool]:
list[Tool]: A list of available tools.
"""
tools = []

# add flow tools
tools.extend(self.flow.tools)

# add task tools
for task in self.get_tasks("assigned"):
tools.extend(task.get_tools())

# add completion tools
if task.completion_agents is None or self.agent in task.completion_agents:
tools.append(task.create_success_tool())
tools.append(task.create_fail_tool())

# add turn strategy tools
tools.extend(
self.turn_strategy.get_tools(self.agent, self.get_available_agents())
)

tools = as_tools(tools)
return tools

Expand Down Expand Up @@ -138,10 +150,6 @@ def _run_turn(self, max_calls_per_turn: Optional[int] = None):
for event in self.agent._run_model(messages=messages, tools=tools):
self.handle_event(event)

# Check if the current agent is still available
if self.agent not in self.get_available_agents():
break

# at the end of each turn, select the next agent
if available_agents := self.get_available_agents():
self.agent = self.turn_strategy.get_next_agent(self.agent, available_agents)
Expand Down Expand Up @@ -178,10 +186,6 @@ async def _run_turn_async(self, max_calls_per_turn: Optional[int] = None):
):
self.handle_event(event)

# Check if the current agent is still available
if self.agent not in self.get_available_agents():
break

# at the end of each turn, select the next agent
if available_agents := self.get_available_agents():
self.agent = self.turn_strategy.get_next_agent(self.agent, available_agents)
Expand Down
7 changes: 5 additions & 2 deletions src/controlflow/orchestration/prompt_templates/tasks.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ The following tasks are active:

{% for task in tasks %}
<Task ID {{ task.id }}>
Assigned agents: {{ task._serialize_agents(task.get_agents()) }}
Assigned agents: {{ task._serialize_agents(task.get_agents()) }}
{% if task.completion_agents -%}
Completion agents: {{ task._serialize_agents(task.completion_agents) }}
{% endif %}

{{ task.get_prompt() }}
{{ task.get_prompt() }}
</Task>
{% endfor %}

Expand Down
17 changes: 9 additions & 8 deletions src/controlflow/orchestration/turn_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def create_delegate_tool(
strategy: TurnStrategy, available_agents: Dict[Agent, List[Task]]
) -> Tool:
@tool
def delegate_to_agent(agent_id: str) -> str:
"""Delegate to another agent."""
def delegate_to_agent(agent_id: str, message: str = None) -> str:
"""Delegate to another agent and optionally send a message."""
if len(available_agents) <= 1:
return "Cannot delegate as there are no other available agents."
next_agent = next(
(a for a in available_agents.keys() if a.id == agent_id), None
)
Expand Down Expand Up @@ -95,18 +97,17 @@ class Popcorn(TurnStrategy):
def get_tools(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
) -> List[Tool]:
return [create_delegate_tool(self, available_agents)]
if len(available_agents) > 1:
return [create_delegate_tool(self, available_agents)]
else:
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
) -> Agent:
if self.next_agent and self.next_agent in available_agents:
return self.next_agent
return (
current_agent
if current_agent in available_agents
else next(iter(available_agents))
)
return next(iter(available_agents)) # Always return an available agent


class Random(TurnStrategy):
Expand Down
30 changes: 21 additions & 9 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class Task(ControlFlowModel):
default_factory=list,
description="Tools available to every agent working on this task.",
)
completion_agents: Optional[list[Agent]] = Field(
default=None,
description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.",
)
interactive: bool = False
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
_subtasks: set["Task"] = set()
Expand Down Expand Up @@ -220,13 +224,13 @@ def __repr__(self) -> str:
serialized = self.model_dump(include={"id", "objective"})
return f"{self.__class__.__name__}({', '.join(f'{key}={repr(value)}' for key, value in serialized.items())})"

@field_validator("agents", mode="after")
@field_validator("agents")
def _validate_agents(cls, v):
if isinstance(v, list) and not v:
raise ValueError("Agents must be `None` or a non-empty list of agents.")
return v

@field_validator("parent", mode="before")
@field_validator("parent")
def _default_parent(cls, v):
if v is None:
parent_tasks = ctx.get("tasks", [])
Expand All @@ -235,7 +239,7 @@ def _default_parent(cls, v):
v = None
return v

@field_validator("result_type", mode="before")
@field_validator("result_type")
def _ensure_result_type_is_list_if_literal(cls, v):
if isinstance(v, _LiteralGenericAlias):
v = v.__args__
Expand Down Expand Up @@ -273,6 +277,13 @@ def _serialize_result_type(self, result_type: list["Task"]):
def _serialize_agents(self, agents: list[Agent]):
return [agent.serialize_for_prompt() for agent in self.get_agents()]

@field_serializer("completion_agents")
def _serialize_completion_agents(self, completion_agents: Optional[list[Agent]]):
if completion_agents is not None:
return [agent.serialize_for_prompt() for agent in completion_agents]
else:
return None

@field_serializer("tools")
def _serialize_tools(self, tools: list[Callable]):
return [t.serialize_for_prompt() for t in controlflow.tools.as_tools(tools)]
Expand Down Expand Up @@ -444,12 +455,13 @@ def get_tools(self) -> list[Union[Tool, Callable]]:
tools = self.tools.copy()
if self.interactive:
tools.append(cli_input)
tools.extend(
[
self.create_success_tool(),
self.create_fail_tool(),
]
)
return tools

def get_completion_tools(self) -> list[Tool]:
tools = [
self.create_success_tool(),
self.create_fail_tool(),
]
return tools

def get_prompt(self) -> str:
Expand Down
12 changes: 12 additions & 0 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ def test_task_loads_agent_from_parent_before_flow():
assert child.get_agents() == [agent2]


def test_completion_agents_default():
task = Task(objective="Test task")
assert task.completion_agents is None


def test_completion_agents_set():
agent1 = Agent(name="Agent 1")
agent2 = Agent(name="Agent 2")
task = Task(objective="Test task", completion_agents=[agent1, agent2])
assert task.completion_agents == [agent1, agent2]


class TestTaskStatus:
def test_task_status_transitions(self):
task = SimpleTask()
Expand Down

0 comments on commit d886202

Please sign in to comment.