diff --git a/examples/choose_a_number.py b/examples/choose_a_number.py index de476b84..7e7f8758 100644 --- a/examples/choose_a_number.py +++ b/examples/choose_a_number.py @@ -10,7 +10,7 @@ @ai_flow def demo(): task = Task("choose a number between 1 and 100", agents=[a1, a2], result_type=int) - return task.run_until_complete() + return task.run() demo() diff --git a/examples/multi_agent_conversation.py b/examples/multi_agent_conversation.py index 211b6e72..55853181 100644 --- a/examples/multi_agent_conversation.py +++ b/examples/multi_agent_conversation.py @@ -70,7 +70,7 @@ def demo(): agents=[jerry, george, elaine, kramer, newman], context=dict(topic=topic), ) - task.run_until_complete(moderator=Moderator()) + task.run(moderator=Moderator()) demo() diff --git a/examples/pineapple_pizza.py b/examples/pineapple_pizza.py index 8452dda2..3f3721df 100644 --- a/examples/pineapple_pizza.py +++ b/examples/pineapple_pizza.py @@ -27,18 +27,14 @@ def demo(): topic = "pineapple on pizza" - task = Task( - "Discuss the topic", - agents=[a1, a2], - context={"topic": topic}, - ) + task = Task("Discuss the topic", agents=[a1, a2], context={"topic": topic}) with instructions("2 sentences max"): - task.run_until_complete() + task.run() task2 = Task( "which argument do you find more compelling?", [a1.name, a2.name], agents=[a3] ) - task2.run_until_complete() + task2.run() demo() diff --git a/examples/readme_example.py b/examples/readme_example.py index f1cd8d0f..2ab08991 100644 --- a/examples/readme_example.py +++ b/examples/readme_example.py @@ -29,7 +29,7 @@ def demo(): interests = Task( "ask user for three interests", result_type=list[str], user_access=True ) - interests.run_until_complete() + interests.run() # set instructions for just the next task with instructions("no more than 8 lines"): diff --git a/examples/write_and_critique_paper.py b/examples/write_and_critique_paper.py new file mode 100644 index 00000000..a2a0f7d7 --- /dev/null +++ b/examples/write_and_critique_paper.py @@ -0,0 +1,30 @@ +from control_flow import Agent, Task + +writer = Agent(name="writer") +editor = Agent(name="editor", instructions="you always find at least one problem") +critic = Agent(name="critic") + + +# ai tasks: +# - automatically supply context from kwargs +# - automatically wrap sub tasks in parent +# - automatically iterate over sub tasks if they are all completed but the parent isn't? + + +def write_paper(topic: str) -> str: + """ + Write a paragraph on the topic + """ + draft = Task( + "produce a 3-sentence draft on the topic", + str, + agents=[writer], + context=dict(topic=topic), + ) + edits = Task("edit the draft", str, agents=[editor], depends_on=[draft]) + critique = Task("is it good enough?", bool, agents=[critic], depends_on=[edits]) + return critique + + +task = write_paper("AI and the future of work") +task.run() diff --git a/src/control_flow/core/agent.py b/src/control_flow/core/agent.py index eb7d15f0..c43a6464 100644 --- a/src/control_flow/core/agent.py +++ b/src/control_flow/core/agent.py @@ -42,3 +42,6 @@ async def run_async(self, tasks: list[Task] | Task | None = None): def __hash__(self): return id(self) + + +DEFAULT_AGENT = Agent(name="Marvin") diff --git a/src/control_flow/core/controller/controller.py b/src/control_flow/core/controller/controller.py index 2eaecd94..5225d013 100644 --- a/src/control_flow/core/controller/controller.py +++ b/src/control_flow/core/controller/controller.py @@ -1,20 +1,24 @@ import json import logging -from typing import Callable +from typing import Any +import marvin.utilities +import marvin.utilities.tools import prefect -from marvin.beta.assistants import PrintHandler, Run +from marvin.beta.assistants import EndRun, PrintHandler, Run from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from openai.types.beta.threads.runs import ToolCall from prefect import get_client as get_prefect_client from prefect import task as prefect_task from prefect.context import FlowRunContext -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from control_flow.core.agent import Agent -from control_flow.core.flow import Flow +from control_flow.core.controller.moderators import marvin_moderator +from control_flow.core.flow import Flow, get_flow, get_flow_messages +from control_flow.core.graph import Graph from control_flow.core.task import Task -from control_flow.instructions import get_instructions as get_context_instructions +from control_flow.instructions import get_instructions from control_flow.utilities.prefect import ( create_json_artifact, create_python_artifact, @@ -39,116 +43,94 @@ class Controller(BaseModel, ExposeSyncMethodsMixin): """ - flow: Flow - agents: list[Agent] + flow: Flow = Field( + default_factory=get_flow, + description="The flow that the controller is a part of.", + ) tasks: list[Task] = Field( None, description="Tasks that the controller will complete.", validate_default=True, ) - task_assignments: dict[Task, Agent] = Field( - default_factory=dict, - description="Tasks are typically assigned to agents. To " - "temporarily assign agent to a task without changing " - r"the task definition, use this field as {task: [agent]}", - ) + agents: list[Agent] | None = None + run_dependencies: bool = True context: dict = {} + graph: Graph = None model_config: dict = dict(extra="forbid") - @field_validator("agents", mode="before") - def _validate_agents(cls, v): - if not v: - raise ValueError("At least one agent is required.") - return v + @model_validator(mode="before") + @classmethod + def _create_graph(cls, data: Any) -> Any: + if not data.get("graph"): + data["graph"] = Graph.from_tasks(data.get("tasks", [])) + return data @field_validator("tasks", mode="before") def _validate_tasks(cls, v): - if not v: - raise ValueError("At least one task is required.") - return v - - @field_validator("tasks", mode="before") - def _load_tasks_from_ctx(cls, v): if v is None: v = cls.context.get("tasks", None) + if not v: + raise ValueError("At least one task is required.") return v - def all_tasks(self) -> list[Task]: - tasks = [] - for task in self.tasks: - tasks.extend(task.trace_dependencies()) - - # add temporary assignments - assigned_tasks = [] - for task in set(tasks): - if task in assigned_tasks: - task = task.model_copy( - update={"agents": task.agents + self.task_assignments.get(task, [])} - ) - assigned_tasks.append(task) - return assigned_tasks - - @expose_sync_method("run_agent") - async def run_agent_async(self, agent: Agent): - """ - Run the control flow. - """ - if agent not in self.agents: - raise ValueError("Agent not found in controller agents.") + def _create_end_run_tool(self) -> FunctionTool: + def end_run(): + raise EndRun() - prefect_task = await self._get_prefect_run_agent_task(agent) - await prefect_task(agent=agent) + return marvin.utilities.tools.tool_from_function( + end_run, + description="End your turn if you have no tasks to work on. Only call this tool in an emergency; otherwise you can end your turn normally.", + ) - async def _run_agent(self, agent: Agent, thread: Thread = None) -> Run: + async def _run_agent( + self, agent: Agent, tasks: list[Task] = None, thread: Thread = None + ) -> Run: """ Run a single agent. """ - from control_flow.core.controller.instruction_template import MainTemplate - instructions_template = MainTemplate( - agent=agent, - controller=self, - context=self.context, - instructions=get_context_instructions(), - ) + @prefect_task(task_run_name=f'Run Agent: "{agent.name}"') + async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None): + from control_flow.core.controller.instruction_template import MainTemplate - instructions = instructions_template.render() - - tools = self.flow.tools + agent.get_tools() - - # add tools for any inactive tasks that the agent is assigned to - for task in self.all_tasks(): - if task.is_incomplete() and agent in task.agents: - tools = tools + task.get_tools() - - # filter tools because duplicate names are not allowed - final_tools = [] - final_tool_names = set() - for tool in tools: - if isinstance(tool, FunctionTool): - if tool.function.name in final_tool_names: - continue - final_tool_names.add(tool.function.name) - final_tools.append(wrap_prefect_tool(tool)) - - run = Run( - assistant=agent, - thread=thread or self.flow.thread, - instructions=instructions, - tools=final_tools, - event_handler_class=AgentHandler, - ) + tasks = tasks or self.tasks - await run.run_async() + tools = self.flow.tools + agent.get_tools() + [self._create_end_run_tool()] - return run + # add tools for any inactive tasks that the agent is assigned to + for task in tasks: + if agent in task.agents: + tools = tools + task.get_tools() - async def _get_prefect_run_agent_task( - self, agent: Agent, thread: Thread = None - ) -> Callable: - @prefect_task(task_run_name=f'Run Agent: "{agent.name}"') - async def _run_agent(agent: Agent, thread: Thread = None): - run = await self._run_agent(agent=agent, thread=thread) + instructions_template = MainTemplate( + agent=agent, + controller=self, + tasks=tasks, + context=self.context, + instructions=get_instructions(), + ) + + instructions = instructions_template.render() + + # filter tools because duplicate names are not allowed + final_tools = [] + final_tool_names = set() + for tool in tools: + if isinstance(tool, FunctionTool): + if tool.function.name in final_tool_names: + continue + final_tool_names.add(tool.function.name) + final_tools.append(wrap_prefect_tool(tool)) + + run = Run( + assistant=agent, + thread=thread or self.flow.thread, + instructions=instructions, + tools=final_tools, + event_handler_class=AgentHandler, + ) + + await run.run_async() create_json_artifact( key="messages", @@ -162,7 +144,41 @@ async def _run_agent(agent: Agent, thread: Thread = None): ) return run - return _run_agent + return await _run_agent(agent=agent, tasks=tasks, thread=thread) + + @expose_sync_method("run_once") + async def run_once_async(self): + # get the tasks to run + if self.run_dependencies: + tasks = self.graph.upstream_dependencies(self.tasks) + else: + tasks = self.tasks + + # get the agents + if self.agents: + agents = self.agents + else: + # if we are running dependencies, only load agents for tasks that are ready + if self.run_dependencies: + agents = list({a for t in tasks for a in t.agents if t.is_ready()}) + else: + agents = list({a for t in tasks for a in t.agents}) + + # select the next agent + if len(agents) == 0: + agent = Agent() + elif len(agents) == 1: + agent = agents[0] + else: + agent = marvin_moderator( + agents=agents, + tasks=tasks, + context=dict( + history=get_flow_messages(), instructions=get_instructions() + ), + ) + + return await self._run_agent(agent, tasks=tasks) class AgentHandler(PrintHandler): diff --git a/src/control_flow/core/controller/instruction_template.py b/src/control_flow/core/controller/instruction_template.py index d18799e1..3d13d2a3 100644 --- a/src/control_flow/core/controller/instruction_template.py +++ b/src/control_flow/core/controller/instruction_template.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from control_flow.core.agent import Agent +from control_flow.core.task import Task from control_flow.utilities.jinja import jinja_env from control_flow.utilities.types import ControlFlowModel @@ -60,17 +61,16 @@ class TasksTemplate(Template): template: str = """ ## Tasks + ### Your assignments + You have been assigned to complete certain tasks. Each task has an objective and criteria for success. Your job is to perform any required actions and then mark each assigned task as successful. If a task also - requires a result, you must provide it. + requires a result, you must provide it. Only work on tasks that are + assigned to you. If the task requires a result, do not also post the + result in a message, as this would be redundant. Messages should be used + only to provide context that is not captured in task results. - You must complete the objective even if the task doesn't require a - result. For example, a task that asks you to choose, discuss, or perform - an action must be completed by posting messages before the task is - marked complete. The objective may require participation from multiple - agents. Do not mark a task as complete until the objective is fully met. - A "parent" is a task that spawned another task as a subtask. Generally, the subtasks will need to be completed BEFORE the parent task. If you can complete a parent task before its subtasks, you should mark the @@ -80,6 +80,26 @@ class TasksTemplate(Template): before the task itself can be completed. The `mark_success` tool will not be available until all dependencies are met. + ### Current tasks + + At the start of your turn, these are all the tasks being worked on. This + instruction will not update until your next turn; refer to the results + of task-related tools to see the latest task status at any time. + + {% for task in tasks %} + #### Task {{ task.id }} + {{ task.model_dump_json() }} + + {% endfor %} + + ### Completing a task + + You must complete the objective even if the task doesn't require a + result. For example, a task that asks you to choose, discuss, or perform + an action must be completed by posting messages before the task is + marked complete. The objective may require participation from multiple + agents. Do not mark a task as complete until the objective is fully met. + Some tasks may require collaboration with other agents to be completed; others may take you multiple attempts. A task can only be marked complete one time, so make sure you perform all actions and collect all results before @@ -87,24 +107,19 @@ class TasksTemplate(Template): (both you and others) can no longer act on it. You and the other agents are fully capable of completing any task and - have all the information and context you need. Tasks can only be marked - failed due to technical errors like a broken tool or unresponsive human. + have all the information and context you need. Tasks should only be marked + failed due to technical errors like a broken or erroring tool or unresponsive human. Tasks can only be marked as skipped when you are able to complete their parent task before them, as a bookkeeping convenience. Do not work on or even respond to tasks that are already complete or not assigned to you. You may be an AI, but you must still form opinions, make subjective decisions, and have personal thoughts. - {% for task in controller.all_tasks() %} - ### Task {{ task.id }} - {{ task.model_dump_json() }} - - {% endfor %} """ - controller: Controller + tasks: list[Task] def should_render(self): - return any(self.controller.tasks) + return bool(self.tasks) class CommunicationTemplate(Template): @@ -179,18 +194,15 @@ class MainTemplate(BaseModel): controller: Controller context: dict instructions: list[str] + tasks: list[Task] def render(self): - all_agents = [self.agent] + self.controller.agents - for task in self.controller.tasks: - all_agents += task.agents - # other_agents = [agent for agent in all_agents if agent != self.agent] templates = [ AgentTemplate( agent=self.agent, ), TasksTemplate( - controller=self.controller, + tasks=self.tasks, ), InstructionsTemplate( agent=self.agent, diff --git a/src/control_flow/core/controller/moderators.py b/src/control_flow/core/controller/moderators.py index 8a1e52db..937ed9a5 100644 --- a/src/control_flow/core/controller/moderators.py +++ b/src/control_flow/core/controller/moderators.py @@ -7,7 +7,6 @@ from control_flow.core.agent import Agent from control_flow.core.flow import Flow, get_flow_messages from control_flow.core.task import Task -from control_flow.instructions import get_instructions if TYPE_CHECKING: from control_flow.core.agent import Agent @@ -57,29 +56,29 @@ def run(self, agents: list[Agent], tasks: list[Task]) -> Generator[Any, Any, Age agents=[self.agent], parent=None, ) - agent_name = task.run_until_complete() + agent_name = task.run() yield next(a for a in agents if a.name == agent_name) -class Moderator(BaseModerator): - model: str = None - - def run(self, agents: list[Agent], tasks: list[Task]) -> Generator[Any, Any, Agent]: - while True: - instructions = get_instructions() - history = get_flow_messages() - context = dict( - tasks=tasks, messages=history, global_instructions=instructions - ) - agent = marvin.classify( - context, - agents, - instructions=""" - Given the conversation context, choose the AI agent most - qualified to take the next turn at completing the tasks. Take into - account any tasks, history, instructions, and tools. - """, - model_kwargs=dict(model=self.model) if self.model else None, - ) - - yield agent +def marvin_moderator( + agents: list[Agent], + tasks: list[Task], + context: dict = None, + model: str = None, +) -> Agent: + context = context or {} + context.update(tasks=tasks) + agent = marvin.classify( + context, + agents, + instructions=""" + Given the context, choose the AI agent best suited to take the + next turn at completing the tasks in the task graph. Take into account + any descriptions, tasks, history, instructions, and tools. Focus on + agents assigned to upstream dependencies or subtasks that need to be + completed before their downstream/parents can be completed. An agent + can only work on a task that it is assigned to. + """, + model_kwargs=dict(model=model) if model else None, + ) + return agent diff --git a/src/control_flow/core/flow.py b/src/control_flow/core/flow.py index 1ca38b04..40b67b4f 100644 --- a/src/control_flow/core/flow.py +++ b/src/control_flow/core/flow.py @@ -48,6 +48,9 @@ def __exit__(self, *exc_info): return self.__cm.__exit__(*exc_info) +GLOBAL_FLOW = Flow() + + def get_flow() -> Flow: """ Loads the flow from the context. @@ -56,10 +59,15 @@ def get_flow() -> Flow: """ flow: Flow | None = ctx.get("flow") if not flow: - return Flow() + return GLOBAL_FLOW return flow +def reset_global_flow(): + global GLOBAL_FLOW + GLOBAL_FLOW = Flow() + + def get_flow_messages(limit: int = None) -> list[Message]: """ Loads messages from the flow's thread. diff --git a/src/control_flow/core/graph.py b/src/control_flow/core/graph.py new file mode 100644 index 00000000..639ee479 --- /dev/null +++ b/src/control_flow/core/graph.py @@ -0,0 +1,152 @@ +from enum import Enum + +from pydantic import BaseModel + +from control_flow.core.task import Task + + +class EdgeType(Enum): + """ + Edges represent the relationship between two tasks in a graph. + + - `DEPENDENCY_OF` means that the downstream task depends on the upstream task. + - `PARENT` means that the downstream task is the parent of the upstream task. + + Example: + + # write paper + ## write outline + ## write draft based on outline + + Edges: + outline -> paper # child_of (outline is a child of paper) + draft -> paper # child_of (draft is a child of paper) + outline -> draft # dependency_of (outline is a dependency of draft) + + """ + + DEPENDENCY_OF = "dependency_of" + CHILD_OF = "child_of" + + +class Edge(BaseModel): + upstream: Task + downstream: Task + type: EdgeType + + def __repr__(self): + return f"{self.type}: {self.upstream.id} -> {self.downstream.id}" + + def __hash__(self) -> int: + return id(self) + + +class Graph(BaseModel): + tasks: set[Task] = set() + edges: set[Edge] = set() + _cache: dict[str, dict[Task, list[Task]]] = {} + + def __init__(self): + super().__init__() + + @classmethod + def from_tasks(cls, tasks: list[Task]) -> "Graph": + graph = cls() + for task in tasks: + graph.add_task(task) + return graph + + def add_task(self, task: Task): + if task in self.tasks: + return + self.tasks.add(task) + if task.parent: + self.add_edge( + Edge( + upstream=task.parent, + downstream=task, + type=EdgeType.CHILD_OF, + ) + ) + if task.depends_on: + for upstream in task.depends_on: + self.add_edge( + Edge( + upstream=upstream, + downstream=task, + type=EdgeType.DEPENDENCY_OF, + ) + ) + self._cache.clear() + + def add_edge(self, edge: Edge): + if edge in self.edges: + return + self.edges.add(edge) + self.add_task(edge.upstream) + self.add_task(edge.downstream) + self._cache.clear() + + def upstream_edges(self) -> dict[Task, list[Edge]]: + if "upstream_edges" not in self._cache: + graph = {} + for task in self.tasks: + graph[task] = [] + for edge in self.edges: + graph[edge.downstream].append(edge) + self._cache["upstream_edges"] = graph + return self._cache["upstream_edges"] + + def downstream_edges(self) -> dict[Task, list[Edge]]: + if "downstream_edges" not in self._cache: + graph = {} + for task in self.tasks: + graph[task] = [] + for edge in self.edges: + graph[edge.upstream].append(edge) + self._cache["downstream_edges"] = graph + return self._cache["downstream_edges"] + + def upstream_dependencies( + self, tasks: list[Task], prune_completed: bool = True + ) -> list[Task]: + """ + From a list of tasks, returns the subgraph of tasks that are directly or + indirectly dependencies of those tasks. A dependency means following + upstream edges, so it includes tasks that are considered explicit + dependencies as well as any subtasks that are considered implicit + dependencies. + + If `prune_completed` is True, the subgraph will be pruned to stop traversal after adding any completed tasks. + """ + subgraph = set() + upstreams = self.upstream_edges() + stack = tasks + while stack: + current = stack.pop() + if current in subgraph: + continue + + subgraph.add(current) + # if prune_completed, stop traversal if the current task is complete + if prune_completed and current.is_complete(): + continue + stack.extend([edge.upstream for edge in upstreams[current]]) + + return list(subgraph) + + def ready_tasks(self, tasks: list[Task] = None) -> list[Task]: + """ + Returns a list of tasks that are ready to run, meaning that all of their + dependencies have been completed. If `tasks` is provided, only tasks in + the upstream dependency subgraph of those tasks will be considered. + + Ready tasks will be returned in the order they were created. + """ + if tasks is None: + candidates = self.tasks + else: + candidates = self.upstream_dependencies(tasks) + return sorted( + [task for task in candidates if task.is_ready()], key=lambda t: t.created_at + ) diff --git a/src/control_flow/core/task.py b/src/control_flow/core/task.py index 5d6f15b7..b8d9d7de 100644 --- a/src/control_flow/core/task.py +++ b/src/control_flow/core/task.py @@ -1,10 +1,10 @@ +import datetime import uuid from contextlib import contextmanager from enum import Enum from typing import ( TYPE_CHECKING, Callable, - Generator, GenericAlias, Literal, TypeVar, @@ -30,6 +30,7 @@ if TYPE_CHECKING: from control_flow.core.agent import Agent + from control_flow.core.graph import Graph T = TypeVar("T") logger = get_logger(__name__) @@ -46,9 +47,13 @@ class TaskStatus(Enum): class Task(ControlFlowModel): id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:4])) - objective: str - instructions: str | None = None - agents: list["Agent"] = [] + objective: str = Field( + ..., description="A brief description of the required result." + ) + instructions: str | None = Field( + None, description="Detailed instructions for completing the task." + ) + agents: list["Agent"] = Field(None, validate_default=True) context: dict = {} parent: "Task | None" = Field( NOTSET, @@ -62,6 +67,7 @@ class Task(ControlFlowModel): error: str | None = None tools: list[AssistantTool | Callable] = [] user_access: bool = False + created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) _children: list["Task"] = [] _downstream: list["Task"] = [] model_config = dict(extra="forbid", arbitrary_types_allowed=True) @@ -74,9 +80,16 @@ def __init__(self, objective=None, result_type=None, **kwargs): # allow certain args to be provided as a positional args super().__init__(**kwargs) + def __repr__(self): + return str(self.model_dump()) + @field_validator("agents", mode="before") - def _turn_none_into_empty_list(cls, v): - return v or [] + def _default_agent(cls, v): + if v is None: + from control_flow.core.agent import DEFAULT_AGENT + + return [DEFAULT_AGENT] + return v @field_validator("result_type", mode="before") def _turn_list_into_literal_result_type(cls, v): @@ -121,6 +134,11 @@ def _serialize_agents(agents: list["Agent"]): for a in agents ] + def as_graph(self) -> "Graph": + from control_flow.core.graph import Graph + + return Graph.from_tasks(tasks=[self]) + def trace_dependencies(self) -> list["Task"]: """ Returns a list of all tasks related to this task, including upstream and downstream tasks, parents, and children. @@ -152,47 +170,28 @@ def trace_dependencies(self) -> list["Task"]: return list(tasks) - def dependency_agents(self) -> list["Agent"]: - deps = self.trace_dependencies() - agents = [] - for task in deps: - agents.extend(task.agents) - return agents - - def run(self, agent: "Agent" = None): + def run_once(self, agent: "Agent" = None, run_dependencies: bool = True): """ - Runs the task with provided agent. If no agent is provided, a default agent is used. + Runs the task with provided agent. If no agent is provided, one will be selected from the task's agents. """ - from control_flow.core.agent import Agent - - if agent is None: - all_agents = self.dependency_agents() - if not all_agents: - agent = Agent() - elif len(all_agents) == 1: - agent = all_agents[0] - else: - raise ValueError( - f"Task {self.id} has multiple agents assigned to it or its " - "children. Please specify one to run the task or call run_until_complete()." - ) - - run_gen = run_iter(tasks=[self], agents=[agent]) - return next(run_gen) - - def run_until_complete( - self, - agents: list["Agent"] = None, - moderator: Callable[[list["Agent"]], Generator[None, None, "Agent"]] = None, - ) -> T: + from control_flow.core.controller import Controller + + controller = Controller( + tasks=[self], agents=agent, run_dependencies=run_dependencies + ) + + controller.run_once() + + def run(self, run_dependencies: bool = True) -> T: """ Runs the task with provided agents until it is complete. """ - - run_until_complete(tasks=[self], agents=agents, moderator=moderator) - if self.is_failed(): - raise ValueError(f"Task {self.id} failed: {self.error}") - return self.result + while self.is_incomplete(): + self.run_once(run_dependencies=run_dependencies) + if self.is_successful(): + return self.result + elif self.is_failed(): + raise ValueError(f"Task {self.id} failed: {self.error}") @contextmanager def _context(self): @@ -223,6 +222,12 @@ def is_failed(self) -> bool: def is_skipped(self) -> bool: return self.status == TaskStatus.SKIPPED + def is_ready(self) -> bool: + """ + Returns True if all dependencies are complete and this task is incomplete. + """ + return self.is_incomplete() and all(t.is_complete() for t in self.depends_on) + def __hash__(self): return id(self) @@ -232,14 +237,13 @@ def _create_success_tool(self) -> FunctionTool: """ # wrap the method call to get the correct result type signature - def succeed(result: self.result_type): - # validate the result - self.mark_successful(result=result) + def succeed(result: self.result_type) -> str: + return self.mark_successful(result=result) tool = marvin.utilities.tools.tool_from_function( succeed, - name=f"succeed_task_{self.id}", - description=f"Mark task {self.id} as successful and provide a result.", + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful and optionally provide a result.", ) return tool @@ -250,8 +254,8 @@ def _create_fail_tool(self) -> FunctionTool: """ tool = marvin.utilities.tools.tool_from_function( self.mark_failed, - name=f"fail_task_{self.id}", - description=f"Mark task {self.id} as failed. Only use when a technical issue prevents completion.", + name=f"mark_task_{self.id}_failed", + description=f"Mark task {self.id} as failed. Only use when a technical issue like a broken tool or unresponsive human prevents completion.", ) return tool @@ -261,7 +265,7 @@ def _create_skip_tool(self) -> FunctionTool: """ tool = marvin.utilities.tools.tool_from_function( self.mark_skipped, - name=f"skip_task_{self.id}", + name=f"mark_task_{self.id}_skipped", description=f"Mark task {self.id} as skipped. Only use when completing its parent task early.", ) return tool @@ -269,34 +273,41 @@ def _create_skip_tool(self) -> FunctionTool: def get_tools(self) -> list[AssistantTool | Callable]: tools = self.tools.copy() if self.is_incomplete(): - tools.append(self._create_fail_tool()) + tools.extend([self._create_fail_tool(), self._create_success_tool()]) # add skip tool if this task has a parent task if self.parent is not None: tools.append(self._create_skip_tool()) - # add success tools if this task has no upstream tasks or all upstream tasks are complete - if all(t.is_complete() for t in self.depends_on): - tools.append(self._create_success_tool()) if self.user_access: tools.append(marvin.utilities.tools.tool_from_function(talk_to_human)) return [wrap_prefect_tool(t) for t in tools] - def mark_successful(self, result: T = None): + def mark_successful(self, result: T = None, validate_upstreams: bool = True): if self.result_type is None and result is not None: raise ValueError( f"Task {self.objective} specifies no result type, but a result was provided." ) elif self.result_type is not None: + if validate_upstreams: + if any(t.is_incomplete() for t in self.depends_on): + raise ValueError( + f"Task {self.objective} cannot be marked successful until all of its " + "upstream dependencies are completed. Incomplete dependencies " + f"are: {[t for t in self.depends_on if t.is_incomplete()]}" + ) result = TypeAdapter(self.result_type).validate_python(result) self.result = result self.status = TaskStatus.SUCCESSFUL + return f"Task {self.id} marked successful. Updated task definition: {self.model_dump()}" def mark_failed(self, message: str | None = None): self.error = message self.status = TaskStatus.FAILED + return f"Task {self.id} marked failed. Updated task definition: {self.model_dump()}" def mark_skipped(self): self.status = TaskStatus.SKIPPED + return f"Task {self.id} marked skipped. Updated task definition: {self.model_dump()}" def any_incomplete(tasks: list[Task]) -> bool: @@ -317,49 +328,3 @@ def any_failed(tasks: list[Task]) -> bool: def none_failed(tasks: list[Task]) -> bool: return not any_failed(tasks) - - -def run_iter( - tasks: list["Task"], - agents: list["Agent"] = None, - moderator: Callable[[list["Agent"]], Generator[None, None, "Agent"]] = None, -): - from control_flow.core.controller.moderators import round_robin - - if moderator is None: - moderator = round_robin - - if agents is None: - agents = list(set([a for t in tasks for a in t.dependency_agents()])) - - if not agents: - raise ValueError("Tasks have no agents assigned. Please specify agents.") - - all_tasks = list(set([a for t in tasks for a in t.trace_dependencies()])) - - for agent in moderator(agents, tasks=all_tasks): - if any(t.is_failed() for t in tasks): - break - elif all(t.is_complete() for t in tasks): - break - agent.run(tasks=all_tasks) - yield True - - -def run_until_complete( - tasks: list["Task"], - agents: list["Agent"] = None, - moderator: Callable[[list["Agent"]], Generator[None, None, "Agent"]] = None, - raise_on_error: bool = True, -) -> T: - """ - Runs the task with provided agents until it is complete. - """ - - for _ in run_iter(tasks=tasks, agents=agents, moderator=moderator): - continue - - if raise_on_error and any(t.is_failed() for t in tasks): - raise ValueError( - f"At least one task failed: {', '.join(t.id for t in tasks if t.is_failed())}" - ) diff --git a/src/control_flow/dx.py b/src/control_flow/dx.py index e061c4a7..b57fc187 100644 --- a/src/control_flow/dx.py +++ b/src/control_flow/dx.py @@ -109,7 +109,7 @@ def wrapper(*args, _agents: list[Agent] = None, **kwargs): tools=tools or [], ) - task.run_until_complete() + task.run() return task.result return wrapper