diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index b0d2cb5b..e8571832 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -18,7 +18,7 @@ on: pull_request: paths: - .github/workflows/run-tests.yml - - src** + - src/** - tests/** - pyproject.toml - setup.py @@ -34,13 +34,11 @@ jobs: timeout-minutes: 15 strategy: matrix: - # run no_llm tests across all python versions and oses # os: [ubuntu-latest, macos-latest, windows-latest] - # python-version: ['3.9', '3.10', '3.11', '3.12'] os: [ubuntu-latest] - python-version: ['3.9'] - - + # python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ["3.9", "3.12"] + runs-on: ${{ matrix.os }} env: @@ -48,12 +46,12 @@ jobs: steps: - uses: actions/checkout@v4 - + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - + - name: download uv run: curl -LsSf https://astral.sh/uv/install.sh | sh @@ -61,5 +59,5 @@ jobs: run: uv pip install --system ".[tests]" - name: Run tests - run: pytest -n auto -vv - if: ${{ !(github.event.pull_request.head.repo.fork) }} \ No newline at end of file + run: pytest -vv + if: ${{ !(github.event.pull_request.head.repo.fork) }} diff --git a/docs/concepts/flows.mdx b/docs/concepts/flows.mdx index aca3309f..a6a36e15 100644 --- a/docs/concepts/flows.mdx +++ b/docs/concepts/flows.mdx @@ -42,7 +42,7 @@ flow = Flow() Flows have several key properties that define their behavior and capabilities: - `thread` (Thread): The thread associated with the flow, which stores the conversation history and context. -- `tools` (list[AssistantTool | Callable]): A list of tools that are available to all agents in the flow. +- `tools` (list[ToolType]): A list of tools that are available to all agents in the flow. - `agents` (list[Agent]): The default agents for the flow, which are used for tasks that do not specify agents explicitly. - `context` (dict): Additional context or information that is shared across tasks and agents in the flow. diff --git a/docs/concepts/tasks.mdx b/docs/concepts/tasks.mdx index 65d8ae2a..50790f78 100644 --- a/docs/concepts/tasks.mdx +++ b/docs/concepts/tasks.mdx @@ -59,7 +59,7 @@ Tasks have several key properties that define their behavior and requirements: - `agents` (list[Agent], optional): The AI agents assigned to work on the task. - `context` (dict, optional): Additional context or information required for the task. - `result_type` (type, optional): The expected type of the task's result. -- `tools` (list[AssistantTool | Callable], optional): Tools or functions available to the agents for completing the task. +- `tools` (list[ToolType], optional): Tools or functions available to the agents for completing the task. - `user_access` (bool, optional): Indicates whether the task requires human user interaction. ## Task Execution and Results diff --git a/pyproject.toml b/pyproject.toml index aedb7293..186cd47c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ tests = [ "pytest-xdist", "pre-commit>=3.7.0", ] -dev = ["controlflow[tests]", "ipython>=8.22.2", "pdbpp>=0.10.3", "ruff>=0.3.4"] +dev = ["controlflow[tests]", "ipython", "pdbpp", "ruff>=0.3.4"] [build-system] requires = ["hatchling"] @@ -77,3 +77,6 @@ skip-magic-trailing-comma = false "conftest.py" = ["F401", "F403"] 'tests/fixtures/*.py' = ['F401', 'F403'] "src/controlflow/utilities/types.py" = ['F401'] + +[tool.pytest.ini_options] +timeout = 120 diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index 01ac31b5..fdfe8b45 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -1,12 +1,10 @@ from .settings import settings -# from .agent_old import task, Agent, run_ai from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow from .core.task import Task, task from .core.agent import Agent from .core.controller.controller import Controller from .instructions import instructions -from .dx import run_ai Flow.model_rebuild() Task.model_rebuild() diff --git a/src/controlflow/core/agent.py b/src/controlflow/core/agent.py index a2efaa6a..acdf173f 100644 --- a/src/controlflow/core/agent.py +++ b/src/controlflow/core/agent.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from typing import Union from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.tools import tool_from_function @@ -10,7 +10,7 @@ from controlflow.utilities.prefect import ( wrap_prefect_tool, ) -from controlflow.utilities.types import Assistant, AssistantTool, ControlFlowModel +from controlflow.utilities.types import Assistant, ControlFlowModel, ToolType from controlflow.utilities.user_access import talk_to_human logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ class Agent(Assistant, ControlFlowModel, ExposeSyncMethodsMixin): description="If True, the agent is given tools for interacting with a human user.", ) - def get_tools(self) -> list[AssistantTool | Callable]: + def get_tools(self) -> list[ToolType]: tools = super().get_tools() if self.user_access: tools.append(tool_from_function(talk_to_human)) @@ -41,7 +41,7 @@ def get_tools(self) -> list[AssistantTool | Callable]: return [wrap_prefect_tool(tool) for tool in tools] @expose_sync_method("run") - async def run_async(self, tasks: list[Task] | Task | None = None): + async def run_async(self, tasks: Union[list[Task], Task, None] = None): from controlflow.core.controller import Controller if isinstance(tasks, Task): diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index 631c66a3..58c39161 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any +from typing import Any, Union import marvin.utilities import marvin.utilities.tools @@ -56,7 +56,7 @@ class Controller(BaseModel, ExposeSyncMethodsMixin): description="Tasks that the controller will complete.", validate_default=True, ) - agents: list[Agent] | None = None + agents: Union[list[Agent], None] = None context: dict = {} graph: Graph = None model_config: dict = dict(extra="forbid") @@ -173,7 +173,7 @@ async def run_once_async(self): Run the controller for a single iteration of the provided tasks. An agent will be selected to run the tasks. """ # get the tasks to run - tasks = self.graph.upstream_dependencies(self.tasks) + tasks = self.graph.upstream_dependencies(self.tasks, include_tasks=True) # get the agents agent_candidates = {a for t in tasks for a in t.agents if t.is_ready()} diff --git a/src/controlflow/core/flow.py b/src/controlflow/core/flow.py index 6b2250a9..4a4ba79c 100644 --- a/src/controlflow/core/flow.py +++ b/src/controlflow/core/flow.py @@ -1,19 +1,18 @@ import functools import inspect from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Union import prefect from marvin.beta.assistants import Thread from openai.types.beta.threads import Message -from prefect import task as prefect_task from pydantic import Field, field_validator import controlflow from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger from controlflow.utilities.marvin import patch_marvin -from controlflow.utilities.types import AssistantTool, ControlFlowModel +from controlflow.utilities.types import ControlFlowModel, ToolType if TYPE_CHECKING: from controlflow.core.agent import Agent @@ -23,7 +22,7 @@ class Flow(ControlFlowModel): thread: Thread = Field(None, validate_default=True) - tools: list[AssistantTool | Callable] = Field( + tools: list[ToolType] = Field( default_factory=list, description="Tools that will be available to every agent in the flow", ) @@ -41,8 +40,6 @@ def _load_thread_from_ctx(cls, v): v = ctx.get("thread", None) if v is None: v = Thread() - if not v.id: - v.create() return v @@ -53,9 +50,6 @@ def add_task(self, task: "Task"): ) self._tasks[task.id] = task - def add_message(self, message: str, role: Literal["user", "assistant"] = None): - prefect_task(self.thread.add)(message, role=role) - @contextmanager def _context(self): with ctx(flow=self, tasks=[]): @@ -79,7 +73,7 @@ def get_flow() -> Flow: Will error if no flow is found in the context, unless the global flow is enabled in settings """ - flow: Flow | None = ctx.get("flow") + flow: Union[Flow, None] = ctx.get("flow") if not flow: if controlflow.settings.enable_global_flow: return GLOBAL_FLOW @@ -108,7 +102,7 @@ def flow( *, thread: Thread = None, instructions: str = None, - tools: list[AssistantTool | Callable] = None, + tools: list[ToolType] = None, agents: list["Agent"] = None, ): """ @@ -153,7 +147,7 @@ def wrapper( ) with ctx(flow=flow_obj), patch_marvin(): - with controlflow.instructions.instructions(instructions): + with controlflow.instructions(instructions): return p_fn(*args, **kwargs) return wrapper diff --git a/src/controlflow/core/graph.py b/src/controlflow/core/graph.py index ba532e68..065e0e08 100644 --- a/src/controlflow/core/graph.py +++ b/src/controlflow/core/graph.py @@ -108,7 +108,10 @@ def downstream_edges(self) -> dict[Task, list[Edge]]: return self._cache["downstream_edges"] def upstream_dependencies( - self, tasks: list[Task], prune_completed: bool = True + self, + tasks: list[Task], + prune_completed: bool = True, + include_tasks: bool = False, ) -> list[Task]: """ From a list of tasks, returns the subgraph of tasks that are directly or @@ -117,11 +120,15 @@ def upstream_dependencies( dependencies as well as any subtasks that are considered implicit dependencies. - If `prune_completed` is True, the subgraph will be pruned to stop traversal after adding any completed tasks. + If `prune_completed` is True, the subgraph will be pruned to stop + traversal after adding any completed tasks. + + If `include_tasks` is True, the subgraph will include the tasks provided. """ subgraph = set() upstreams = self.upstream_edges() - stack = tasks + # copy stack to allow difference update with original tasks + stack = [t for t in tasks] while stack: current = stack.pop() if current in subgraph: @@ -133,6 +140,8 @@ def upstream_dependencies( continue stack.extend([edge.upstream for edge in upstreams[current]]) + if not include_tasks: + subgraph.difference_update(tasks) return list(subgraph) def ready_tasks(self, tasks: list[Task] = None) -> list[Task]: @@ -146,7 +155,7 @@ def ready_tasks(self, tasks: list[Task] = None) -> list[Task]: if tasks is None: candidates = self.tasks else: - candidates = self.upstream_dependencies(tasks) + candidates = self.upstream_dependencies(tasks, include_tasks=True) return sorted( [task for task in candidates if task.is_ready()], key=lambda t: t.created_at ) diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 7150b6ba..37b013df 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -11,6 +11,7 @@ GenericAlias, Literal, TypeVar, + Union, _LiteralGenericAlias, ) @@ -25,11 +26,17 @@ model_validator, ) +import controlflow from controlflow.instructions import get_instructions from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger from controlflow.utilities.prefect import wrap_prefect_tool -from controlflow.utilities.types import NOTSET, AssistantTool, ControlFlowModel +from controlflow.utilities.types import ( + NOTSET, + AssistantTool, + ControlFlowModel, + ToolType, +) from controlflow.utilities.user_access import talk_to_human if TYPE_CHECKING: @@ -77,10 +84,10 @@ class Task(ControlFlowModel): objective: str = Field( ..., description="A brief description of the required result." ) - instructions: str | None = Field( + instructions: Union[str, None] = Field( None, description="Detailed instructions for completing the task." ) - agents: list["Agent"] | None = Field( + agents: Union[list["Agent"], None] = Field( None, description="The agents assigned to the task. If None, the task will use its flow's default agents.", validate_default=True, @@ -98,12 +105,12 @@ class Task(ControlFlowModel): ) status: TaskStatus = TaskStatus.INCOMPLETE result: T = None - result_type: type[T] | GenericAlias | _LiteralGenericAlias | None = None - error: str | None = None - tools: list[AssistantTool | Callable] = [] + result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = None + error: Union[str, None] = None + tools: list[ToolType] = [] user_access: bool = False created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) - _parent: "Task | None" = None + _parent: "Union[Task, None]" = None _downstreams: list["Task"] = [] model_config = dict(extra="forbid", arbitrary_types_allowed=True) @@ -147,8 +154,11 @@ def _default_agents(cls, v): from controlflow.core.flow import get_flow if v is None: - flow = get_flow() - if flow.agents: + try: + flow = get_flow() + except ValueError: + flow = None + if flow and flow.agents: v = flow.agents else: v = [default_agent()] @@ -204,11 +214,20 @@ def _serialize_agents(agents: list["Agent"]): for a in agents ] + @field_serializer("tools") + def _serialize_tools(tools: list[ToolType]): + return [ + marvin.utilities.tools.tool_from_function(t) + if not isinstance(t, AssistantTool) + else t + for t in tools + ] + def friendly_name(self): if len(self.objective) > 50: - objective = self.objective[:50] + "..." + objective = f'"{self.objective[:50]}..."' else: - objective = self.objective + objective = f'"{self.objective}"' return f"Task {self.id} ({objective})" def as_graph(self) -> "Graph": @@ -253,7 +272,7 @@ def run(self, max_iterations: int = NOTSET) -> T: If max_iterations is provided, the task will run at most that many times before raising an error. """ if max_iterations == NOTSET: - max_iterations = marvin.settings.max_task_iterations + max_iterations = controlflow.settings.max_task_iterations if max_iterations is None: max_iterations = float("inf") @@ -264,6 +283,7 @@ def run(self, max_iterations: int = NOTSET) -> T: f"{self.friendly_name()} did not complete after {max_iterations} iterations." ) self.run_once() + counter += 1 if self.is_successful(): return self.result elif self.is_failed(): @@ -272,8 +292,7 @@ def run(self, max_iterations: int = NOTSET) -> T: @contextmanager def _context(self): stack = ctx.get("tasks", []) - stack.append(self) - with ctx(tasks=stack): + with ctx(tasks=stack + [self]): yield self def __enter__(self): @@ -346,7 +365,7 @@ def _create_skip_tool(self) -> FunctionTool: ) return tool - def get_tools(self) -> list[AssistantTool | Callable]: + def get_tools(self) -> list[ToolType]: tools = self.tools.copy() if self.is_incomplete(): tools.extend([self._create_fail_tool(), self._create_success_tool()]) @@ -363,13 +382,13 @@ def mark_successful(self, result: T = None, validate: bool = True): raise ValueError( f"Task {self.objective} cannot be marked successful until all of its " "upstream dependencies are completed. Incomplete dependencies " - f"are: {[t.id for t in self.depends_on if t.is_incomplete()]}" + f"are: {', '.join(t.friendly_name() for t in self.depends_on if t.is_incomplete())}" ) elif any(t.is_incomplete() for t in self.subtasks): raise ValueError( f"Task {self.objective} cannot be marked successful until all of its " "subtasks are completed. Incomplete subtasks " - f"are: {[t.id for t in self.subtasks if t.is_incomplete()]}" + f"are: {', '.join(t.friendly_name() for t in self.subtasks if t.is_incomplete())}" ) if self.result_type is None and result is not None: @@ -383,7 +402,7 @@ def mark_successful(self, result: T = None, validate: bool = True): self.status = TaskStatus.SUCCESSFUL return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}" - def mark_failed(self, message: str | None = None): + def mark_failed(self, message: Union[str, None] = None): self.error = message self.status = TaskStatus.FAILED return f"{self.friendly_name()} marked failed. Updated task definition: {self.model_dump()}" @@ -419,7 +438,7 @@ def task( objective: str = None, instructions: str = None, agents: list["Agent"] = None, - tools: list[AssistantTool | Callable] = None, + tools: list[ToolType] = None, user_access: bool = None, ): """ diff --git a/src/controlflow/dx.py b/src/controlflow/dx.py deleted file mode 100644 index 59f768aa..00000000 --- a/src/controlflow/dx.py +++ /dev/null @@ -1,152 +0,0 @@ -import functools -import inspect -from typing import Callable, TypeVar - -from prefect import task as prefect_task - -from controlflow.core.agent import Agent -from controlflow.core.task import Task, TaskStatus -from controlflow.utilities.context import ctx -from controlflow.utilities.logging import get_logger -from controlflow.utilities.types import AssistantTool - -logger = get_logger(__name__) -T = TypeVar("T") -NOT_PROVIDED = object() - - -def task( - fn=None, - *, - objective: str = None, - agents: list[Agent] = None, - tools: list[AssistantTool | Callable] = None, - user_access: bool = None, -): - """ - Use a Python function to create an AI task. When the function is called, an - agent is created to complete the task and return the result. - """ - - if fn is None: - return functools.partial( - task, - objective=objective, - agents=agents, - tools=tools, - user_access=user_access, - ) - - sig = inspect.signature(fn) - - if objective is None: - if fn.__doc__: - objective = f"{fn.__name__}: {fn.__doc__}" - else: - objective = fn.__name__ - - @functools.wraps(fn) - def wrapper(*args, _agents: list[Agent] = None, **kwargs): - # first process callargs - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - - task = Task( - objective=objective, - agents=_agents or agents, - context=bound.arguments, - result_type=fn.__annotations__.get("return"), - user_access=user_access or False, - tools=tools or [], - ) - - task.run() - return task.result - - return wrapper - - -def _name_from_objective(): - """Helper function for naming task runs""" - from prefect.runtime import task_run - - objective = task_run.parameters.get("task") - - if not objective: - objective = "Follow general instructions" - if len(objective) > 75: - return f"Task: {objective[:75]}..." - return f"Task: {objective}" - - -@prefect_task(task_run_name=_name_from_objective) -def run_ai( - tasks: str | list[str], - agents: list[Agent] = None, - cast: T = NOT_PROVIDED, - context: dict = None, - tools: list[AssistantTool | Callable] = None, - user_access: bool = False, -) -> T | list[T]: - """ - Create and run an agent to complete a task with the given objective and - context. This function is similar to an inline version of the @task - decorator. - - This inline version is useful when you want to create and run an ad-hoc AI - task, without defining a function or using decorator syntax. It provides - more flexibility in terms of dynamically setting the task parameters. - Additional detail can be provided as `context`. - """ - - single_result = False - if isinstance(tasks, str): - single_result = True - - tasks = [tasks] - - if cast is NOT_PROVIDED: - if not tasks: - cast = None - else: - cast = str - - # load flow - flow = ctx.get("flow", None) - - # create tasks - if tasks: - tasks = [ - Task( - objective=t, - context=context or {}, - user_access=user_access or False, - tools=tools or [], - ) - for t in tasks - ] - else: - tasks = [] - - # create agent - if agents is None: - agents = [Agent(user_access=user_access or False)] - - # create Controller - from controlflow.core.controller.controller import Controller - - controller = Controller(tasks=tasks, agents=agents, flow=flow) - controller.run() - - if tasks: - if all(task.status == TaskStatus.SUCCESSFUL for task in tasks): - result = [task.result for task in tasks] - if single_result: - result = result[0] - return result - elif failed_tasks := [ - task for task in tasks if task.status == TaskStatus.FAILED - ]: - raise ValueError( - f'Failed tasks: {", ".join([task.objective for task in failed_tasks])}' - ) diff --git a/src/controlflow/instructions.py b/src/controlflow/instructions.py index db226366..55cc9faf 100644 --- a/src/controlflow/instructions.py +++ b/src/controlflow/instructions.py @@ -17,11 +17,13 @@ def instructions(*instructions: str) -> Generator[list[str], None, None]: ... """ + filtered_instructions = [i for i in instructions if i] + if not filtered_instructions: + yield + return stack: list[str] = ctx.get("instructions", []) - stack = stack + list(instructions) - - with ctx(instructions=stack): + with ctx(instructions=stack + list(filtered_instructions)): yield diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index ad4a2802..5538b149 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -3,9 +3,9 @@ import warnings from contextlib import contextmanager from copy import deepcopy -from typing import Any +from typing import Any, Optional, Union -from pydantic import Field +from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -47,17 +47,32 @@ def apply(self): class Settings(ControlFlowSettings): assistant_model: str = "gpt-4o" - max_task_iterations: int = None + max_task_iterations: Union[int, None] = Field( + None, + description="The maximum number of iterations to attempt to complete a task " + "before raising an error. If None, the task will run indefinitely. " + "This setting can be overridden by the `max_iterations` attribute " + "on a task.", + ) prefect: PrefectSettings = Field(default_factory=PrefectSettings) enable_global_flow: bool = Field( True, description="If True, a global flow is created for convenience, so users don't have to wrap every invocation in a flow function. Disable to avoid accidentally sharing context between agents.", ) + openai_api_key: Optional[str] = Field(None, validate_assignment=True) def __init__(self, **data): super().__init__(**data) self.prefect.apply() + @field_validator("openai_api_key", mode="after") + def _apply_api_key(cls, v): + if v is not None: + import marvin + + marvin.settings.openai.api_key = v + return v + settings = Settings() diff --git a/src/controlflow/utilities/context.py b/src/controlflow/utilities/context.py index c8d5cd1c..bafd489c 100644 --- a/src/controlflow/utilities/context.py +++ b/src/controlflow/utilities/context.py @@ -1,3 +1,8 @@ from marvin.utilities.context import ScopedContext -ctx = ScopedContext() +ctx = ScopedContext( + dict( + flow=None, + tasks=[], + ) +) diff --git a/src/controlflow/utilities/prefect.py b/src/controlflow/utilities/prefect.py index 348fbcf6..305d4b0a 100644 --- a/src/controlflow/utilities/prefect.py +++ b/src/controlflow/utilities/prefect.py @@ -13,7 +13,7 @@ from prefect.context import FlowRunContext, TaskRunContext from pydantic import TypeAdapter -from controlflow.utilities.types import AssistantTool +from controlflow.utilities.types import AssistantTool, ToolType def create_markdown_artifact( @@ -117,14 +117,28 @@ def create_python_artifact( ) -def wrap_prefect_tool(tool: AssistantTool | Callable) -> AssistantTool: +def safe_isinstance(obj, type_) -> bool: + # FunctionTool objects are typed generics, and + # Python 3.9 will raise an error if you try to isinstance a typed generic... + try: + return isinstance(obj, type_) + except TypeError: + try: + return issubclass(type(obj), type_) + except TypeError: + return False + + +def wrap_prefect_tool(tool: ToolType) -> AssistantTool: """ Wraps a Marvin tool in a prefect task """ - if not isinstance(tool, AssistantTool): + if not ( + safe_isinstance(tool, AssistantTool) or safe_isinstance(tool, FunctionTool) + ): tool = tool_from_function(tool) - if isinstance(tool, FunctionTool): + if safe_isinstance(tool, FunctionTool): # for functions, we modify the function to become a Prefect task and # publish an artifact that contains details about the function call diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index fe7292eb..baf32057 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -1,3 +1,5 @@ +from typing import Callable, Union + from marvin.beta.assistants import Assistant, Thread from marvin.beta.assistants.assistants import AssistantTool from marvin.types import FunctionTool @@ -7,6 +9,8 @@ # flag for unset defaults NOTSET = "__NOTSET__" +ToolType = Union[FunctionTool, AssistantTool, Callable] + class ControlFlowModel(BaseModel): model_config = dict(validate_assignment=True, extra="forbid") diff --git a/tests/conftest.py b/tests/conftest.py index d4bb2090..7799c921 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1,10 @@ +import pytest +from controlflow.settings import temporary_settings + from .fixtures import * + + +@pytest.fixture(autouse=True, scope="session") +def temp_settings(): + with temporary_settings(enable_global_flow=False, max_task_iterations=3): + yield diff --git a/tests/core/test_agents.py b/tests/core/test_agents.py deleted file mode 100644 index f532412c..00000000 --- a/tests/core/test_agents.py +++ /dev/null @@ -1,16 +0,0 @@ -from controlflow.core.agent import Agent -from pytest import patch - - -class TestAgent: - pass - - -class TestAgentRun: - def test_agent_run(self): - with patch( - "controlflow.core.controller.Controller._get_prefect_run_agent_task" - ) as mock_task: - agent = Agent() - agent.run() - mock_task.assert_called_once() diff --git a/tests/core/test_controller.py b/tests/core/test_controller.py index 670ccb83..a9953828 100644 --- a/tests/core/test_controller.py +++ b/tests/core/test_controller.py @@ -26,7 +26,6 @@ def test_controller_initialization(self, flow, agent, task): assert controller.flow == flow assert controller.tasks == [task] assert controller.agents == [agent] - assert controller.run_dependencies is True assert len(controller.context) == 0 assert len(controller.graph.tasks) == 1 assert len(controller.graph.edges) == 0 @@ -61,7 +60,7 @@ def test_controller_graph_creation(self, flow, agent): controller = Controller(flow=flow, tasks=[task1, task2], agents=[agent]) assert len(controller.graph.tasks) == 2 assert len(controller.graph.edges) == 1 - assert controller.graph.edges.pop().type == EdgeType.dependency + assert controller.graph.edges.pop().type == EdgeType.DEPENDENCY def test_controller_agent_selection(self, flow, monkeypatch): agent1 = Agent(name="Agent 1") @@ -74,12 +73,3 @@ def test_controller_agent_selection(self, flow, monkeypatch): mocked_marvin_moderator, ) assert controller.agents == [agent1, agent2] - - async def test_controller_run_dependencies(self, flow, agent): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - controller = Controller(flow=flow, tasks=[task2], agents=[agent]) - mocked_run_agent = AsyncMock() - controller._run_agent = mocked_run_agent - await controller.run_once_async() - mocked_run_agent.assert_called_once_with(agent, tasks=[task1, task2]) diff --git a/tests/core/test_flows.py b/tests/core/test_flows.py index 123016df..acf8c63b 100644 --- a/tests/core/test_flows.py +++ b/tests/core/test_flows.py @@ -1,6 +1,4 @@ -# test_flow.py -from unittest.mock import MagicMock - +import pytest from controlflow.core.agent import Agent from controlflow.core.flow import Flow, get_flow from controlflow.utilities.context import ctx @@ -11,8 +9,7 @@ def test_flow_initialization(self): flow = Flow() assert flow.thread is not None assert len(flow.tools) == 0 - assert len(flow.agents) == 1 - assert isinstance(flow.agents[0], Agent) + assert len(flow.agents) == 0 assert len(flow.context) == 0 def test_flow_with_custom_agents(self): @@ -40,31 +37,20 @@ def test_flow_with_custom_context(self): assert len(flow.context) == 1 assert flow.context["key"] == "value" - def test_add_message(self, monkeypatch): - flow = Flow() - mocked_add = MagicMock() - monkeypatch.setattr(flow.thread, "add", mocked_add) - flow.add_message("Test message", role="user") - mocked_add.assert_called_once_with("Test message", role="user") - def test_flow_context_manager(self): with Flow() as flow: assert ctx.get("flow") == flow assert ctx.get("tasks") == [] assert ctx.get("flow") is None - assert ctx.get("tasks") is None + assert ctx.get("tasks") == [] def test_get_flow_within_context(self): with Flow() as flow: assert get_flow() == flow def test_get_flow_without_context(self): - flow1 = get_flow() - with Flow() as flow2: - pass - flow3 = get_flow() - assert flow1 == flow3 - assert flow1 != flow2 + with pytest.raises(ValueError, match="No flow found in context."): + get_flow() def test_get_flow_nested_contexts(self): with Flow() as flow1: diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py index 27951ced..e462f86c 100644 --- a/tests/core/test_graph.py +++ b/tests/core/test_graph.py @@ -1,6 +1,6 @@ # test_graph.py from controlflow.core.graph import Edge, EdgeType, Graph -from controlflow.core.task import Task +from controlflow.core.task import Task, TaskStatus class TestGraph: @@ -74,27 +74,52 @@ def test_upstream_dependencies(self): task2 = Task(objective="Task 2", depends_on=[task1]) task3 = Task(objective="Task 3", parent=task2) graph = Graph.from_tasks([task1, task2, task3]) - dependencies = graph.upstream_dependencies([task3]) + dependencies = graph.upstream_dependencies([task2]) + assert len(dependencies) == 2 + assert task1 in dependencies + assert task3 in dependencies + + def test_upstream_dependencies_include_tasks(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + task3 = Task(objective="Task 3", parent=task2) + graph = Graph.from_tasks([task1, task2, task3]) + dependencies = graph.upstream_dependencies([task2], include_tasks=True) assert len(dependencies) == 3 assert task1 in dependencies assert task2 in dependencies assert task3 in dependencies + def test_upstream_dependencies_prune(self): + task1 = Task(objective="Task 1", status=TaskStatus.SUCCESSFUL) + task2 = Task(objective="Task 2", depends_on=[task1], status=TaskStatus.FAILED) + task3 = Task(objective="Task 3", depends_on=[task2]) + graph = Graph.from_tasks([task1, task2, task3]) + dependencies = graph.upstream_dependencies([task3]) + assert len(dependencies) == 1 + assert task2 in dependencies + dependencies = graph.upstream_dependencies([task3], prune_completed=False) + assert len(dependencies) == 2 + assert task1 in dependencies + assert task2 in dependencies + def test_ready_tasks(self): task1 = Task(objective="Task 1") task2 = Task(objective="Task 2", depends_on=[task1]) task3 = Task(objective="Task 3", parent=task2) graph = Graph.from_tasks([task1, task2, task3]) ready_tasks = graph.ready_tasks() - assert len(ready_tasks) == 1 + assert len(ready_tasks) == 2 assert task1 in ready_tasks + assert task3 in ready_tasks task1.mark_successful() ready_tasks = graph.ready_tasks() - assert len(ready_tasks) == 1 + assert len(ready_tasks) == 2 assert task2 in ready_tasks + assert task3 in ready_tasks - task2.mark_successful() + task3.mark_successful() ready_tasks = graph.ready_tasks() assert len(ready_tasks) == 1 - assert task3 in ready_tasks + assert task2 in ready_tasks diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 590933dd..6e1c98c2 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -1,7 +1,10 @@ +from unittest.mock import AsyncMock + +import pytest from controlflow.core.agent import Agent from controlflow.core.flow import Flow from controlflow.core.graph import EdgeType -from controlflow.core.task import Task, TaskStatus, get_tasks +from controlflow.core.task import Task, TaskStatus from controlflow.utilities.context import ctx @@ -15,16 +18,6 @@ def test_context_open_and_close(): assert ctx.get("tasks") == [] -def test_get_tasks_function(): - # assert get_tasks() == [] - with Task("a") as ta: - assert get_tasks() == [ta] - with Task("b") as tb: - assert get_tasks() == [ta, tb] - assert get_tasks() == [ta] - assert get_tasks() == [] - - def test_task_initialization(): task = Task(objective="Test objective") assert task.objective == "Test objective" @@ -53,10 +46,11 @@ def test_task_agent_assignment(): assert agent in task.agents -def test_task_context(): - with Flow(): +def test_task_tracking(mock_run): + with Flow() as flow: task = Task(objective="Test objective") - assert task in Task._context_stack + task.run_once() + assert task in flow._tasks.values() def test_task_status_transitions(): @@ -91,6 +85,24 @@ def test_task_status_transitions(): assert task.is_skipped() +def test_validate_upstream_dependencies_on_success(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + with pytest.raises(ValueError, match="cannot be marked successful"): + task2.mark_successful() + task1.mark_successful() + task2.mark_successful() + + +def test_validate_subtask_dependencies_on_success(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", parent=task1) + with pytest.raises(ValueError, match="cannot be marked successful"): + task1.mark_successful() + task2.mark_successful() + task1.mark_successful() + + def test_task_ready(): task1 = Task(objective="Task 1") task2 = Task(objective="Task 2", depends_on=[task1]) @@ -109,13 +121,19 @@ def test_task_hash(): def test_task_tools(): task = Task(objective="Test objective") tools = task.get_tools() - assert any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) - assert any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) + assert any(tool.function.name == f"mark_task_{task.id}_failed" for tool in tools) + assert any( + tool.function.name == f"mark_task_{task.id}_successful" for tool in tools + ) task.mark_successful() tools = task.get_tools() - assert not any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) - assert not any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) + assert not any( + tool.function.name == f"mark_task_{task.id}_failed" for tool in tools + ) + assert not any( + tool.function.name == f"mark_task_{task.id}_successful" for tool in tools + ) class TestTaskToGraph: @@ -178,3 +196,52 @@ def test_task_with_subtasks_and_dependencies_graph(self): and edge.type == EdgeType.SUBTASK for edge in graph.edges ) + + +@pytest.mark.usefixtures("mock_run") +class TestTaskRun: + def test_run_task_max_iterations(self, mock_run: AsyncMock): + task = Task(objective="Say hello") + + with Flow(): + with pytest.raises(ValueError): + task.run() + + assert mock_run.await_count == 3 + + def test_run_task_mark_successful(self, mock_run: AsyncMock): + task = Task(objective="Say hello") + + def mark_complete(): + task.mark_successful() + + mock_run.side_effect = mark_complete + with Flow(): + result = task.run() + assert task.is_successful() + assert result is None + + def test_run_task_mark_successful_with_result(self, mock_run: AsyncMock): + task = Task(objective="Say hello", result_type=int) + + def mark_complete(): + task.mark_successful(result=42) + + mock_run.side_effect = mark_complete + with Flow(): + result = task.run() + assert task.is_successful() + assert result == 42 + + def test_run_task_mark_failed(self, mock_run: AsyncMock): + task = Task(objective="Say hello") + + def mark_complete(): + task.mark_failed(message="Failed to say hello") + + mock_run.side_effect = mark_complete + with Flow(): + with pytest.raises(ValueError): + task.run() + assert task.is_failed() + assert task.error == "Failed to say hello" diff --git a/tests/fixtures/flows.py b/tests/fixtures/flows.py deleted file mode 100644 index f3d80c8c..00000000 --- a/tests/fixtures/flows.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest -from controlflow.settings import temporary_settings - - -@pytest.fixture(autouse=True, scope="session") -def disable_global_flow(): - with temporary_settings(enable_global_flow=False): - yield diff --git a/tests/fixtures/mocks.py b/tests/fixtures/mocks.py index e42b599e..52458981 100644 --- a/tests/fixtures/mocks.py +++ b/tests/fixtures/mocks.py @@ -1,9 +1,8 @@ -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from controlflow.utilities.user_access import talk_to_human - # @pytest.fixture(autouse=True) # def mock_talk_to_human(): # """Return an empty default handler instead of a print handler to avoid @@ -18,3 +17,27 @@ # "controlflow.utilities.user_access.mock_talk_to_human", new=talk_to_human # ): # yield + + +@pytest.fixture +def mock_run(monkeypatch): + """ + This fixture mocks the calls to OpenAI. Use it in a test and assign any desired side effects (like completing a task) + to the mock object's `.side_effect` attribute. + + For example: + + def test_example(mock_run): + task = Task(objective="Say hello") + + def side_effect(): + task.mark_complete() + + mock_run.side_effect = side_effect + + task.run() + + """ + MockRun = AsyncMock() + monkeypatch.setattr("controlflow.core.controller.controller.Run.run_async", MockRun) + yield MockRun diff --git a/tests/flows/test_sign_guestbook.py b/tests/flows/test_sign_guestbook.py index 1f841c86..cd85a3cd 100644 --- a/tests/flows/test_sign_guestbook.py +++ b/tests/flows/test_sign_guestbook.py @@ -1,4 +1,5 @@ -from controlflow import Agent, flow, run_ai +import pytest +from controlflow import Agent, Task, flow # define assistants @@ -27,7 +28,7 @@ def view_guestbook(): @flow def guestbook_flow(): - run_ai( + task = Task( """ Add your name to the list using the `sign` tool. All assistants must sign their names for the task to be complete. You can read the sign to @@ -36,11 +37,13 @@ def guestbook_flow(): agents=[a, b, c], tools=[sign, view_guestbook], ) + task.run() # run test +@pytest.mark.skip(reason="Skipping test for now") def test(): guestbook_flow() assert GUESTBOOK == ["a", "b", "c"] diff --git a/tests/flows/test_user_access.py b/tests/flows/test_user_access.py index 13dff2a0..d66ba3f5 100644 --- a/tests/flows/test_user_access.py +++ b/tests/flows/test_user_access.py @@ -1,8 +1,9 @@ import pytest -from controlflow import Agent, flow, run_ai +from controlflow import Agent, Task, flow -# define assistants +pytest.skip("Skipping the entire file", allow_module_level=True) +# define assistants user_agent = Agent(name="user-agent", user_access=True) non_user_agent = Agent(name="non-user-agent", user_access=False) @@ -10,10 +11,11 @@ def test_no_user_access_fails(): @flow def user_access_flow(): - run_ai( + task = Task( "This task requires human user access. Inform the user that today is a good day.", agents=[non_user_agent], ) + task.run() with pytest.raises(ValueError): user_access_flow() @@ -22,10 +24,11 @@ def user_access_flow(): def test_user_access_agent_succeeds(): @flow def user_access_flow(): - run_ai( + task = Task( "This task requires human user access. Inform the user that today is a good day.", agents=[user_agent], ) + task.run() assert user_access_flow() @@ -33,11 +36,12 @@ def user_access_flow(): def test_user_access_task_succeeds(): @flow def user_access_flow(): - run_ai( + task = Task( "This task requires human user access. Inform the user that today is a good day.", agents=[non_user_agent], user_access=True, ) + task.run() assert user_access_flow() @@ -45,10 +49,11 @@ def user_access_flow(): def test_user_access_agent_and_task_succeeds(): @flow def user_access_flow(): - run_ai( + task = Task( "This task requires human user access. Inform the user that today is a good day.", agents=[user_agent], user_access=True, ) + task.run() assert user_access_flow()