From 37f531cb0a93e7896b53743aaa0467c29351341f Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 14 May 2024 09:45:39 -0400 Subject: [PATCH] Improve handling of default values --- pyproject.toml | 12 +++-- src/controlflow/__init__.py | 6 ++- src/controlflow/core/agent.py | 10 ++++ src/controlflow/core/controller/controller.py | 36 +++++++++++--- .../core/controller/instruction_template.py | 4 +- src/controlflow/core/flow.py | 42 +++++++++------- src/controlflow/core/graph.py | 2 +- src/controlflow/core/task.py | 49 ++++++++++++++----- src/controlflow/settings.py | 2 +- 9 files changed, 116 insertions(+), 47 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ed18ea4a..aedb7293 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "controlflow" -version = "0.1.0" +version = "0.3.0" description = "AI Workflows" authors = [ { name = "Jeremiah Lowin", email = "153965+jlowin@users.noreply.github.com" }, @@ -17,8 +17,10 @@ keywords = [ "ai", "chatbot", "llm", - "NLP", - "natural language processing", + "ai orchestration", + "llm orchestration", + "agentic workflows", + "flow engineering", "prefect", "workflow", "orchestration", @@ -26,7 +28,9 @@ keywords = [ "GPT", "openai", "assistant", - "agent", + "agents", + "AI agents", + "natural language processing", ] [project.urls] diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index 9e0b44b1..01ac31b5 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -2,10 +2,14 @@ # from .agent_old import task, Agent, run_ai from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow -from .core.agent import Agent from .core.task import Task, task +from .core.agent import Agent from .core.controller.controller import Controller from .instructions import instructions from .dx import run_ai +Flow.model_rebuild() +Task.model_rebuild() +Agent.model_rebuild() + _reset_global_flow() diff --git a/src/controlflow/core/agent.py b/src/controlflow/core/agent.py index 354da39c..a2efaa6a 100644 --- a/src/controlflow/core/agent.py +++ b/src/controlflow/core/agent.py @@ -16,6 +16,16 @@ logger = logging.getLogger(__name__) +def default_agent(): + return Agent( + name="Marvin", + instructions=""" + You are a diligent AI assistant. You complete + your tasks efficiently and without error. + """, + ) + + class Agent(Assistant, ControlFlowModel, ExposeSyncMethodsMixin): name: str user_access: bool = Field( diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index db51be68..a7997308 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -43,6 +43,10 @@ class Controller(BaseModel, ExposeSyncMethodsMixin): """ + # the flow is tracked by the Controller, not the Task, so that tasks can be + # defined and even instantiated outside a flow. When a Controller is + # created, we know we're inside a flow context and ready to load defaults + # and run. flow: Flow = Field( default_factory=get_flow, description="The flow that the controller is a part of.", @@ -65,6 +69,12 @@ def _create_graph(cls, data: Any) -> Any: data["graph"] = Graph.from_tasks(data.get("tasks", [])) return data + @model_validator(mode="after") + def _finalize(self): + for task in self.tasks: + self.flow.add_task(task) + return self + @field_validator("tasks", mode="before") def _validate_tasks(cls, v): if v is None: @@ -92,12 +102,21 @@ async def _run_agent( """ @prefect_task(task_run_name=f'Run Agent: "{agent.name}"') - async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None): + async def _run_agent( + controller: Controller, + agent: Agent, + tasks: list[Task], + thread: Thread = None, + ): from controlflow.core.controller.instruction_template import MainTemplate - tasks = tasks or self.tasks + tasks = tasks or controller.tasks - tools = self.flow.tools + agent.get_tools() + [self._create_end_run_tool()] + tools = ( + controller.flow.tools + + agent.get_tools() + + [controller._create_end_run_tool()] + ) # add tools for any inactive tasks that the agent is assigned to for task in tasks: @@ -106,12 +125,11 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None): instructions_template = MainTemplate( agent=agent, - controller=self, + controller=controller, tasks=tasks, - context=self.context, + context=controller.context, instructions=get_instructions(), ) - instructions = instructions_template.render() # filter tools because duplicate names are not allowed @@ -126,7 +144,7 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None): run = Run( assistant=agent, - thread=thread or self.flow.thread, + thread=thread or controller.flow.thread, instructions=instructions, tools=final_tools, event_handler_class=AgentHandler, @@ -146,7 +164,9 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None): ) return run - return await _run_agent(agent=agent, tasks=tasks, thread=thread) + return await _run_agent( + controller=self, agent=agent, tasks=tasks, thread=thread + ) @expose_sync_method("run_once") async def run_once_async(self): diff --git a/src/controlflow/core/controller/instruction_template.py b/src/controlflow/core/controller/instruction_template.py index 03efd774..0c836254 100644 --- a/src/controlflow/core/controller/instruction_template.py +++ b/src/controlflow/core/controller/instruction_template.py @@ -1,7 +1,5 @@ import inspect -from pydantic import BaseModel - from controlflow.core.agent import Agent from controlflow.core.task import Task from controlflow.utilities.jinja import jinja_env @@ -187,7 +185,7 @@ def should_render(self): return bool(self.flow_context or self.controller_context) -class MainTemplate(BaseModel): +class MainTemplate(ControlFlowModel): agent: Agent controller: Controller context: dict diff --git a/src/controlflow/core/flow.py b/src/controlflow/core/flow.py index df27a943..6b2250a9 100644 --- a/src/controlflow/core/flow.py +++ b/src/controlflow/core/flow.py @@ -1,6 +1,7 @@ import functools +import inspect from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal import prefect from marvin.beta.assistants import Thread @@ -16,31 +17,23 @@ if TYPE_CHECKING: from controlflow.core.agent import Agent + from controlflow.core.task import Task logger = get_logger(__name__) -def default_agent(): - from controlflow.core.agent import Agent - - return [ - Agent( - name="Marvin", - description="I am Marvin, the default agent for Control Flow.", - ) - ] - - class Flow(ControlFlowModel): thread: Thread = Field(None, validate_default=True) tools: list[AssistantTool | Callable] = Field( - [], description="Tools that will be available to every agent in the flow" + default_factory=list, + description="Tools that will be available to every agent in the flow", ) agents: list["Agent"] = Field( - default_factory=default_agent, description="The default agents for the flow. These agents will be used " "for any task that does not specify agents.", + default_factory=list, ) - context: dict = {} + _tasks: dict[str, "Task"] = {} + context: dict[str, Any] = {} @field_validator("thread", mode="before") def _load_thread_from_ctx(cls, v): @@ -53,6 +46,13 @@ def _load_thread_from_ctx(cls, v): return v + def add_task(self, task: "Task"): + if self._tasks.get(task.id, task) is not task: + raise ValueError( + f"A different task with id '{task.id}' already exists in flow." + ) + self._tasks[task.id] = task + def add_message(self, message: str, role: Literal["user", "assistant"] = None): prefect_task(self.thread.add)(message, role=role) @@ -107,6 +107,7 @@ def flow( fn=None, *, thread: Thread = None, + instructions: str = None, tools: list[AssistantTool | Callable] = None, agents: list["Agent"] = None, ): @@ -122,12 +123,18 @@ def flow( agents=agents, ) + sig = inspect.signature(fn) + @functools.wraps(fn) def wrapper( *args, flow_kwargs: dict = None, **kwargs, ): + # first process callargs + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + flow_kwargs = flow_kwargs or {} if thread is not None: @@ -139,13 +146,14 @@ def wrapper( p_fn = prefect.flow(fn) - flow_obj = Flow(**flow_kwargs) + flow_obj = Flow(**flow_kwargs, context=bound.arguments) logger.info( f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"' ) with ctx(flow=flow_obj), patch_marvin(): - return p_fn(*args, **kwargs) + with controlflow.instructions.instructions(instructions): + return p_fn(*args, **kwargs) return wrapper diff --git a/src/controlflow/core/graph.py b/src/controlflow/core/graph.py index 99e361c4..ba532e68 100644 --- a/src/controlflow/core/graph.py +++ b/src/controlflow/core/graph.py @@ -35,7 +35,7 @@ class Edge(BaseModel): type: EdgeType def __repr__(self): - return f"{self.type}: {self.upstream.id} -> {self.downstream.id}" + return f"{self.type}: {self.upstream.friendly_name()} -> {self.downstream.friendly_name()}" def __hash__(self) -> int: return id(self) diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 65331066..f6b88da2 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -25,7 +25,6 @@ model_validator, ) -from controlflow.core.flow import get_flow from controlflow.instructions import get_instructions from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger @@ -77,14 +76,18 @@ def visit_task_collection( class Task(ControlFlowModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:4])) + id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5])) objective: str = Field( ..., description="A brief description of the required result." ) instructions: str | None = Field( None, description="Detailed instructions for completing the task." ) - agents: list["Agent"] = Field(None, validate_default=True) + agents: list["Agent"] | None = Field( + None, + description="The agents assigned to the task. If None, the task will use its flow's default agents.", + validate_default=True, + ) context: dict = Field( default_factory=dict, description="Additional context for the task. If tasks are provided as context, they are automatically added as `depends_on`", @@ -108,7 +111,11 @@ class Task(ControlFlowModel): model_config = dict(extra="forbid", arbitrary_types_allowed=True) def __init__( - self, objective=None, result_type=None, parent: "Task" = None, **kwargs + self, + objective=None, + result_type=None, + parent: "Task" = None, + **kwargs, ): # allow certain args to be provided as a positional args if result_type is not None: @@ -138,9 +145,18 @@ def __repr__(self): return str(self.model_dump()) @field_validator("agents", mode="before") - def _default_agent(cls, v): + def _default_agents(cls, v): + from controlflow.core.agent import default_agent + from controlflow.core.flow import get_flow + if v is None: - v = get_flow().agents + flow = get_flow() + if flow.agents: + v = flow.agents + else: + v = [default_agent()] + if not v: + raise ValueError("At least one agent is required.") return v @field_validator("result_type", mode="before") @@ -150,7 +166,8 @@ def _turn_list_into_literal_result_type(cls, v): return v @model_validator(mode="after") - def _load_context_dependencies(self): + def _finalize(self): + # create dependencies to tasks passed in as context tasks = [] def visitor(task): @@ -158,6 +175,7 @@ def visitor(task): return task visit_task_collection(self.context, visitor) + for task in tasks: if task not in self.depends_on: self.depends_on.append(task) @@ -189,6 +207,13 @@ def _serialize_agents(agents: list["Agent"]): for a in agents ] + def friendly_name(self): + if len(self.objective) > 50: + objective = self.objective[:50] + "..." + else: + objective = self.objective + return f"Task {self.id} ({objective})" + def as_graph(self) -> "Graph": from controlflow.core.graph import Graph @@ -201,7 +226,7 @@ def add_subtask(self, task: "Task"): if task._parent is None: task._parent = self elif task._parent is not self: - raise ValueError(f"Task {task.id} already has a parent.") + raise ValueError(f"{self.friendly_name()} already has a parent.") if task not in self.subtasks: self.subtasks.append(task) @@ -235,7 +260,7 @@ def run(self, run_dependencies: bool = True) -> T: if self.is_successful(): return self.result elif self.is_failed(): - raise ValueError(f"Task {self.id} failed: {self.error}") + raise ValueError(f"{self.friendly_name()} failed: {self.error}") @contextmanager def _context(self): @@ -349,16 +374,16 @@ def mark_successful(self, result: T = None, validate: bool = True): self.result = result self.status = TaskStatus.SUCCESSFUL - return f"Task {self.id} marked successful. Updated task definition: {self.model_dump()}" + return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}" def mark_failed(self, message: str | None = None): self.error = message self.status = TaskStatus.FAILED - return f"Task {self.id} marked failed. Updated task definition: {self.model_dump()}" + return f"{self.friendly_name()} marked failed. Updated task definition: {self.model_dump()}" def mark_skipped(self): self.status = TaskStatus.SKIPPED - return f"Task {self.id} marked skipped. Updated task definition: {self.model_dump()}" + return f"{self.friendly_name()} marked skipped. Updated task definition: {self.model_dump()}" def any_incomplete(tasks: list[Task]) -> bool: diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index eaa92c1f..1a563249 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -46,7 +46,7 @@ def apply(self): class Settings(ControlFlowSettings): - assistant_model: str = "gpt-4-1106-preview" + assistant_model: str = "gpt-4o" max_agent_iterations: int = 10 prefect: PrefectSettings = Field(default_factory=PrefectSettings) enable_global_flow: bool = Field(