Skip to content

Commit

Permalink
Merge pull request #84 from PrefectHQ/multi-agent
Browse files Browse the repository at this point in the history
Improve multi-agent moderation
  • Loading branch information
jlowin authored Jun 10, 2024
2 parents fdd8d02 + 75bcd9d commit 33d668c
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 261 deletions.
5 changes: 3 additions & 2 deletions src/controlflow/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
from typing import Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional

from pydantic import Field

import controlflow
from controlflow.core.task import Task
from controlflow.llm.models import BaseChatModel, get_default_model
from controlflow.tools.talk_to_human import talk_to_human
from controlflow.utilities.types import ControlFlowModel

if TYPE_CHECKING:
from controlflow.core.task import Task
logger = logging.getLogger(__name__)


Expand Down
257 changes: 114 additions & 143 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import logging
import math
from collections import defaultdict
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, contextmanager
from functools import cached_property
from typing import Callable, Union

from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator

import controlflow
from controlflow.core.agent import Agent
from controlflow.core.controller.moderators import classify_moderator
from controlflow.core.flow import Flow, get_flow
from controlflow.core.graph import Graph
from controlflow.core.task import Task
Expand All @@ -25,13 +24,20 @@
logger = logging.getLogger(__name__)


def add_agent_name_to_message(msg: MessageType):
def add_agent_name_to_messages(messages: list[MessageType]) -> list[MessageType]:
"""
If the message is from a named assistant, prefix the message with the assistant's name.
"""
if isinstance(msg, AIMessage) and msg.name:
msg = msg.model_copy(update={"content": f"{msg.name}: {msg.content}"})
return msg
new_messages = []
for msg in messages:
if isinstance(msg, AIMessage) and msg.name:
msg = msg.copy(
update={
"content": f'(Message from agent "{msg.name}")\n\n{msg.content}'
}
)
new_messages.append(msg)
return new_messages


class Controller(BaseModel):
Expand Down Expand Up @@ -88,11 +94,13 @@ def _finalize(self):
def _create_end_turn_tool(self) -> Callable:
def end_turn():
"""
Call this tool to skip your turn and let another agent go next. This
is useful if you are stuck and can not complete any tasks. If this
tool is used 3 times by any agent the workflow will be aborted
automatically, so only use it if you are truly stuck and unable to
proceed.
This tool is for emergencies only; you should not use it normally.
If you find yourself in a situation where you are repeatedly invoked
and your normal tools do not work, or you can not escape the loop,
use this tool to signal to the controller that you are stuck. A new
agent will be selected to go next. If this tool is used 3 times by
an agent the workflow will be aborted automatically.
"""

# the agent's name is used as the key to track the number of times
Expand All @@ -110,13 +118,6 @@ def end_turn():

return end_turn

def choose_agent(self, agents: list[Agent], tasks: list[Task]) -> Agent:
return classify_moderator(
agents=agents,
tasks=tasks,
iteration=self._iteration,
)

@asynccontextmanager
async def tui(self):
if tui := ctx.get("tui"):
Expand All @@ -129,155 +130,125 @@ async def tui(self):
else:
yield

def _run_once_payload(self):
@contextmanager
def _setup_run(self):
"""
Generate the payload for a single run of the controller.
"""
if all(t.is_complete() for t in self.tasks):
return

# TODO: show the agent the entire graph, not just immediate upstreams
# get the tasks to run
tasks = self.graph.ready_tasks()
# get the agents
agent_candidates = [a for t in tasks for a in t.get_agents() if t.is_ready]
if len({a.name for a in agent_candidates}) != len(agent_candidates):
raise ValueError(
"Multiple agents with the same name were found. Agents must have unique names."
)
if self.agents:
agents = [a for a in agent_candidates if a in self.agents]
else:
agents = agent_candidates

# select the next agent
if len(agents) == 0:
raise ValueError(
"No agents were provided that are assigned to tasks that are ready to be run."
)
elif len(agents) == 1:
agent = agents[0]
else:
agent = self.choose_agent(agents=agents, tasks=tasks)

from controlflow.core.controller.instruction_template import (
MainTemplate,
)

tools = self.flow.tools + agent.get_tools() + [self._create_end_turn_tool()]

# add tools for any inactive tasks that the agent is assigned to
assigned_tools = []
for task in tasks:
if agent in task.get_agents():
assigned_tools.extend(task.get_tools())
if not assigned_tools:
raise ValueError(
f"Agent {agent.name} is not assigned to any of the tasks that are ready to be run."
with self.flow:
# TODO: show the agent the entire graph, not just immediate upstreams
tasks = self.graph.topological_sort()
ready_tasks = [t for t in tasks if t.is_ready]

# if there are no ready tasks, return. This will usually happen because
# all the tasks are complete.
if not ready_tasks:
yield None
return

# get an agent from the next ready task
agents = ready_tasks[0].get_agents()
if len(agents) != 1:
moderator = ready_tasks[0].get_moderator()
agent = moderator(agents=agents, task=ready_tasks[0], flow=self.flow)
ready_tasks[0]._iteration += 1
else:
agent = agents[0]

from controlflow.core.controller.instruction_template import MainTemplate

tools = self.flow.tools + agent.get_tools() + [self._create_end_turn_tool()]

# add tools for any ready tasks that the agent is assigned to
for task in ready_tasks:
if agent in task.get_agents():
tools.extend(task.get_tools())

instructions_template = MainTemplate(
agent=agent,
controller=self,
tasks=tasks,
context=self.context,
instructions=get_instructions(),
)
tools.extend(assigned_tools)

# tools = [prefect.task(tool) for tool in tools]

instructions_template = MainTemplate(
agent=agent,
controller=self,
tasks=tasks,
context=self.context,
instructions=get_instructions(),
)
instructions = instructions_template.render()

# prepare messages
system_message = SystemMessage(content=instructions)
messages = self.history.load_messages(thread_id=self.flow.thread_id)

# setup handlers
handlers = []
if controlflow.settings.enable_tui:
handlers.append(TUIHandler())
if controlflow.settings.enable_print_handler:
handlers.append(PrintHandler())

# yield the agent payload
return dict(
agent=agent,
messages=[system_message] + messages,
tools=tools,
handlers=handlers,
# message_preprocessor=add_agent_name_to_message,
)
instructions = instructions_template.render()

# prepare messages
system_message = SystemMessage(content=instructions)
messages = self.history.load_messages(thread_id=self.flow.thread_id)

# setup handlers
handlers = []
if controlflow.settings.enable_tui:
handlers.append(TUIHandler())
if controlflow.settings.enable_print_handler:
handlers.append(PrintHandler())

with ctx(controller_agent=agent):
# yield the agent payload
yield dict(
agent=agent,
messages=[system_message] + messages,
tools=tools,
handlers=handlers,
)

async def run_once_async(self):
async with self.tui():
with self.flow:
payload = self._run_once_payload()
if payload is not None:
agent: Agent = payload.pop("agent")
response_handler = ResponseHandler()
payload["handlers"].append(response_handler)

messages = []
for msg in payload["messages"]:
if isinstance(msg, AIMessage) and msg.name:
msg = msg.copy()
msg.content = (
f"Message from agent: {msg.name}\n\n{msg.content}"
)
messages.append(msg)

response_gen = await completion_async(
messages=messages,
model=agent.model,
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
ai_name=agent.name,
stream=True,
)
async for _ in response_gen:
pass

# save history
self.history.save_messages(
thread_id=self.flow.thread_id,
messages=response_handler.response_messages,
)
self._iteration += 1

def run_once(self):
with self.flow:
payload = self._run_once_payload()
if payload is not None:
with self._setup_run() as payload:
if payload is None:
return
agent: Agent = payload.pop("agent")
response_handler = ResponseHandler()
payload["handlers"].append(response_handler)

messages = []
for msg in payload["messages"]:
if isinstance(msg, AIMessage) and msg.name:
msg = msg.copy()
msg.content = f"Message from agent: {msg.name}\n\n{msg.content}"
messages.append(msg)

response_gen = completion(
messages=messages,
response_gen = await completion_async(
messages=payload["messages"],
model=agent.model,
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
ai_name=agent.name,
stream=True,
ai_name=agent.name,
message_preprocessor=add_agent_name_to_messages,
)
for _ in response_gen:
async for _ in response_gen:
pass

# save history
self.history.save_messages(
thread_id=self.flow.thread_id,
messages=response_handler.response_messages,
)
self._iteration += 1

def run_once(self):
with self._setup_run() as payload:
if payload is None:
return
agent: Agent = payload.pop("agent")
response_handler = ResponseHandler()
payload["handlers"].append(response_handler)

response_gen = completion(
messages=payload["messages"],
model=agent.model,
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
stream=True,
ai_name=agent.name,
message_preprocessor=add_agent_name_to_messages,
)
for _ in response_gen:
pass

# save history
self.history.save_messages(
thread_id=self.flow.thread_id,
messages=response_handler.response_messages,
)
self._iteration += 1
self._iteration += 1

async def run_async(self):
"""
Expand Down
19 changes: 12 additions & 7 deletions src/controlflow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ class TasksTemplate(Template):
template: str = """
## Tasks
Your job is to complete any tasks assigned to you. Tasks may have multiple agents assigned.
Your job is to complete any tasks assigned to you. Tasks may have
multiple agents assigned.
### Current tasks
### Ready tasks
These tasks are assigned to you and ready to be worked on because their dependencies have been completed:
These tasks are ready to be worked on because their dependencies have
been completed. You can only work on tasks assigned to you.
{% for task in tasks %}
{% if task.is_ready %}
Expand All @@ -74,7 +76,7 @@ class TasksTemplate(Template):
### Other tasks
These tasks are either not ready yet or are dependencies of other tasks. They are provided for context.
These tasks are provided for context only. They may be upstream or downstream of the active tasks.
{% for task in tasks %}
{% if not task.is_ready %}
Expand Down Expand Up @@ -131,7 +133,8 @@ class CommunicationTemplate(Template):
- You need to post a message or otherwise communicate to complete a
task. For example, the task instructs you to write, discuss, or
otherwise produce content (and does not accept a result, or the result
that meets the objective is different than the instructed actions).
that meets the objective is different than the instructed actions, or
multiple agents are assigned to the discussion).
- You need to communicate with other agents to complete a task.
- You want to write your thought process for future reference.
Expand Down Expand Up @@ -183,8 +186,10 @@ class ContextTemplate(Template):
Information about the flow and controller.
### Flow
{% if flow.name %} Flow name: {{ flow.name }} {% endif %}
{% if flow.description %} Flow description: {{ flow.description }} {% endif %}
{% if flow.name %}Flow name: {{ flow.name }} {% endif %}
{% if flow.description %}Flow description: {{ flow.description }} {% endif %}
Flow context:
{% for key, value in flow.context.items() %}
- *{{ key }}*: {{ value }}
Expand Down
Loading

0 comments on commit 33d668c

Please sign in to comment.