diff --git a/README.md b/README.md index 9462058b..ce487ac0 100644 --- a/README.md +++ b/README.md @@ -79,4 +79,8 @@ if __name__ == "__main__": demo() ``` + + + + image diff --git a/examples/choose_a_number.py b/examples/choose_a_number.py new file mode 100644 index 00000000..36609056 --- /dev/null +++ b/examples/choose_a_number.py @@ -0,0 +1,21 @@ +from control_flow import Agent, Task, ai_flow + +a1 = Agent(name="A1", instructions="You struggle to make decisions.") +a2 = Agent( + name="A2", + instructions="You like to make decisions.", +) + + +@ai_flow +def demo(): + task = Task("Choose a number between 1 and 100", agents=[a1, a2], result_type=int) + + while task.is_incomplete(): + a1.run(task) + a2.run(task) + + return task + + +demo() diff --git a/examples/pineapple_pizza.py b/examples/pineapple_pizza.py new file mode 100644 index 00000000..098518e0 --- /dev/null +++ b/examples/pineapple_pizza.py @@ -0,0 +1,38 @@ +from control_flow import Agent, Task, ai_flow +from control_flow.instructions import instructions + +a1 = Agent( + name="Half-full", + instructions="You are an ardent fan and hype-man of whatever topic" + " the user asks you for information on." + " Purely positive, though thorough in your debating skills.", +) +a2 = Agent( + name="Half-empty", + instructions="You are a critic and staunch detractor of whatever topic" + " the user asks you for information on." + " Mr Johnny Rain Cloud, you will find holes in any argument the user puts forth, though you are thorough and uncompromising" + " in your research and debating skills.", +) + + +@ai_flow +def demo(): + user_message = "pineapple on pizza" + + with instructions("one sentence max"): + task = Task( + "All agents must give an argument based on the user message", + agents=[a1, a2], + context={"user_message": user_message}, + ) + task.run_until_complete() + + task2 = Task( + "Post a message saying which argument about the user message is more compelling?" + ) + while task2.is_incomplete(): + task2.run(agents=[Agent(instructions="you always pick a side")]) + + +demo() diff --git a/src/control_flow/__init__.py b/src/control_flow/__init__.py index 6f3237e5..4a8e2bcb 100644 --- a/src/control_flow/__init__.py +++ b/src/control_flow/__init__.py @@ -3,6 +3,7 @@ # from .agent_old import ai_task, Agent, run_ai from .core.flow import Flow from .core.agent import Agent +from .core.task import Task from .core.controller.controller import Controller from .instructions import instructions from .dx import ai_flow, run_ai, ai_task diff --git a/src/control_flow/agents/__init__.py b/src/control_flow/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/control_flow/agents/agents.py b/src/control_flow/agents/agents.py new file mode 100644 index 00000000..e24c835c --- /dev/null +++ b/src/control_flow/agents/agents.py @@ -0,0 +1,42 @@ +import marvin + +from control_flow.core.agent import Agent +from control_flow.instructions import get_instructions +from control_flow.utilities.context import ctx +from control_flow.utilities.threads import get_history + + +def choose_agent( + agents: list[Agent], + instructions: str = None, + context: dict = None, + model: str = None, +) -> Agent: + """ + Given a list of potential agents, choose the most qualified assistant to complete the tasks. + """ + + instructions = get_instructions() + history = [] + if (flow := ctx.get("flow")) and flow.thread.id: + history = get_history(thread_id=flow.thread.id) + + info = dict( + history=history, + global_instructions=instructions, + context=context, + ) + + agent = marvin.classify( + info, + agents, + instructions=""" + Given the conversation context, choose the AI agent most + qualified to take the next turn at completing the tasks. Take into + account the instructions, each agent's own instructions, and the + tools they have available. + """, + model_kwargs=dict(model=model), + ) + + return agent diff --git a/src/control_flow/core/agent.py b/src/control_flow/core/agent.py index dc7a6860..7f8c003b 100644 --- a/src/control_flow/core/agent.py +++ b/src/control_flow/core/agent.py @@ -1,10 +1,12 @@ import logging -from enum import Enum from typing import Callable +from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.tools import tool_from_function from pydantic import Field +from control_flow.core.flow import get_flow +from control_flow.core.task import Task from control_flow.utilities.prefect import ( wrap_prefect_tool, ) @@ -14,12 +16,8 @@ logger = logging.getLogger(__name__) -class AgentStatus(Enum): - INCOMPLETE = "incomplete" - COMPLETE = "complete" - - -class Agent(Assistant, ControlFlowModel): +class Agent(Assistant, ControlFlowModel, ExposeSyncMethodsMixin): + name: str = "Agent" user_access: bool = Field( False, description="If True, the agent is given tools for interacting with a human user.", @@ -35,3 +33,13 @@ def get_tools(self) -> list[AssistantTool | Callable]: tools.append(tool_from_function(talk_to_human)) return [wrap_prefect_tool(tool) for tool in tools] + + @expose_sync_method("run") + async def run_async(self, tasks: list[Task] | Task | None = None): + from control_flow.core.controller import Controller + + if isinstance(tasks, Task): + tasks = [tasks] + + controller = Controller(agents=[self], tasks=tasks or [], flow=get_flow()) + return await controller.run_agent_async(agent=self) diff --git a/src/control_flow/core/controller/controller.py b/src/control_flow/core/controller/controller.py index 5c1b3a9a..0f8253d9 100644 --- a/src/control_flow/core/controller/controller.py +++ b/src/control_flow/core/controller/controller.py @@ -12,12 +12,8 @@ from pydantic import BaseModel, Field, field_validator, model_validator from control_flow.core.agent import Agent -from control_flow.core.controller.delegation import ( - DelegationStrategy, - RoundRobin, -) from control_flow.core.flow import Flow -from control_flow.core.task import Task, TaskStatus +from control_flow.core.task import Task from control_flow.instructions import get_instructions as get_context_instructions from control_flow.utilities.prefect import ( create_json_artifact, @@ -30,20 +26,26 @@ class Controller(BaseModel, ExposeSyncMethodsMixin): + """ + A controller contains logic for executing agents with context about the + larger workflow, including the flow itself, any tasks, and any other agents + they are collaborating with. The controller is responsible for orchestrating + agent behavior by generating instructions and tools for each agent. Note + that while the controller accepts details about (potentially multiple) + agents and tasks, it's responsiblity is to invoke one agent one time. Other + mechanisms should be used to orchestrate multiple agents invocations. This + is done by the controller to avoid tying e.g. agents to tasks or even a + specific flow. + + """ + flow: Flow agents: list[Agent] tasks: list[Task] = Field( description="Tasks that the controller will complete.", default_factory=list, ) - delegation_strategy: DelegationStrategy = Field( - validate_default=True, - description="The strategy for delegating work to assistants.", - default_factory=RoundRobin, - ) - # termination_strategy: TerminationStrategy context: dict = {} - instructions: str = None model_config: dict = dict(extra="forbid") @field_validator("agents", mode="before") @@ -58,48 +60,18 @@ def _add_tasks_to_flow(self) -> Self: self.flow.add_task(task) return self - @expose_sync_method("run") - async def run_async(self): + @expose_sync_method("run_agent") + async def run_agent_async(self, agent: Agent): """ Run the control flow. """ + if agent not in self.agents: + raise ValueError("Agent not found in controller agents.") - # continue as long as there are incomplete tasks - while any([t for t in self.tasks if t.status == TaskStatus.PENDING]): - # select the next agent - if len(self.agents) > 1: - agent = self.delegation_strategy(self.agents) - else: - agent = self.agents[0] - if not agent: - return - - # run the agent - task = await self._get_prefect_run_agent_task(agent) - task(agent=agent) - - async def _get_prefect_run_agent_task( - self, agent: Agent, thread: Thread = None - ) -> Callable: - @prefect_task(task_run_name=f'Run Agent: "{agent.name}"') - async def _run_agent(agent: Agent, thread: Thread = None): - run = await self.run_agent(agent=agent, thread=thread) - - create_json_artifact( - key="messages", - data=[m.model_dump() for m in run.messages], - description="All messages sent and received during the run.", - ) - create_json_artifact( - key="actions", - data=[s.model_dump() for s in run.steps], - description="All actions taken by the assistant during the run.", - ) - return run - - return _run_agent + task = await self._get_prefect_run_agent_task(agent) + await task(agent=agent) - async def run_agent(self, agent: Agent, thread: Thread = None) -> Run: + async def _run_agent(self, agent: Agent, thread: Thread = None) -> Run: """ Run a single agent. """ @@ -142,6 +114,27 @@ async def run_agent(self, agent: Agent, thread: Thread = None) -> Run: return run + async def _get_prefect_run_agent_task( + self, agent: Agent, thread: Thread = None + ) -> Callable: + @prefect_task(task_run_name=f'Run Agent: "{agent.name}"') + async def _run_agent(agent: Agent, thread: Thread = None): + run = await self._run_agent(agent=agent, thread=thread) + + create_json_artifact( + key="messages", + data=[m.model_dump() for m in run.messages], + description="All messages sent and received during the run.", + ) + create_json_artifact( + key="actions", + data=[s.model_dump() for s in run.steps], + description="All actions taken by the assistant during the run.", + ) + return run + + return _run_agent + def task_ids(self) -> dict[Task, int]: return {task: self.flow.get_task_id(task) for task in self.tasks} diff --git a/src/control_flow/core/controller/instruction_template.py b/src/control_flow/core/controller/instruction_template.py index 26930bca..5f63d07a 100644 --- a/src/control_flow/core/controller/instruction_template.py +++ b/src/control_flow/core/controller/instruction_template.py @@ -24,20 +24,28 @@ def render(self) -> str: class AgentTemplate(Template): template: str = """ - You are an AI agent. Your name is "{{ agent.name }}". - + You are an AI agent. Your name is "{{ agent.name }}". {% if agent.description %} + Your description: "{{ agent.description }}" + {% endif -%} + {% if agent.instructions %} + Your instructions: "{{ agent.instructions }}" + {% endif -%} - The following description has been provided for you: - {{ agent.description }} - {% endif -%} + You have been created by a program to complete certain tasks. Each task has + an objective and criteria for success. Your job is to perform any required + actions and then mark each task as successful. If a task also requires a + result, you must provide it; this is how the program receives data from you + as it can not read your messages. - Your job is to work on any pending tasks until you can mark them as either - `complete` or `failed`. The following instructions will provide you with all - the context you need to complete your tasks. Note that using a tool to - complete or fail a task is your ultimate objective, and you should not post - any messages to the thread unless you have a specific reason to do so. + Some tasks may require collaboration before they are complete; others may + take multiple iterations. You are fully capable of completing any task and + have all the information and context you need. Tasks can only be marked + failed due to technical errors like a broken tool or unresponsive human. You + must make a subjective decision if a task requires it. Do not work on or + even respond to tasks that are already complete. + """ agent: Agent @@ -46,21 +54,25 @@ class CommunicationTemplate(Template): template: str = """ ## Communciation - ### Posting messages to the thread + You should only post messages to the thread if you must send information to + other agents or if a task requires it. The human user can not see + these messages. Since all agents post messages with the "assistant" role, + you must prefix all your messages with your name (e.g. "{{ agent.name }}: + (message)") in order to distinguish your messages from others. Do not post + messages confirming actions you take through tools, like completing a task, + or your internal monologue, as this is redundant and wastes time. - You have been created by a Controller in the Python library ControlFlow in - order to complete various tasks or instructions. All messages in this thread - are either from the controller or from AI agents like you. Note that all - agents post to the thread with the `Assistant` role, so if you do need to - post a message, preface with your name (e.g. "{{ agent.name }}: Hello!") in - order to distinguish your messages. + ### Other agents assigned to your tasks - The controller CAN NOT and WILL NOT read your messages, so DO NOT post - messages unless you need to send information to another agent. DO NOT post - messages about information already captured by your tool calls, such as the - tool call itself, its result, human responses, or task completion. + {% for agent in other_agents %} - ### Talking to humans + - Name: {{agent.name}} + - Description: {{ agent.description if agent.description is not none else "No description provided." }} + - Can talk to human users: {{agent.user_access}} + + {% endfor %} + + ## Talking to human users {% if agent.user_access %} You may interact with a human user to complete your tasks by using the @@ -76,91 +88,30 @@ class CommunicationTemplate(Template): fail the task if you truly can not make progress. {% else %} You can not interact with a human at this time. If your task requires human - contact and no agent has user access, you should fail the task. + contact and no agent has user access, you should fail the task. Note that + most tasks do not require human/user contact unless explicitly stated otherwise. {% endif %} """ agent: Agent - - -class CollaborationTemplate(Template): - template: str = """ - ## Collaboration - - You are collaborating with other AI agents. They are listed below by name, - along with a brief description. Note that all agents post messages to the - same thread with the `Assistant` role, so pay attention to the name of the - agent that is speaking. Only one agent needs to indicate that a task is - complete. - - ### Agents - {% for agent in other_agents %} - - #### "{{agent.name}}" - Can talk to humans: {{agent.user_access}} - Description: {% if agent.description %}{{agent.description}}{% endif %} - - {% endfor %} - {% if not other_agents %} - (No other agents are currently participating in this workflow) - {% endif %} - """ other_agents: list[Agent] class InstructionsTemplate(Template): template: str = """ ## Instructions - - {% if flow_instructions -%} - ### Workflow instructions - - These instructions apply to the entire workflow: - - {{ flow_instructions }} - {% endif %} - - {% if controller_instructions -%} - ### Controller instructions - - These instructions apply to these tasks: - - {{ controller_instructions }} - {% endif %} - - {% if agent_instructions -%} - ### Agent instructions - - These instructions apply only to you: - - {{ agent_instructions }} - {% endif %} - - {% if additional_instructions -%} - ### Additional instructions - - These instructions were additionally provided for this part of the workflow: + + You must follow these instructions for this part of the workflow: {% for instruction in additional_instructions %} - {{ instruction }} {% endfor %} - {% endif %} """ - flow_instructions: str | None = None - controller_instructions: str | None = None - agent_instructions: str | None = None additional_instructions: list[str] = [] def should_render(self): - return any( - [ - self.flow_instructions, - self.controller_instructions, - self.agent_instructions, - self.additional_instructions, - ] - ) + return bool(self.additional_instructions) class TasksTemplate(Template): @@ -169,17 +120,16 @@ class TasksTemplate(Template): ### Active tasks - The following tasks are pending. You and any other agents are responsible - for completing them and will continue to be invoked until you mark each - task as either "completed" or "failed" with the appropriate tool. The - result of a complete task should be an artifact that fully represents - the completed objective. + The following tasks are incomplete. Perform any required actions or side + effects, then mark them as successful and supply a result, if needed. + Never mark a task successful until its objective is complete. A task + that doesn't require a result may still require action. Note: Task IDs are assigned for identification purposes only and will be resused after tasks complete. {% for task in controller.tasks %} - {% if task.status.value == "pending" %} + {% if task.status.value == "incomplete" %} #### Task {{ controller.flow.get_task_id(task) }} - Status: {{ task.status.value }} - Objective: {{ task.objective }} @@ -187,7 +137,7 @@ class TasksTemplate(Template): {% if task.instructions %} - Instructions: {{ task.instructions }} {% endif %} - {% if task.status.value == "completed" %} + {% if task.status.value == "successful" %} - Result: {{ task.result }} {% elif task.status.value == "failed" %} - Error: {{ task.error }} @@ -195,7 +145,12 @@ class TasksTemplate(Template): {% if task.context %} - Context: {{ task.context }} {% endif %} - + {% if task.agents %} + - Assigned agents: + {% for agent in task.agents %} + - "{{ agent.name }}" + {% endfor %} + {% endif %} {% endif %} {% endfor %} @@ -207,7 +162,7 @@ class TasksTemplate(Template): #### Task {{ controller.flow.get_task_id(task) }} - Status: {{ task.status.value }} - Objective: {{ task.objective }} - {% if task.status.value == "completed" %} + {% if task.status.value == "successful" %} - Result: {{ task.result }} {% elif task.status.value == "failed" %} - Error: {{ task.error }} @@ -227,20 +182,22 @@ def should_render(self): class ContextTemplate(Template): template: str = """ - ## Context + ## Additional context - {% if flow_context %} ### Flow context {% for key, value in flow_context.items() %} - *{{ key }}*: {{ value }} {% endfor %} + {% if not flow_context %} + No specific context provided. {% endif %} - {% if controller_context %} ### Controller context {% for key, value in controller_context.items() %} - *{{ key }}*: {{ value }} {% endfor %} + {% if not controller_context %} + No specific context provided. {% endif %} """ flow_context: dict @@ -257,6 +214,10 @@ class MainTemplate(BaseModel): instructions: list[str] def render(self): + all_agents = [self.agent] + self.controller.agents + for task in self.controller.tasks: + all_agents += task.agents + other_agents = [agent for agent in all_agents if agent != self.agent] templates = [ AgentTemplate(agent=self.agent), TasksTemplate(controller=self.controller), @@ -265,15 +226,10 @@ def render(self): controller_context=self.controller.context, ), InstructionsTemplate( - flow_instructions=self.controller.flow.instructions, - controller_instructions=self.controller.instructions, - agent_instructions=self.agent.instructions, additional_instructions=self.instructions, ), - CommunicationTemplate(agent=self.agent), - CollaborationTemplate( - other_agents=[a for a in self.controller.agents if a != self.agent] - ), + CommunicationTemplate(agent=self.agent, other_agents=other_agents), + # CollaborationTemplate(other_agents=other_agents), ] rendered = [ diff --git a/src/control_flow/core/controller/termination.py b/src/control_flow/core/controller/termination.py deleted file mode 100644 index 773582d5..00000000 --- a/src/control_flow/core/controller/termination.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import TYPE_CHECKING - -from pydantic import BaseModel - -from control_flow.task import TaskStatus - -if TYPE_CHECKING: - from control_flow.agent import Agent - - -class TerminationStrategy(BaseModel): - """ - A TerminationStrategy is a strategy for deciding when AI assistants have completed their tasks. - """ - - def run(self, agents: list["Agent"]) -> bool: - """ - Given agents, determine whether they have completed their tasks. - """ - - raise NotImplementedError() - - -class AllFinished(TerminationStrategy): - """ - An AllFinished termination strategy terminates when all agents have finished all of their tasks (either COMPLETED or FAILED). - """ - - def run(self, agents: list["Agent"]) -> bool: - """ - Given agents, determine whether they have completed their tasks. - """ - for agent in agents: - if any(task.status == TaskStatus.PENDING for task in agent.tasks): - return False - return True diff --git a/src/control_flow/core/flow.py b/src/control_flow/core/flow.py index 069a57c1..a8407c12 100644 --- a/src/control_flow/core/flow.py +++ b/src/control_flow/core/flow.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Literal from marvin.beta.assistants import Thread from openai.types.beta.threads import Message @@ -18,7 +18,6 @@ class Flow(ControlFlowModel): tools: list[AssistantTool | Callable] = Field( [], description="Tools that will be available to every agent in the flow" ) - instructions: str | None = None model: str | None = None context: dict = {} tasks: dict[Task, int] = Field(repr=False, default_factory=dict) @@ -34,8 +33,8 @@ def _load_thread_from_ctx(cls, v): return v - def add_message(self, message: str): - prefect_task(self.thread.add)(message) + def add_message(self, message: str, role: Literal["user", "assistant"] = None): + prefect_task(self.thread.add)(message, role=role) def add_task(self, task: Task): if task not in self.tasks: @@ -47,15 +46,15 @@ def add_task(self, task: Task): def get_task_id(self, task: Task): return self.tasks[task] - def pending_tasks(self): + def incomplete_tasks(self): return sorted( - (t for t in self.tasks if t.status == TaskStatus.PENDING), + (t for t in self.tasks if t.status == TaskStatus.INCOMPLETE), key=lambda t: t.created_at, ) def completed_tasks(self, reverse=False, limit=None): result = sorted( - (t for t in self.tasks if t.status != TaskStatus.PENDING), + (t for t in self.tasks if t.status != TaskStatus.INCOMPLETE), key=lambda t: t.completed_at, reverse=reverse, ) diff --git a/src/control_flow/core/task.py b/src/control_flow/core/task.py index f8882639..8a1cc89a 100644 --- a/src/control_flow/core/task.py +++ b/src/control_flow/core/task.py @@ -1,56 +1,101 @@ import datetime +import itertools from enum import Enum -from typing import Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Callable, TypeVar import marvin import marvin.utilities.tools from marvin.utilities.tools import FunctionTool -from pydantic import Field +from pydantic import Field, TypeAdapter from control_flow.utilities.logging import get_logger from control_flow.utilities.prefect import wrap_prefect_tool from control_flow.utilities.types import AssistantTool, ControlFlowModel from control_flow.utilities.user_access import talk_to_human +if TYPE_CHECKING: + from control_flow.core.agent import Agent T = TypeVar("T") logger = get_logger(__name__) class TaskStatus(Enum): - PENDING = "pending" - COMPLETED = "completed" + INCOMPLETE = "incomplete" + SUCCESSFUL = "successful" FAILED = "failed" -class Task(ControlFlowModel, Generic[T]): +class Task(ControlFlowModel): + model_config = dict(extra="forbid", allow_arbitrary_types=True) objective: str instructions: str | None = None - context: dict = Field({}) - status: TaskStatus = TaskStatus.PENDING + agents: list["Agent"] = [] + context: dict = {} + status: TaskStatus = TaskStatus.INCOMPLETE result: T = None + result_type: type[T] | None = None error: str | None = None tools: list[AssistantTool | Callable] = [] created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) completed_at: datetime.datetime | None = None user_access: bool = False + def __init__(self, objective, **kwargs): + # allow objective as a positional arg + super().__init__(objective=objective, **kwargs) + + def run(self, agents: list["Agent"] = None): + """ + Runs the task with provided agents for up to one cycle through the agents. + """ + if not agents and not self.agents: + raise ValueError("No agents provided to run task.") + + for agent in agents or self.agents: + if self.is_complete(): + break + agent.run(tasks=[self]) + + def run_until_complete(self, agents: list["Agent"] = None): + """ + Runs the task with provided agents until it is complete. + """ + if not agents and not self.agents: + raise ValueError("No agents provided to run task.") + agents = itertools.cycle(agents or self.agents) + while self.is_incomplete(): + agent = next(agents) + agent.run(tasks=[self]) + + def is_incomplete(self) -> bool: + return self.status == TaskStatus.INCOMPLETE + + def is_complete(self) -> bool: + return self.status != TaskStatus.INCOMPLETE + + def is_successful(self) -> bool: + return self.status == TaskStatus.SUCCESSFUL + + def is_failed(self) -> bool: + return self.status == TaskStatus.FAILED + def __hash__(self): return id(self) - def _create_complete_tool(self, task_id: int) -> FunctionTool: + def _create_success_tool(self, task_id: int) -> FunctionTool: """ - Create an agent-compatible tool for completing this task. + Create an agent-compatible tool for marking this task as successful. """ - result_type = self.get_result_type() # wrap the method call to get the correct result type signature - def complete(result: result_type): - self.complete(result=result) + def succeed(result: self.result_type): + # validate the result + self.mark_successful(result=result) tool = marvin.utilities.tools.tool_from_function( - complete, - name=f"complete_task_{task_id}", - description=f"Mark task {task_id} completed", + succeed, + name=f"succeed_task_{task_id}", + description=f"Mark task {task_id} as successful", ) return tool @@ -60,33 +105,54 @@ def _create_fail_tool(self, task_id: int) -> FunctionTool: Create an agent-compatible tool for failing this task. """ tool = marvin.utilities.tools.tool_from_function( - self.fail, + self.mark_failed, name=f"fail_task_{task_id}", - description=f"Mark task {task_id} failed", + description=f"Mark task {task_id} as failed", ) return tool def get_tools(self, task_id: int) -> list[AssistantTool | Callable]: tools = self.tools + [ - self._create_complete_tool(task_id), + self._create_success_tool(task_id), self._create_fail_tool(task_id), ] if self.user_access: tools.append(marvin.utilities.tools.tool_from_function(talk_to_human)) return [wrap_prefect_tool(t) for t in tools] - def complete(self, result: T): + def mark_successful(self, result: T = None): + if self.result_type is None and result is not None: + raise ValueError( + f"Task {self.objective} specifies no result type, but a result was provided." + ) + elif self.result_type is not None: + result = TypeAdapter(self.result_type).validate_python(result) + self.result = result - self.status = TaskStatus.COMPLETED + self.status = TaskStatus.SUCCESSFUL self.completed_at = datetime.datetime.now() - def fail(self, message: str | None = None): + def mark_failed(self, message: str | None = None): self.error = message self.status = TaskStatus.FAILED self.completed_at = datetime.datetime.now() - def get_result_type(self) -> T: - """ - Returns the `type` of the task's result field. - """ - return self.model_fields["result"].annotation + +def any_incomplete(tasks: list[Task]) -> bool: + return any(t.status == TaskStatus.INCOMPLETE for t in tasks) + + +def all_complete(tasks: list[Task]) -> bool: + return all(t.status != TaskStatus.INCOMPLETE for t in tasks) + + +def all_successful(tasks: list[Task]) -> bool: + return all(t.status == TaskStatus.SUCCESSFUL for t in tasks) + + +def any_failed(tasks: list[Task]) -> bool: + return any(t.status == TaskStatus.FAILED for t in tasks) + + +def none_failed(tasks: list[Task]) -> bool: + return not any_failed(tasks) diff --git a/src/control_flow/dx.py b/src/control_flow/dx.py index 79817606..499d6cdb 100644 --- a/src/control_flow/dx.py +++ b/src/control_flow/dx.py @@ -23,7 +23,6 @@ def ai_flow( *, thread: Thread = None, tools: list[AssistantTool | Callable] = None, - instructions: str = None, model: str = None, ): """ @@ -35,7 +34,6 @@ def ai_flow( ai_flow, thread=thread, tools=tools, - instructions=instructions, model=model, ) @@ -51,7 +49,6 @@ def wrapper( **{ "thread": thread, "tools": tools or [], - "instructions": instructions, "model": model, **(flow_kwargs or {}), } @@ -188,7 +185,7 @@ def run_ai( controller.run() if ai_tasks: - if all(task.status == TaskStatus.COMPLETED for task in ai_tasks): + if all(task.status == TaskStatus.SUCCESSFUL for task in ai_tasks): result = [task.result for task in ai_tasks] if single_result: result = result[0] diff --git a/src/control_flow/loops.py b/src/control_flow/loops.py new file mode 100644 index 00000000..1e939ee5 --- /dev/null +++ b/src/control_flow/loops.py @@ -0,0 +1,31 @@ +import math +from typing import Generator + +import control_flow.core.task +from control_flow.core.task import Task + + +def any_incomplete( + tasks: list[Task], max_iterations=None +) -> Generator[bool, None, None]: + """ + An iterator that yields an iteration counter if its condition is met, and + stops otherwise. Also stops if the max_iterations is reached. + + + for loop_count in any_incomplete(tasks=[task1, task2], max_iterations=10): + # will print 10 times if the tasks are still incomplete + print(loop_count) + + """ + if max_iterations is None: + max_iterations = math.inf + + i = 0 + while i < max_iterations: + i += 1 + if control_flow.core.task.any_incomplete(tasks): + yield i + else: + break + return False diff --git a/src/control_flow/settings.py b/src/control_flow/settings.py index 8b8af737..e8d083cf 100644 --- a/src/control_flow/settings.py +++ b/src/control_flow/settings.py @@ -1,4 +1,6 @@ import os +import sys +import warnings from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -21,13 +23,20 @@ class ControlFlowSettings(BaseSettings): class PrefectSettings(ControlFlowSettings): """ All settings here are used as defaults for Prefect, unless overridden by env vars. + Note that `apply()` must be called before Prefect is imported. """ PREFECT_LOGGING_LEVEL: str = "WARNING" + PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE: str = "true" def apply(self): import os + if "prefect" in sys.modules: + warnings.warn( + "Prefect has already been imported; ControlFlow defaults will not be applied." + ) + for k, v in self.model_dump().items(): if k not in os.environ: os.environ[k] = v diff --git a/src/control_flow/utilities/marvin.py b/src/control_flow/utilities/marvin.py index 5fad7e80..b530d876 100644 --- a/src/control_flow/utilities/marvin.py +++ b/src/control_flow/utilities/marvin.py @@ -1,4 +1,4 @@ -import functools +import inspect from contextlib import contextmanager from typing import Any, Callable @@ -33,21 +33,26 @@ async def _generate_chat(**kwargs): create_json_artifact(key="response", data=response) return response - return _generate_chat(**kwargs) + return await _generate_chat(**kwargs) def generate_task(name: str, original_fn: Callable): - @functools.wraps(marvin.classify_async) - async def wrapper(*args, **kwargs): + if inspect.iscoroutinefunction(original_fn): + @prefect_task(name=name) - async def inner(*args, **kwargs): + async def wrapper(*args, **kwargs): create_json_artifact(key="args", data=[args, kwargs]) result = await original_fn(*args, **kwargs) create_json_artifact(key="result", data=result) return result + else: - # do this to avoid weirdness with async/sync behavior - return inner(*args, **kwargs) + @prefect_task(name=name) + def wrapper(*args, **kwargs): + create_json_artifact(key="args", data=[args, kwargs]) + result = original_fn(*args, **kwargs) + create_json_artifact(key="result", data=result) + return result return wrapper diff --git a/src/control_flow/utilities/prefect.py b/src/control_flow/utilities/prefect.py index cf0025f4..d40e589b 100644 --- a/src/control_flow/utilities/prefect.py +++ b/src/control_flow/utilities/prefect.py @@ -61,11 +61,15 @@ def create_json_artifact( Create a JSON artifact. """ - json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode() + try: + markdown = TypeAdapter(type(data)).dump_json(data, indent=2).decode() + markdown = f"```json\n{markdown}\n```" + except Exception: + markdown = str(data) create_markdown_artifact( key=key, - markdown=f"```json\n{json_data}\n```", + markdown=markdown, description=description, task_run_id=task_run_id, flow_run_id=flow_run_id, @@ -127,18 +131,17 @@ def wrap_prefect_tool(tool: AssistantTool | Callable) -> AssistantTool: if isinstance(tool.function._python_fn, prefect.tasks.Task): return tool - async def modified_fn( - *args, + def modified_fn( # provide default args to avoid a late-binding issue original_fn: Callable = tool.function._python_fn, tool: FunctionTool = tool, **kwargs, ): # call fn - result = original_fn(*args, **kwargs) + result = original_fn(**kwargs) # prepare artifact - passed_args = inspect.signature(original_fn).bind(*args, **kwargs).arguments + passed_args = inspect.signature(original_fn).bind(**kwargs).arguments try: passed_args = json.dumps(passed_args, indent=2) except Exception: diff --git a/src/control_flow/utilities/threads.py b/src/control_flow/utilities/threads.py index 7a56b17b..c707a9f8 100644 --- a/src/control_flow/utilities/threads.py +++ b/src/control_flow/utilities/threads.py @@ -1,4 +1,4 @@ -from marvin.beta.assistants import Thread +from marvin.beta.assistants.threads import Message, Thread THREAD_REGISTRY = {} @@ -18,3 +18,10 @@ def load_thread(name: str): thread = Thread() save_thread(name, thread) return THREAD_REGISTRY[name] + + +def get_history(thread_id: str, limit: int = None) -> list[Message]: + """ + Get the history of a thread + """ + return Thread(id=thread_id).get_messages(limit=limit) diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/agents.py b/tests/core/agents.py new file mode 100644 index 00000000..6d27af7e --- /dev/null +++ b/tests/core/agents.py @@ -0,0 +1,22 @@ +from control_flow.core.agent import Agent +from control_flow.core.task import Task +from pytest import patch + + +class TestAgent: + pass + + +class TestAgentRun: + def test_agent_run(self): + with patch( + "control_flow.core.controller.Controller._get_prefect_run_agent_task" + ) as mock_task: + agent = Agent() + agent.run() + mock_task.assert_called_once() + + def test_agent_run_with_task(self): + task = Task("say hello") + agent = Agent() + agent.run(tasks=[task]) diff --git a/tests/core/test_agents.py b/tests/core/test_agents.py new file mode 100644 index 00000000..fac88126 --- /dev/null +++ b/tests/core/test_agents.py @@ -0,0 +1,16 @@ +from control_flow.core.agent import Agent +from pytest import patch + + +class TestAgent: + pass + + +class TestAgentRun: + def test_agent_run(self): + with patch( + "control_flow.core.controller.Controller._get_prefect_run_agent_task" + ) as mock_task: + agent = Agent() + agent.run() + mock_task.assert_called_once() diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py new file mode 100644 index 00000000..feebca57 --- /dev/null +++ b/tests/core/test_tasks.py @@ -0,0 +1,22 @@ +from control_flow.core.task import Task, get_tasks +from control_flow.utilities.context import ctx + + +class TestTaskContext: + def test_context_open_and_close(self): + assert ctx.get("tasks") == [] + with Task("a") as ta: + assert ctx.get("tasks") == [ta] + with Task("b") as tb: + assert ctx.get("tasks") == [ta, tb] + assert ctx.get("tasks") == [ta] + assert ctx.get("tasks") == [] + + def test_get_tasks_function(self): + # assert get_tasks() == [] + with Task("a") as ta: + assert get_tasks() == [ta] + with Task("b") as tb: + assert get_tasks() == [ta, tb] + assert get_tasks() == [ta] + assert get_tasks() == [] diff --git a/tests/fixtures/mocks.py b/tests/fixtures/mocks.py index ff85345a..ef307913 100644 --- a/tests/fixtures/mocks.py +++ b/tests/fixtures/mocks.py @@ -4,17 +4,17 @@ from control_flow.utilities.user_access import talk_to_human -@pytest.fixture(autouse=True) -def mock_talk_to_human(): - """Return an empty default handler instead of a print handler to avoid - printing assistant output during tests""" +# @pytest.fixture(autouse=True) +# def mock_talk_to_human(): +# """Return an empty default handler instead of a print handler to avoid +# printing assistant output during tests""" - def mock_talk_to_human(message: str, get_response: bool) -> str: - print(dict(message=message, get_response=get_response)) - return "Message sent to user." +# def mock_talk_to_human(message: str, get_response: bool) -> str: +# print(dict(message=message, get_response=get_response)) +# return "Message sent to user." - mock_talk_to_human.__doc__ = talk_to_human.__doc__ - with patch( - "control_flow.utilities.user_access.mock_talk_to_human", new=talk_to_human - ): - yield +# mock_talk_to_human.__doc__ = talk_to_human.__doc__ +# with patch( +# "control_flow.utilities.user_access.mock_talk_to_human", new=talk_to_human +# ): +# yield