Skip to content

Commit

Permalink
Merge pull request #19 from jlowin/run
Browse files Browse the repository at this point in the history
Overhaul task / agent running
  • Loading branch information
jlowin authored May 11, 2024
2 parents 2fc74c4 + fe34efe commit 62d5d4e
Show file tree
Hide file tree
Showing 13 changed files with 430 additions and 249 deletions.
2 changes: 1 addition & 1 deletion examples/choose_a_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@ai_flow
def demo():
task = Task("choose a number between 1 and 100", agents=[a1, a2], result_type=int)
return task.run_until_complete()
return task.run()


demo()
2 changes: 1 addition & 1 deletion examples/multi_agent_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def demo():
agents=[jerry, george, elaine, kramer, newman],
context=dict(topic=topic),
)
task.run_until_complete(moderator=Moderator())
task.run(moderator=Moderator())


demo()
10 changes: 3 additions & 7 deletions examples/pineapple_pizza.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,14 @@
def demo():
topic = "pineapple on pizza"

task = Task(
"Discuss the topic",
agents=[a1, a2],
context={"topic": topic},
)
task = Task("Discuss the topic", agents=[a1, a2], context={"topic": topic})
with instructions("2 sentences max"):
task.run_until_complete()
task.run()

task2 = Task(
"which argument do you find more compelling?", [a1.name, a2.name], agents=[a3]
)
task2.run_until_complete()
task2.run()


demo()
2 changes: 1 addition & 1 deletion examples/readme_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def demo():
interests = Task(
"ask user for three interests", result_type=list[str], user_access=True
)
interests.run_until_complete()
interests.run()

# set instructions for just the next task
with instructions("no more than 8 lines"):
Expand Down
30 changes: 30 additions & 0 deletions examples/write_and_critique_paper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from control_flow import Agent, Task

writer = Agent(name="writer")
editor = Agent(name="editor", instructions="you always find at least one problem")
critic = Agent(name="critic")


# ai tasks:
# - automatically supply context from kwargs
# - automatically wrap sub tasks in parent
# - automatically iterate over sub tasks if they are all completed but the parent isn't?


def write_paper(topic: str) -> str:
"""
Write a paragraph on the topic
"""
draft = Task(
"produce a 3-sentence draft on the topic",
str,
agents=[writer],
context=dict(topic=topic),
)
edits = Task("edit the draft", str, agents=[editor], depends_on=[draft])
critique = Task("is it good enough?", bool, agents=[critic], depends_on=[edits])
return critique


task = write_paper("AI and the future of work")
task.run()
3 changes: 3 additions & 0 deletions src/control_flow/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ async def run_async(self, tasks: list[Task] | Task | None = None):

def __hash__(self):
return id(self)


DEFAULT_AGENT = Agent(name="Marvin")
198 changes: 107 additions & 91 deletions src/control_flow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import json
import logging
from typing import Callable
from typing import Any

import marvin.utilities
import marvin.utilities.tools
import prefect
from marvin.beta.assistants import PrintHandler, Run
from marvin.beta.assistants import EndRun, PrintHandler, Run
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from openai.types.beta.threads.runs import ToolCall
from prefect import get_client as get_prefect_client
from prefect import task as prefect_task
from prefect.context import FlowRunContext
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator

from control_flow.core.agent import Agent
from control_flow.core.flow import Flow
from control_flow.core.controller.moderators import marvin_moderator
from control_flow.core.flow import Flow, get_flow, get_flow_messages
from control_flow.core.graph import Graph
from control_flow.core.task import Task
from control_flow.instructions import get_instructions as get_context_instructions
from control_flow.instructions import get_instructions
from control_flow.utilities.prefect import (
create_json_artifact,
create_python_artifact,
Expand All @@ -39,116 +43,94 @@ class Controller(BaseModel, ExposeSyncMethodsMixin):
"""

flow: Flow
agents: list[Agent]
flow: Flow = Field(
default_factory=get_flow,
description="The flow that the controller is a part of.",
)
tasks: list[Task] = Field(
None,
description="Tasks that the controller will complete.",
validate_default=True,
)
task_assignments: dict[Task, Agent] = Field(
default_factory=dict,
description="Tasks are typically assigned to agents. To "
"temporarily assign agent to a task without changing "
r"the task definition, use this field as {task: [agent]}",
)
agents: list[Agent] | None = None
run_dependencies: bool = True
context: dict = {}
graph: Graph = None
model_config: dict = dict(extra="forbid")

@field_validator("agents", mode="before")
def _validate_agents(cls, v):
if not v:
raise ValueError("At least one agent is required.")
return v
@model_validator(mode="before")
@classmethod
def _create_graph(cls, data: Any) -> Any:
if not data.get("graph"):
data["graph"] = Graph.from_tasks(data.get("tasks", []))
return data

@field_validator("tasks", mode="before")
def _validate_tasks(cls, v):
if not v:
raise ValueError("At least one task is required.")
return v

@field_validator("tasks", mode="before")
def _load_tasks_from_ctx(cls, v):
if v is None:
v = cls.context.get("tasks", None)
if not v:
raise ValueError("At least one task is required.")
return v

def all_tasks(self) -> list[Task]:
tasks = []
for task in self.tasks:
tasks.extend(task.trace_dependencies())

# add temporary assignments
assigned_tasks = []
for task in set(tasks):
if task in assigned_tasks:
task = task.model_copy(
update={"agents": task.agents + self.task_assignments.get(task, [])}
)
assigned_tasks.append(task)
return assigned_tasks

@expose_sync_method("run_agent")
async def run_agent_async(self, agent: Agent):
"""
Run the control flow.
"""
if agent not in self.agents:
raise ValueError("Agent not found in controller agents.")
def _create_end_run_tool(self) -> FunctionTool:
def end_run():
raise EndRun()

prefect_task = await self._get_prefect_run_agent_task(agent)
await prefect_task(agent=agent)
return marvin.utilities.tools.tool_from_function(
end_run,
description="End your turn if you have no tasks to work on. Only call this tool in an emergency; otherwise you can end your turn normally.",
)

async def _run_agent(self, agent: Agent, thread: Thread = None) -> Run:
async def _run_agent(
self, agent: Agent, tasks: list[Task] = None, thread: Thread = None
) -> Run:
"""
Run a single agent.
"""
from control_flow.core.controller.instruction_template import MainTemplate

instructions_template = MainTemplate(
agent=agent,
controller=self,
context=self.context,
instructions=get_context_instructions(),
)
@prefect_task(task_run_name=f'Run Agent: "{agent.name}"')
async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):
from control_flow.core.controller.instruction_template import MainTemplate

instructions = instructions_template.render()

tools = self.flow.tools + agent.get_tools()

# add tools for any inactive tasks that the agent is assigned to
for task in self.all_tasks():
if task.is_incomplete() and agent in task.agents:
tools = tools + task.get_tools()

# filter tools because duplicate names are not allowed
final_tools = []
final_tool_names = set()
for tool in tools:
if isinstance(tool, FunctionTool):
if tool.function.name in final_tool_names:
continue
final_tool_names.add(tool.function.name)
final_tools.append(wrap_prefect_tool(tool))

run = Run(
assistant=agent,
thread=thread or self.flow.thread,
instructions=instructions,
tools=final_tools,
event_handler_class=AgentHandler,
)
tasks = tasks or self.tasks

await run.run_async()
tools = self.flow.tools + agent.get_tools() + [self._create_end_run_tool()]

return run
# add tools for any inactive tasks that the agent is assigned to
for task in tasks:
if agent in task.agents:
tools = tools + task.get_tools()

async def _get_prefect_run_agent_task(
self, agent: Agent, thread: Thread = None
) -> Callable:
@prefect_task(task_run_name=f'Run Agent: "{agent.name}"')
async def _run_agent(agent: Agent, thread: Thread = None):
run = await self._run_agent(agent=agent, thread=thread)
instructions_template = MainTemplate(
agent=agent,
controller=self,
tasks=tasks,
context=self.context,
instructions=get_instructions(),
)

instructions = instructions_template.render()

# filter tools because duplicate names are not allowed
final_tools = []
final_tool_names = set()
for tool in tools:
if isinstance(tool, FunctionTool):
if tool.function.name in final_tool_names:
continue
final_tool_names.add(tool.function.name)
final_tools.append(wrap_prefect_tool(tool))

run = Run(
assistant=agent,
thread=thread or self.flow.thread,
instructions=instructions,
tools=final_tools,
event_handler_class=AgentHandler,
)

await run.run_async()

create_json_artifact(
key="messages",
Expand All @@ -162,7 +144,41 @@ async def _run_agent(agent: Agent, thread: Thread = None):
)
return run

return _run_agent
return await _run_agent(agent=agent, tasks=tasks, thread=thread)

@expose_sync_method("run_once")
async def run_once_async(self):
# get the tasks to run
if self.run_dependencies:
tasks = self.graph.upstream_dependencies(self.tasks)
else:
tasks = self.tasks

# get the agents
if self.agents:
agents = self.agents
else:
# if we are running dependencies, only load agents for tasks that are ready
if self.run_dependencies:
agents = list({a for t in tasks for a in t.agents if t.is_ready()})
else:
agents = list({a for t in tasks for a in t.agents})

# select the next agent
if len(agents) == 0:
agent = Agent()
elif len(agents) == 1:
agent = agents[0]
else:
agent = marvin_moderator(
agents=agents,
tasks=tasks,
context=dict(
history=get_flow_messages(), instructions=get_instructions()
),
)

return await self._run_agent(agent, tasks=tasks)


class AgentHandler(PrintHandler):
Expand Down
Loading

0 comments on commit 62d5d4e

Please sign in to comment.