Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow configurable completion agents #275

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/controlflow/events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ def to_messages(self, context: "CompileContext") -> list[BaseMessage]:
]
else:
return OrchestratorMessage(
prefix=f'The following {"failed" if self.tool_result.is_error else "successful"} '
f'tool result was received by "{self.agent.name}" with ID {self.agent.id}',
prefix=f'The following {"failed " if self.tool_result.is_error else ""}'
f'tool result was received by "{self.agent.name}" with ID {self.agent.id}. '
f'The tool call was: {self.tool_call}',
content=self.tool_result.str_result,
name=self.agent.name,
).to_messages(context)
20 changes: 12 additions & 8 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,24 @@ def get_tools(self) -> list[Tool]:
list[Tool]: A list of available tools.
"""
tools = []

# add flow tools
tools.extend(self.flow.tools)

# add task tools
for task in self.get_tasks("assigned"):
tools.extend(task.get_tools())

# add completion tools
if task.completion_agents is None or self.agent in task.completion_agents:
tools.append(task.create_success_tool())
tools.append(task.create_fail_tool())

# add turn strategy tools
tools.extend(
self.turn_strategy.get_tools(self.agent, self.get_available_agents())
)

tools = as_tools(tools)
return tools

Expand Down Expand Up @@ -138,10 +150,6 @@ def _run_turn(self, max_calls_per_turn: Optional[int] = None):
for event in self.agent._run_model(messages=messages, tools=tools):
self.handle_event(event)

# Check if the current agent is still available
if self.agent not in self.get_available_agents():
break

# at the end of each turn, select the next agent
if available_agents := self.get_available_agents():
self.agent = self.turn_strategy.get_next_agent(self.agent, available_agents)
Expand Down Expand Up @@ -178,10 +186,6 @@ async def _run_turn_async(self, max_calls_per_turn: Optional[int] = None):
):
self.handle_event(event)

# Check if the current agent is still available
if self.agent not in self.get_available_agents():
break

# at the end of each turn, select the next agent
if available_agents := self.get_available_agents():
self.agent = self.turn_strategy.get_next_agent(self.agent, available_agents)
Expand Down
7 changes: 5 additions & 2 deletions src/controlflow/orchestration/prompt_templates/tasks.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ The following tasks are active:

{% for task in tasks %}
<Task ID {{ task.id }}>
Assigned agents: {{ task._serialize_agents(task.get_agents()) }}
Assigned agents: {{ task._serialize_agents(task.get_agents()) }}
{% if task.completion_agents -%}
Completion agents: {{ task._serialize_agents(task.completion_agents) }}
{% endif %}

{{ task.get_prompt() }}
{{ task.get_prompt() }}
</Task>
{% endfor %}

Expand Down
17 changes: 9 additions & 8 deletions src/controlflow/orchestration/turn_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def create_delegate_tool(
strategy: TurnStrategy, available_agents: Dict[Agent, List[Task]]
) -> Tool:
@tool
def delegate_to_agent(agent_id: str) -> str:
"""Delegate to another agent."""
def delegate_to_agent(agent_id: str, message: str = None) -> str:
"""Delegate to another agent and optionally send a message."""
if len(available_agents) <= 1:
return "Cannot delegate as there are no other available agents."
next_agent = next(
(a for a in available_agents.keys() if a.id == agent_id), None
)
Expand Down Expand Up @@ -95,18 +97,17 @@ class Popcorn(TurnStrategy):
def get_tools(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
) -> List[Tool]:
return [create_delegate_tool(self, available_agents)]
if len(available_agents) > 1:
return [create_delegate_tool(self, available_agents)]
else:
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
) -> Agent:
if self.next_agent and self.next_agent in available_agents:
return self.next_agent
return (
current_agent
if current_agent in available_agents
else next(iter(available_agents))
)
return next(iter(available_agents)) # Always return an available agent


class Random(TurnStrategy):
Expand Down
30 changes: 21 additions & 9 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class Task(ControlFlowModel):
default_factory=list,
description="Tools available to every agent working on this task.",
)
completion_agents: Optional[list[Agent]] = Field(
default=None,
description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.",
)
interactive: bool = False
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
_subtasks: set["Task"] = set()
Expand Down Expand Up @@ -220,13 +224,13 @@ def __repr__(self) -> str:
serialized = self.model_dump(include={"id", "objective"})
return f"{self.__class__.__name__}({', '.join(f'{key}={repr(value)}' for key, value in serialized.items())})"

@field_validator("agents", mode="after")
@field_validator("agents")
def _validate_agents(cls, v):
if isinstance(v, list) and not v:
raise ValueError("Agents must be `None` or a non-empty list of agents.")
return v

@field_validator("parent", mode="before")
@field_validator("parent")
def _default_parent(cls, v):
if v is None:
parent_tasks = ctx.get("tasks", [])
Expand All @@ -235,7 +239,7 @@ def _default_parent(cls, v):
v = None
return v

@field_validator("result_type", mode="before")
@field_validator("result_type")
def _ensure_result_type_is_list_if_literal(cls, v):
if isinstance(v, _LiteralGenericAlias):
v = v.__args__
Expand Down Expand Up @@ -273,6 +277,13 @@ def _serialize_result_type(self, result_type: list["Task"]):
def _serialize_agents(self, agents: list[Agent]):
return [agent.serialize_for_prompt() for agent in self.get_agents()]

@field_serializer("completion_agents")
def _serialize_completion_agents(self, completion_agents: Optional[list[Agent]]):
if completion_agents is not None:
return [agent.serialize_for_prompt() for agent in completion_agents]
else:
return None

@field_serializer("tools")
def _serialize_tools(self, tools: list[Callable]):
return [t.serialize_for_prompt() for t in controlflow.tools.as_tools(tools)]
Expand Down Expand Up @@ -444,12 +455,13 @@ def get_tools(self) -> list[Union[Tool, Callable]]:
tools = self.tools.copy()
if self.interactive:
tools.append(cli_input)
tools.extend(
[
self.create_success_tool(),
self.create_fail_tool(),
]
)
return tools

def get_completion_tools(self) -> list[Tool]:
tools = [
self.create_success_tool(),
self.create_fail_tool(),
]
return tools

def get_prompt(self) -> str:
Expand Down
12 changes: 12 additions & 0 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ def test_task_loads_agent_from_parent_before_flow():
assert child.get_agents() == [agent2]


def test_completion_agents_default():
task = Task(objective="Test task")
assert task.completion_agents is None


def test_completion_agents_set():
agent1 = Agent(name="Agent 1")
agent2 = Agent(name="Agent 2")
task = Task(objective="Test task", completion_agents=[agent1, agent2])
assert task.completion_agents == [agent1, agent2]


class TestTaskStatus:
def test_task_status_transitions(self):
task = SimpleTask()
Expand Down