diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index e8571832..c17daf53 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -42,7 +42,7 @@ jobs: runs-on: ${{ matrix.os }} env: - CONTROLFLOW_OPENAI_API_KEY: ${{ secrets.CONTROLFLOW_OPENAI_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} steps: - uses: actions/checkout@v4 diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index ebfdf816..6e98cc62 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -1,28 +1,32 @@ from .settings import settings import controlflow.llm -# --- Default model --- -# assign to controlflow.default_model to change the default model -from .llm.models import DEFAULT_MODEL as default_model - -from .core.flow import Flow -from .core.task import Task from .core.agent import Agent +from .core.task import Task +from .core.flow import Flow from .core.controller.controller import Controller from .instructions import instructions from .decorators import flow, task +# --- Default settings --- + +from .llm.models import model_from_string, get_default_model +from .llm.history import InMemoryHistory, get_default_history + +# assign to controlflow.default_model to change the default model +default_model = model_from_string(controlflow.settings.llm_model) +del model_from_string -# --- Default history --- # assign to controlflow.default_history to change the default history -from .llm.history import DEFAULT_HISTORY as default_history, get_default_history +default_history = InMemoryHistory() +del InMemoryHistory -# --- Default agent --- # assign to controlflow.default_agent to change the default agent -from .core.agent.agent import DEFAULT_AGENT as default_agent, get_default_agent +default_agent = Agent(name="Marvin") # --- Version --- + try: from ._version import version as __version__ # type: ignore except ImportError: diff --git a/src/controlflow/core/agent/agent.py b/src/controlflow/core/agent/agent.py index db407ac9..e8e40002 100644 --- a/src/controlflow/core/agent/agent.py +++ b/src/controlflow/core/agent/agent.py @@ -1,5 +1,6 @@ import logging import random +import re import uuid from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Optional @@ -24,12 +25,19 @@ def get_default_agent() -> "Agent": return controlflow.default_agent +def sanitize_name(name): + """ + Replace any invalid characters with `-`, due to restrictions on names in the API + """ + sanitized_string = re.sub(r"[^a-zA-Z0-9_-]", "-", name) + return sanitized_string + + class Agent(ControlFlowModel): id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5])) model_config = dict(arbitrary_types_allowed=True) name: str = Field( description="The name of the agent.", - pattern=r"^[a-zA-Z0-9_-]+$", default_factory=lambda: random.choice(NAMES), ) description: Optional[str] = Field( @@ -68,6 +76,10 @@ def _serialize_tools(self, tools: list[Callable]): # tools are Pydantic 1 objects return [t.dict(include={"name", "description"}) for t in tools] + @field_serializer("name") + def _serialize_name(self, name: str): + return sanitize_name(name) + def __init__(self, name=None, **kwargs): if name is not None: kwargs["name"] = name diff --git a/src/controlflow/core/agent/names.py b/src/controlflow/core/agent/names.py index a276235b..1489c142 100644 --- a/src/controlflow/core/agent/names.py +++ b/src/controlflow/core/agent/names.py @@ -1,18 +1,18 @@ NAMES = [ - "HAL-9000", + "HAL 9000", "R2-D2", "C-3PO", "WALL-E", "T-800", "GLaDOS", - "JARVIS", + "J.A.R.V.I.S", "EVE", "KITT", - "Johnny-5", + "Johnny 5", "BB-8", "Ultron", "TARS", - "Agent-Smith", + "Agent Smith", "CLU", "Deckard", "HK-47", diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index fef9254c..39659a08 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -104,7 +104,7 @@ def graph(self) -> Graph: @model_validator(mode="after") def _finalize(self): if self.tasks is None: - self.tasks = list(self.flow._tasks.values()) + self.tasks = list(self.flow.tasks.values()) for task in self.tasks: self.flow.add_task(task) return self diff --git a/src/controlflow/core/flow.py b/src/controlflow/core/flow.py index d80fe337..cf91f891 100644 --- a/src/controlflow/core/flow.py +++ b/src/controlflow/core/flow.py @@ -1,12 +1,14 @@ import datetime import uuid from contextlib import contextmanager, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union from pydantic import Field import controlflow import controlflow.llm +from controlflow.core.agent import Agent +from controlflow.core.task import Task from controlflow.llm.history import History, get_default_history from controlflow.llm.messages import MessageType from controlflow.utilities.context import ctx @@ -14,9 +16,6 @@ from controlflow.utilities.prefect import prefect_flow_context from controlflow.utilities.types import ControlFlowModel -if TYPE_CHECKING: - from controlflow.core.agent import Agent - from controlflow.core.task import Task logger = get_logger(__name__) @@ -31,13 +30,13 @@ class Flow(ControlFlowModel): default_factory=list, description="Tools that will be available to every agent in the flow", ) - agents: list["Agent"] = Field( + agents: list[Agent] = Field( description="The default agents for the flow. These agents will be used " "for any task that does not specify agents.", default_factory=list, ) context: dict[str, Any] = {} - _tasks: dict[str, "Task"] = {} + tasks: dict[str, Task] = {} _cm_stack: list[contextmanager] = [] def __init__(self, *, copy_parent_history: bool = True, **kwargs): @@ -74,12 +73,12 @@ def get_messages( def add_messages(self, messages: list[MessageType]): self.history.save_messages(thread_id=self.thread_id, messages=messages) - def add_task(self, task: "Task"): - if self._tasks.get(task.id, task) is not task: + def add_task(self, task: Task): + if self.tasks.get(task.id, task) is not task: raise ValueError( f"A different task with id '{task.id}' already exists in flow." ) - self._tasks[task.id] = task + self.tasks[task.id] = task @contextmanager def create_context(self, create_prefect_flow_context: bool = True): @@ -94,7 +93,7 @@ async def run_async(self): """ Runs the flow asynchronously. """ - if self._tasks: + if self.tasks: controller = controlflow.Controller(flow=self) await controller.run_async() @@ -102,7 +101,7 @@ def run(self): """ Runs the flow. """ - if self._tasks: + if self.tasks: controller = controlflow.Controller(flow=self) controller.run() diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 2182a346..982cb83f 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -28,6 +28,7 @@ import controlflow import controlflow.core +from controlflow.core.agent import Agent from controlflow.instructions import get_instructions from controlflow.llm.tools import Tool from controlflow.tools.talk_to_human import talk_to_human @@ -47,7 +48,6 @@ ) if TYPE_CHECKING: - from controlflow.core.agent import Agent from controlflow.core.flow import Flow from controlflow.core.graph import Graph @@ -57,7 +57,7 @@ def get_task_run_name() -> str: context = TaskRunContext.get() - return f'Run {context.parameters['self'].friendly_name()}' + return f'Run {context.parameters["self"].friendly_name()}' class TaskStatus(Enum): @@ -128,11 +128,11 @@ class Task(ControlFlowModel): def __init__( self, objective=None, - result_type=None, + result_type=NOTSET, **kwargs, ): # allow certain args to be provided as a positional args - if result_type is not None: + if result_type is not NOTSET: kwargs["result_type"] = result_type if objective is not None: kwargs["objective"] = objective @@ -442,7 +442,7 @@ def is_ready(self) -> bool: """ return self.is_incomplete() and all(t.is_complete() for t in self.depends_on) - def _create_success_tool(self) -> Callable: + def _create_success_tool(self) -> Tool: """ Create an agent-compatible tool for marking this task as successful. """ @@ -466,7 +466,7 @@ def succeed(result: result_schema) -> str: # type: ignore metadata=dict(is_task_status_tool=True), ) - def _create_fail_tool(self) -> Callable: + def _create_fail_tool(self) -> Tool: """ Create an agent-compatible tool for failing this task. """ @@ -478,7 +478,7 @@ def _create_fail_tool(self) -> Callable: metadata=dict(is_task_status_tool=True), ) - def _create_skip_tool(self) -> Callable: + def _create_skip_tool(self) -> Tool: """ Create an agent-compatible tool for skipping this task. """ @@ -525,9 +525,10 @@ def get_agent_strategy(self) -> Callable: return controlflow.agent_strategies.round_robin - def get_tools(self) -> list[Callable]: + def get_tools(self) -> list[Union[Tool, Callable]]: tools = self.tools.copy() - if self.is_incomplete(): + # if this task is ready to run, generate tools + if self.is_ready: tools.extend([self._create_fail_tool(), self._create_success_tool()]) # add skip tool if this task has a parent task # if self.parent is not None: diff --git a/src/controlflow/llm/history.py b/src/controlflow/llm/history.py index d70a10e3..40c30d32 100644 --- a/src/controlflow/llm/history.py +++ b/src/controlflow/llm/history.py @@ -104,6 +104,3 @@ def save_messages(self, thread_id: str, messages: list[MessageType]): all_messages.extend([msg.model_dump(mode="json") for msg in messages]) with open(self.path(thread_id), "w") as f: json.dump(all_messages, f) - - -DEFAULT_HISTORY = InMemoryHistory() diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py index a1c56a60..bdc22a99 100644 --- a/src/controlflow/llm/models.py +++ b/src/controlflow/llm/models.py @@ -1,16 +1,20 @@ +from typing import Any, Optional + from langchain_core.language_models import BaseChatModel import controlflow def get_default_model() -> BaseChatModel: - if controlflow.default_model is None: + if getattr(controlflow, "default_model", None) is None: return model_from_string(controlflow.settings.llm_model) else: return controlflow.default_model -def model_from_string(model: str, temperature: float = None, **kwargs) -> BaseChatModel: +def model_from_string( + model: str, temperature: Optional[float] = None, **kwargs: Any +) -> BaseChatModel: if "/" not in model: provider, model = "openai", model provider, model = model.split("/") diff --git a/src/controlflow/tui/app.py b/src/controlflow/tui/app.py index cb28f834..4ba378f7 100644 --- a/src/controlflow/tui/app.py +++ b/src/controlflow/tui/app.py @@ -33,7 +33,7 @@ class TUIApp(App): def __init__(self, flow: "controlflow.Flow", **kwargs): self._flow = flow - self._tasks = flow._tasks + self._tasks = flow.tasks self._is_ready = False super().__init__(**kwargs) diff --git a/tests/ai_tests/test_tasks.py b/tests/ai_tests/test_tasks.py index dae7b13e..b07d21bb 100644 --- a/tests/ai_tests/test_tasks.py +++ b/tests/ai_tests/test_tasks.py @@ -21,6 +21,7 @@ def test_task_pydantic_result(self): assert isinstance(result, Name) assert result == Name(first="John", last="Doe") + @pytest.mark.xfail(reason="Need to revisit dataframe handling") def test_task_dataframe_result(self): task = Task( 'return a dataframe with column "x" that has values 1 and 2 and column "y" that has values 3 and 4', diff --git a/tests/conftest.py b/tests/conftest.py index 027145f2..b3b0aa6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import pytest +from controlflow.llm.messages import MessageType from controlflow.settings import temporary_settings from prefect.testing.utilities import prefect_test_harness diff --git a/tests/core/agents.py b/tests/core/agents.py deleted file mode 100644 index 465d63cf..00000000 --- a/tests/core/agents.py +++ /dev/null @@ -1,23 +0,0 @@ -from unittest.mock import patch - -from controlflow.core.agent import Agent -from controlflow.core.task import Task - - -class TestAgent: - pass - - -class TestAgentRun: - def test_agent_run(self): - with patch( - "controlflow.core.controller.Controller._get_prefect_run_agent_task" - ) as mock_task: - agent = Agent() - agent.run() - mock_task.assert_called_once() - - def test_agent_run_with_task(self): - task = Task("say hello") - agent = Agent() - agent.run(tasks=[task]) diff --git a/tests/core/test_agents.py b/tests/core/test_agents.py new file mode 100644 index 00000000..00760bd4 --- /dev/null +++ b/tests/core/test_agents.py @@ -0,0 +1,40 @@ +import controlflow +from controlflow.core.agent import Agent, get_default_agent +from controlflow.core.agent.names import NAMES +from controlflow.core.task import Task + + +class TestAgentInitialization: + def test_agent_gets_random_name(self): + agent = Agent() + + assert agent.name in NAMES + + def test_agent_default_model(self): + agent = Agent() + + assert agent.model is controlflow.get_default_model() + + +class TestDefaultAgent: + def test_default_agent_is_marvin(self): + agent = get_default_agent() + assert agent.name == "Marvin" + + def test_default_agent_has_no_tools(self): + assert get_default_agent().tools == [] + + def test_default_agent_can_be_assigned(self): + # baseline + assert get_default_agent().name == "Marvin" + + new_default_agent = Agent(name="New Agent") + controlflow.default_agent = new_default_agent + + assert get_default_agent().name == "New Agent" + assert Task("task").get_agents()[0] is new_default_agent + assert [a.name for a in Task("task").get_agents()] == ["New Agent"] + + def test_default_agent(self): + assert get_default_agent().name == "Marvin" + assert Task("task").get_agents()[0] is get_default_agent() diff --git a/tests/core/test_controller.py b/tests/core/test_controller.py deleted file mode 100644 index 6422dc6c..00000000 --- a/tests/core/test_controller.py +++ /dev/null @@ -1,75 +0,0 @@ -from unittest.mock import AsyncMock - -import pytest -from controlflow.core.agent import Agent -from controlflow.core.controller.controller import Controller -from controlflow.core.flow import Flow -from controlflow.core.graph import EdgeType -from controlflow.core.task import Task - - -class TestController: - @pytest.fixture - def flow(self): - return Flow() - - @pytest.fixture - def agent(self): - return Agent(name="Test Agent") - - @pytest.fixture - def task(self): - return Task(objective="Test Task") - - def test_controller_initialization(self, flow, agent, task): - controller = Controller(flow=flow, tasks=[task], agents=[agent]) - assert controller.flow == flow - assert controller.tasks == [task] - assert controller.agents == [agent] - assert len(controller.context) == 0 - assert len(controller.graph.tasks) == 1 - assert len(controller.graph.edges) == 0 - - def test_controller_missing_tasks(self, flow): - with pytest.raises(ValueError, match="At least one task is required."): - Controller(flow=flow, tasks=[]) - - async def test_run_agent(self, flow, agent, task, monkeypatch): - controller = Controller(flow=flow, tasks=[task], agents=[agent]) - mocked_run = AsyncMock() - monkeypatch.setattr(Agent, "run", mocked_run) - await controller._run_agent(agent, tasks=[task]) - mocked_run.assert_called_once_with(tasks=[task]) - - async def test_run_once(self, flow, agent, task, monkeypatch): - controller = Controller(flow=flow, tasks=[task], agents=[agent]) - mocked_run_agent = AsyncMock() - monkeypatch.setattr(Controller, "_run_agent", mocked_run_agent) - await controller.run_once_async() - mocked_run_agent.assert_called_once_with(agent, tasks=[task]) - - def test_create_end_run_tool(self, flow, agent, task): - controller = Controller(flow=flow, tasks=[task], agents=[agent]) - end_run_tool = controller._create_end_run_tool() - assert end_run_tool.function.name == "end_run" - assert end_run_tool.function.description.startswith("End your turn") - - def test_controller_graph_creation(self, flow, agent): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - controller = Controller(flow=flow, tasks=[task1, task2], agents=[agent]) - assert len(controller.graph.tasks) == 2 - assert len(controller.graph.edges) == 1 - assert controller.graph.edges.pop().type == EdgeType.DEPENDENCY - - def test_controller_agent_selection(self, flow, monkeypatch): - agent1 = Agent(name="Agent 1") - agent2 = Agent(name="Agent 2") - task = Task(objective="Test Task", agents=[agent1, agent2]) - controller = Controller(flow=flow, tasks=[task], agents=[agent1, agent2]) - mocked_classify_moderator = AsyncMock(return_value=agent1) - monkeypatch.setattr( - "controlflow.core.controller.moderators.classify_moderator", - mocked_classify_moderator, - ) - assert controller.agents == [agent1, agent2] diff --git a/tests/core/test_flows.py b/tests/core/test_flows.py index 180a08f9..488bc031 100644 --- a/tests/core/test_flows.py +++ b/tests/core/test_flows.py @@ -1,9 +1,11 @@ from controlflow.core.agent import Agent from controlflow.core.flow import Flow, get_flow +from controlflow.core.task import Task +from controlflow.llm.messages import HumanMessage from controlflow.utilities.context import ctx -class TestFlow: +class TestFlowInitialization: def test_flow_initialization(self): flow = Flow() assert flow.thread_id is not None @@ -11,14 +13,6 @@ def test_flow_initialization(self): assert len(flow.agents) == 0 assert len(flow.context) == 0 - def test_flow_with_custom_agents(self): - agent1 = Agent(name="Agent 1") - agent2 = Agent(name="Agent 2") - flow = Flow(agents=[agent1, agent2]) - assert len(flow.agents) == 2 - assert agent1 in flow.agents - assert agent2 in flow.agents - def test_flow_with_custom_tools(self): def tool1(): pass @@ -36,6 +30,8 @@ def test_flow_with_custom_context(self): assert len(flow.context) == 1 assert flow.context["key"] == "value" + +class TestFlowContext: def test_flow_context_manager(self): with Flow() as flow: assert ctx.get("flow") == flow @@ -50,9 +46,124 @@ def test_get_flow_within_context(self): def test_get_flow_without_context(self): assert get_flow() is None + def test_reentrant_flow_context(self): + flow = Flow() + with flow: + assert get_flow() is flow + with flow: + assert get_flow() is flow + with flow: + assert get_flow() is flow + assert get_flow() is flow + assert get_flow() is flow + assert get_flow() is None + def test_get_flow_nested_contexts(self): with Flow() as flow1: assert get_flow() == flow1 with Flow() as flow2: assert get_flow() == flow2 assert get_flow() == flow1 + assert get_flow() is None + + def test_tasks_created_in_flow_context(self): + with Flow() as flow: + t1 = Task("test 1") + t2 = Task("test 2") + + assert flow.tasks == {t1.id: t1, t2.id: t2} + + def test_tasks_created_in_nested_flows_only_in_inner_flow(self): + with Flow() as flow1: + t1 = Task("test 1") + with Flow() as flow2: + t2 = Task("test 2") + + assert flow1.tasks == {t1.id: t1} + assert flow2.tasks == {t2.id: t2} + + +class TestFlowHistory: + def test_get_messages_empty(self): + flow = Flow() + messages = flow.get_messages() + assert messages == [] + + def test_add_messages_with_history(self): + flow = Flow() + flow.add_messages( + messages=[HumanMessage(content="hello"), HumanMessage(content="world")] + ) + messages = flow.get_messages() + assert len(messages) == 2 + assert [m.content for m in messages] == ["hello", "world"] + + def test_copy_parent_history(self): + flow1 = Flow() + flow1.add_messages( + messages=[HumanMessage(content="hello"), HumanMessage(content="world")] + ) + + with flow1: + flow2 = Flow() + + messages1 = flow1.get_messages() + assert len(messages1) == 2 + assert [m.content for m in messages1] == ["hello", "world"] + + messages2 = flow2.get_messages() + assert len(messages2) == 2 + assert [m.content for m in messages2] == ["hello", "world"] + + def test_disable_copying_parent_history(self): + flow1 = Flow() + flow1.add_messages( + messages=[HumanMessage(content="hello"), HumanMessage(content="world")] + ) + + with flow1: + flow2 = Flow(copy_parent_history=False) + + messages1 = flow1.get_messages() + assert len(messages1) == 2 + assert [m.content for m in messages1] == ["hello", "world"] + + messages2 = flow2.get_messages() + assert len(messages2) == 0 + + def test_child_flow_messages_dont_go_to_parent(self): + flow1 = Flow() + flow1.add_messages( + messages=[HumanMessage(content="hello"), HumanMessage(content="world")] + ) + + with flow1: + flow2 = Flow() + flow2.add_messages(messages=[HumanMessage(content="goodbye")]) + + messages1 = flow1.get_messages() + assert len(messages1) == 2 + assert [m.content for m in messages1] == ["hello", "world"] + + messages2 = flow2.get_messages() + assert len(messages2) == 3 + assert [m.content for m in messages2] == ["hello", "world", "goodbye"] + + +class TestFlowCreatesDefaults: + def test_flow_with_custom_agents(self): + agent1 = Agent(name="Agent 1") + agent2 = Agent(name="Agent 2") + flow = Flow(agents=[agent1, agent2]) + assert len(flow.agents) == 2 + assert agent1 in flow.agents + assert agent2 in flow.agents + + def test_flow_agent_becomes_task_default(self): + agent = Agent() + t1 = Task("t1") + assert t1.agents != [agent] + + with Flow(agents=[agent]): + t2 = Task("t2") + assert t2.get_agents() == [agent] diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py index e462f86c..6e12eee8 100644 --- a/tests/core/test_graph.py +++ b/tests/core/test_graph.py @@ -1,125 +1,109 @@ # test_graph.py from controlflow.core.graph import Edge, EdgeType, Graph -from controlflow.core.task import Task, TaskStatus - - -class TestGraph: - def test_graph_initialization(self): - graph = Graph() - assert len(graph.tasks) == 0 - assert len(graph.edges) == 0 - - def test_add_task(self): - graph = Graph() - task = Task(objective="Test objective") - graph.add_task(task) - assert len(graph.tasks) == 1 - assert task in graph.tasks - - def test_add_edge(self): - graph = Graph() - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2") - edge = Edge(upstream=task1, downstream=task2, type=EdgeType.DEPENDENCY) - graph.add_edge(edge) - assert len(graph.tasks) == 2 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert len(graph.edges) == 1 - assert edge in graph.edges - - def test_from_tasks(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - task3 = Task(objective="Task 3", parent=task2) - graph = Graph.from_tasks([task1, task2, task3]) - assert len(graph.tasks) == 3 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert task3 in graph.tasks - assert len(graph.edges) == 2 - assert any( - edge.upstream == task1 - and edge.downstream == task2 - and edge.type == EdgeType.DEPENDENCY - for edge in graph.edges - ) - assert any( - edge.upstream == task3 - and edge.downstream == task2 - and edge.type == EdgeType.SUBTASK - for edge in graph.edges - ) - - def test_upstream_edges(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - graph = Graph.from_tasks([task1, task2]) - upstream_edges = graph.upstream_edges() - assert len(upstream_edges[task1]) == 0 - assert len(upstream_edges[task2]) == 1 - assert upstream_edges[task2][0].upstream == task1 - - def test_downstream_edges(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - graph = Graph.from_tasks([task1, task2]) - downstream_edges = graph.downstream_edges() - assert len(downstream_edges[task1]) == 1 - assert len(downstream_edges[task2]) == 0 - assert downstream_edges[task1][0].downstream == task2 - - def test_upstream_dependencies(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - task3 = Task(objective="Task 3", parent=task2) - graph = Graph.from_tasks([task1, task2, task3]) - dependencies = graph.upstream_dependencies([task2]) - assert len(dependencies) == 2 - assert task1 in dependencies - assert task3 in dependencies - - def test_upstream_dependencies_include_tasks(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - task3 = Task(objective="Task 3", parent=task2) - graph = Graph.from_tasks([task1, task2, task3]) - dependencies = graph.upstream_dependencies([task2], include_tasks=True) - assert len(dependencies) == 3 - assert task1 in dependencies - assert task2 in dependencies - assert task3 in dependencies - - def test_upstream_dependencies_prune(self): - task1 = Task(objective="Task 1", status=TaskStatus.SUCCESSFUL) - task2 = Task(objective="Task 2", depends_on=[task1], status=TaskStatus.FAILED) - task3 = Task(objective="Task 3", depends_on=[task2]) - graph = Graph.from_tasks([task1, task2, task3]) - dependencies = graph.upstream_dependencies([task3]) - assert len(dependencies) == 1 - assert task2 in dependencies - dependencies = graph.upstream_dependencies([task3], prune_completed=False) - assert len(dependencies) == 2 - assert task1 in dependencies - assert task2 in dependencies - - def test_ready_tasks(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - task3 = Task(objective="Task 3", parent=task2) - graph = Graph.from_tasks([task1, task2, task3]) - ready_tasks = graph.ready_tasks() - assert len(ready_tasks) == 2 - assert task1 in ready_tasks - assert task3 in ready_tasks - - task1.mark_successful() - ready_tasks = graph.ready_tasks() - assert len(ready_tasks) == 2 - assert task2 in ready_tasks - assert task3 in ready_tasks - - task3.mark_successful() - ready_tasks = graph.ready_tasks() - assert len(ready_tasks) == 1 - assert task2 in ready_tasks +from controlflow.core.task import Task + + +def test_graph_initialization(): + graph = Graph() + assert len(graph.tasks) == 0 + assert len(graph.edges) == 0 + + +def test_add_task(): + graph = Graph() + task = Task(objective="Test objective") + graph.add_task(task) + assert len(graph.tasks) == 1 + assert task in graph.tasks + + +def test_add_edge(): + graph = Graph() + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2") + edge = Edge(upstream=task1, downstream=task2, type=EdgeType.DEPENDENCY) + graph.add_edge(edge) + assert len(graph.tasks) == 2 + assert task1 in graph.tasks + assert task2 in graph.tasks + assert len(graph.edges) == 1 + assert edge in graph.edges + + +def test_from_tasks(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + task3 = Task(objective="Task 3", parent=task2) + graph = Graph.from_tasks([task1, task2, task3]) + assert len(graph.tasks) == 3 + assert task1 in graph.tasks + assert task2 in graph.tasks + assert task3 in graph.tasks + assert len(graph.edges) == 2 + assert any( + edge.upstream == task1 + and edge.downstream == task2 + and edge.type == EdgeType.DEPENDENCY + for edge in graph.edges + ) + assert any( + edge.upstream == task3 + and edge.downstream == task2 + and edge.type == EdgeType.SUBTASK + for edge in graph.edges + ) + + +def test_upstream_edges(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + graph = Graph.from_tasks([task1, task2]) + upstream_edges = graph.upstream_edges() + assert len(upstream_edges[task1]) == 0 + assert len(upstream_edges[task2]) == 1 + assert upstream_edges[task2][0].upstream == task1 + + +def test_downstream_edges(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + graph = Graph.from_tasks([task1, task2]) + downstream_edges = graph.downstream_edges() + assert len(downstream_edges[task1]) == 1 + assert len(downstream_edges[task2]) == 0 + assert downstream_edges[task1][0].downstream == task2 + + +def test_topological_sort(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + task3 = Task(objective="Task 3", depends_on=[task2]) + task4 = Task(objective="Task 4", depends_on=[task3]) + graph = Graph.from_tasks([task1, task2, task3, task4]) + sorted_tasks = graph.topological_sort() + assert len(sorted_tasks) == 4 + assert sorted_tasks.index(task1) < sorted_tasks.index(task2) + assert sorted_tasks.index(task2) < sorted_tasks.index(task3) + assert sorted_tasks.index(task3) < sorted_tasks.index(task4) + + +def test_topological_sort_with_fan_in_and_fan_out(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2") + task3 = Task(objective="Task 3") + + edge1 = Edge(upstream=task1, downstream=task2, type=EdgeType.DEPENDENCY) + edge2 = Edge(upstream=task1, downstream=task3, type=EdgeType.DEPENDENCY) + edge3 = Edge(upstream=task2, downstream=task3, type=EdgeType.DEPENDENCY) + + graph = Graph() + graph.add_edge(edge1) + graph.add_edge(edge2) + graph.add_edge(edge3) + + sorted_tasks = graph.topological_sort() + + assert len(sorted_tasks) == 3 + assert sorted_tasks.index(task1) < sorted_tasks.index(task2) + assert sorted_tasks.index(task1) < sorted_tasks.index(task3) + assert sorted_tasks.index(task2) < sorted_tasks.index(task3) diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index fba0f184..27aa78dd 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -1,24 +1,30 @@ -from unittest.mock import AsyncMock +from functools import partial import pytest from controlflow.core.agent import Agent, get_default_agent from controlflow.core.flow import Flow from controlflow.core.graph import EdgeType from controlflow.core.task import Task, TaskStatus -from controlflow.settings import temporary_settings from controlflow.utilities.context import ctx +SimpleTask = partial(Task, objective="test", result_type=None) + def test_context_open_and_close(): assert ctx.get("tasks") == [] - with Task("a") as ta: + with SimpleTask() as ta: assert ctx.get("tasks") == [ta] - with Task("b") as tb: + with SimpleTask() as tb: assert ctx.get("tasks") == [ta, tb] assert ctx.get("tasks") == [ta] assert ctx.get("tasks") == [] +def test_task_requires_objective(): + with pytest.raises(ValueError): + Task() + + def test_task_initialization(): task = Task(objective="Test objective") assert task.objective == "Test objective" @@ -28,23 +34,23 @@ def test_task_initialization(): def test_task_dependencies(): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) + task1 = SimpleTask() + task2 = SimpleTask(depends_on=[task1]) assert task1 in task2.depends_on assert task2 in task1._downstreams def test_task_subtasks(): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", parent=task1) + task1 = SimpleTask() + task2 = SimpleTask(parent=task1) assert task2 in task1._subtasks assert task2.parent is task1 def test_task_parent_context(): - with Task("grandparent") as task1: - with Task("parent") as task2: - task3 = Task("child") + with SimpleTask() as task1: + with SimpleTask() as task2: + task3 = SimpleTask() assert task3.parent is task2 assert task2.parent is task1 @@ -57,19 +63,19 @@ def test_task_parent_context(): def test_task_agent_assignment(): agent = Agent(name="Test Agent") - task = Task(objective="Test objective", agents=[agent]) + task = SimpleTask(agents=[agent]) assert agent in task.agents def test_task_bad_agent_assignment(): with pytest.raises(ValueError): - Task(objective="Test objective", agents=[]) + SimpleTask(agents=[]) def test_task_loads_agent_from_parent(): agent = Agent(name="Test Agent") - with Task("parent", agents=[agent]): - child = Task("child") + with SimpleTask(agents=[agent]): + child = SimpleTask() assert child.agents is None assert child.get_agents() == [agent] @@ -79,7 +85,7 @@ def test_task_loads_agent_from_flow(): def_agent = get_default_agent() agent = Agent(name="Test Agent") with Flow(agents=[agent]): - task = Task("task") + task = SimpleTask() assert task.agents is None assert task.get_agents() == [agent] @@ -90,7 +96,7 @@ def test_task_loads_agent_from_flow(): def test_task_loads_agent_from_default_if_none_otherwise(): agent = get_default_agent() - task = Task("task") + task = SimpleTask() assert task.agents is None assert task.get_agents() == [agent] @@ -100,28 +106,28 @@ def test_task_loads_agent_from_parent_before_flow(): agent1 = Agent(name="Test Agent 1") agent2 = Agent(name="Test Agent 2") with Flow(agents=[agent1]): - with Task("parent", agents=[agent2]): - child = Task("child") + with SimpleTask(agents=[agent2]): + child = SimpleTask() assert child.agents is None assert child.get_agents() == [agent2] -def test_task_tracking(mock_controller_run_agent): +def test_task_tracking(): with Flow() as flow: - task = Task(objective="Test objective") - assert task in flow._tasks.values() + task = SimpleTask() + assert task in flow.tasks.values() -def test_task_tracking_on_call(mock_controller_run_agent): - task = Task(objective="Test objective") +def test_task_tracking_on_call(): + task = SimpleTask() with Flow() as flow: task.run_once() - assert task in flow._tasks.values() + assert task in flow.tasks.values() def test_task_status_transitions(): - task = Task(objective="Test objective") + task = SimpleTask() assert task.is_incomplete() assert not task.is_complete() assert not task.is_successful() @@ -135,7 +141,7 @@ def test_task_status_transitions(): assert not task.is_failed() assert not task.is_skipped() - task = Task(objective="Test objective") + task = SimpleTask() task.mark_failed() assert not task.is_incomplete() assert task.is_complete() @@ -143,7 +149,7 @@ def test_task_status_transitions(): assert task.is_failed() assert not task.is_skipped() - task = Task(objective="Test objective") + task = SimpleTask() task.mark_skipped() assert not task.is_incomplete() assert task.is_complete() @@ -153,8 +159,8 @@ def test_task_status_transitions(): def test_validate_upstream_dependencies_on_success(): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) + task1 = SimpleTask() + task2 = SimpleTask(depends_on=[task1]) with pytest.raises(ValueError, match="cannot be marked successful"): task2.mark_successful() task1.mark_successful() @@ -162,8 +168,8 @@ def test_validate_upstream_dependencies_on_success(): def test_validate_subtask_dependencies_on_success(): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", parent=task1) + task1 = SimpleTask() + task2 = SimpleTask(parent=task1) with pytest.raises(ValueError, match="cannot be marked successful"): task1.mark_successful() task2.mark_successful() @@ -171,49 +177,91 @@ def test_validate_subtask_dependencies_on_success(): def test_task_ready(): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) + task1 = SimpleTask() + assert task1.is_ready + + +def test_task_not_ready_if_successful(): + task1 = SimpleTask() + task1.mark_successful() + assert not task1.is_ready + + +def test_task_not_ready_if_failed(): + task1 = SimpleTask() + task1.mark_failed() + assert not task1.is_ready + + +def test_task_not_ready_if_dependencies_are_ready(): + task1 = SimpleTask() + task2 = SimpleTask(depends_on=[task1]) + assert task1.is_ready assert not task2.is_ready + +def test_task_ready_if_dependencies_are_ready(): + task1 = SimpleTask() + task2 = SimpleTask(depends_on=[task1]) task1.mark_successful() + assert not task1.is_ready assert task2.is_ready def test_task_hash(): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2") + task1 = SimpleTask() + task2 = SimpleTask() assert hash(task1) != hash(task2) -def test_task_tools(): - task = Task(objective="Test objective") +def test_ready_task_adds_tools(): + task = SimpleTask() + assert task.is_ready + tools = task.get_tools() - assert any(tool.function.name == f"mark_task_{task.id}_failed" for tool in tools) - assert any( - tool.function.name == f"mark_task_{task.id}_successful" for tool in tools - ) + assert any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) + assert any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) + +def test_completed_task_does_not_add_tools(): + task = SimpleTask() task.mark_successful() tools = task.get_tools() + assert not any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) + assert not any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) + + +def test_task_with_incomplete_upstream_does_not_add_tools(): + upstream_task = SimpleTask() + downstream_task = SimpleTask(depends_on=[upstream_task]) + tools = downstream_task.get_tools() assert not any( - tool.function.name == f"mark_task_{task.id}_failed" for tool in tools + tool.name == f"mark_task_{downstream_task.id}_failed" for tool in tools ) assert not any( - tool.function.name == f"mark_task_{task.id}_successful" for tool in tools + tool.name == f"mark_task_{downstream_task.id}_successful" for tool in tools ) +def test_task_with_incomplete_subtask_does_not_add_tools(): + parent = SimpleTask() + SimpleTask(parent=parent) + tools = parent.get_tools() + assert not any(tool.name == f"mark_task_{parent.id}_failed" for tool in tools) + assert not any(tool.name == f"mark_task_{parent.id}_successful" for tool in tools) + + class TestTaskToGraph: def test_single_task_graph(self): - task = Task(objective="Test objective") + task = SimpleTask() graph = task.as_graph() assert len(graph.tasks) == 1 assert task in graph.tasks assert len(graph.edges) == 0 def test_task_with_subtasks_graph(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", parent=task1) + task1 = SimpleTask() + task2 = SimpleTask(parent=task1) graph = task1.as_graph() assert len(graph.tasks) == 2 assert task1 in graph.tasks @@ -227,8 +275,8 @@ def test_task_with_subtasks_graph(self): ) def test_task_with_dependencies_graph(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) + task1 = SimpleTask() + task2 = SimpleTask(depends_on=[task1]) graph = task2.as_graph() assert len(graph.tasks) == 2 assert task1 in graph.tasks @@ -242,9 +290,9 @@ def test_task_with_dependencies_graph(self): ) def test_task_with_subtasks_and_dependencies_graph(self): - task1 = Task(objective="Task 1") - task2 = Task(objective="Task 2", depends_on=[task1]) - task3 = Task(objective="Task 3", parent=task2) + task1 = SimpleTask() + task2 = SimpleTask(depends_on=[task1]) + task3 = SimpleTask(objective="Task 3", parent=task2) graph = task2.as_graph() assert len(graph.tasks) == 3 assert task1 in graph.tasks @@ -265,89 +313,89 @@ def test_task_with_subtasks_and_dependencies_graph(self): ) -@pytest.mark.usefixtures("mock_run") -class TestTaskRun: - def test_run_task_max_iterations(self, mock_run: AsyncMock): - task = Task(objective="Say hello") +# @pytest.mark.usefixtures("mock_run") +# class TestTaskRun: +# def test_run_task_max_iterations(self, mock_run: AsyncMock): +# task = Task(objective="Say hello") - with Flow(): - with pytest.raises(ValueError): - task.run() +# with Flow(): +# with pytest.raises(ValueError): +# task.run() - assert mock_run.await_count == 3 +# assert mock_run.await_count == 3 - def test_run_task_mark_successful(self, mock_run: AsyncMock): - task = Task(objective="Say hello") +# def test_run_task_mark_successful(self, mock_run: AsyncMock): +# task = Task(objective="Say hello") - def mark_complete(): - task.mark_successful() +# def mark_complete(): +# task.mark_successful() - mock_run.side_effect = mark_complete - with Flow(): - result = task.run() - assert task.is_successful() - assert result is None +# mock_run.side_effect = mark_complete +# with Flow(): +# result = task.run() +# assert task.is_successful() +# assert result is None - def test_run_task_mark_successful_with_result(self, mock_run: AsyncMock): - task = Task(objective="Say hello", result_type=int) +# def test_run_task_mark_successful_with_result(self, mock_run: AsyncMock): +# task = Task(objective="Say hello", result_type=int) - def mark_complete(): - task.mark_successful(result=42) +# def mark_complete(): +# task.mark_successful(result=42) - mock_run.side_effect = mark_complete - with Flow(): - result = task.run() - assert task.is_successful() - assert result == 42 +# mock_run.side_effect = mark_complete +# with Flow(): +# result = task.run() +# assert task.is_successful() +# assert result == 42 - def test_run_task_mark_failed(self, mock_run: AsyncMock): - task = Task(objective="Say hello") +# def test_run_task_mark_failed(self, mock_run: AsyncMock): +# task = Task(objective="Say hello") - def mark_complete(): - task.mark_failed(message="Failed to say hello") +# def mark_complete(): +# task.mark_failed(message="Failed to say hello") - mock_run.side_effect = mark_complete - with Flow(): - with pytest.raises(ValueError): - task.run() - assert task.is_failed() - assert task.error == "Failed to say hello" +# mock_run.side_effect = mark_complete +# with Flow(): +# with pytest.raises(ValueError): +# task.run() +# assert task.is_failed() +# assert task.error == "Failed to say hello" - def test_run_task_outside_flow(self, mock_run: AsyncMock): - task = Task(objective="Say hello") +# def test_run_task_outside_flow(self, mock_run: AsyncMock): +# task = Task(objective="Say hello") - def mark_complete(): - task.mark_successful() +# def mark_complete(): +# task.mark_successful() - mock_run.side_effect = mark_complete - result = task.run() - assert task.is_successful() - assert result is None +# mock_run.side_effect = mark_complete +# result = task.run() +# assert task.is_successful() +# assert result is None - def test_run_task_outside_flow_fails_if_strict_flows_enforced( - self, mock_run: AsyncMock - ): - task = Task(objective="Say hello") +# def test_run_task_outside_flow_fails_if_strict_flows_enforced( +# self, mock_run: AsyncMock +# ): +# task = Task(objective="Say hello") - with temporary_settings(strict_flow_context=True): - with pytest.raises(ValueError): - task.run() +# with temporary_settings(strict_flow_context=True): +# with pytest.raises(ValueError): +# task.run() - def test_task_run_once_outside_flow_fails(self, mock_run: AsyncMock): - task = Task(objective="Say hello") +# def test_task_run_once_outside_flow_fails(self, mock_run: AsyncMock): +# task = Task(objective="Say hello") - with pytest.raises(ValueError): - task.run_once() +# with pytest.raises(ValueError): +# task.run_once() - def test_task_run_once_with_passed_flow(self, mock_run: AsyncMock): - task = Task(objective="Say hello") +# def test_task_run_once_with_passed_flow(self, mock_run: AsyncMock): +# task = Task(objective="Say hello") - def mark_complete(): - task.mark_successful() +# def mark_complete(): +# task.mark_successful() - mock_run.side_effect = mark_complete - flow = Flow() - while task.is_incomplete(): - task.run_once(flow=flow) - assert task.is_successful() - assert task.result is None +# mock_run.side_effect = mark_complete +# flow = Flow() +# while task.is_incomplete(): +# task.run_once(flow=flow) +# assert task.is_successful() +# assert task.result is None diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 1e9881d3..53fed4ac 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,2 +1,4 @@ from .mocks import * from .instructions import * + +from .controlflow import * diff --git a/tests/fixtures/controlflow.py b/tests/fixtures/controlflow.py new file mode 100644 index 00000000..25e09e8b --- /dev/null +++ b/tests/fixtures/controlflow.py @@ -0,0 +1,26 @@ +import controlflow +import pytest +from controlflow.llm.messages import MessageType +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel + + +@pytest.fixture(autouse=True) +def restore_defaults(monkeypatch): + """ + Monkeypatch defaults to themselves, which will automatically reset them after every test + """ + monkeypatch.setattr(controlflow, "default_agent", controlflow.default_agent) + monkeypatch.setattr(controlflow, "default_model", controlflow.default_model) + monkeypatch.setattr(controlflow, "default_history", controlflow.default_history) + yield + + +@pytest.fixture() +def fake_llm() -> FakeMessagesListChatModel: + return FakeMessagesListChatModel(responses=[]) + + +@pytest.fixture() +def default_fake_llm(fake_llm, restore_defaults) -> FakeMessagesListChatModel: + controlflow.default_agent = fake_llm + return fake_llm diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/llm/test_completions.py b/tests/llm/test_completions.py deleted file mode 100644 index 9d8ce63e..00000000 --- a/tests/llm/test_completions.py +++ /dev/null @@ -1,38 +0,0 @@ -import controlflow.llm.completions - - -def test_mock_completion(mock_completion): - mock_completion.set_response("Hello, world! xyz") - response = controlflow.llm.completions.completion(messages=[{"content": "Hello"}]) - assert response.last_response().choices[0].message.content == "Hello, world! xyz" - - -async def test_mock_completion_async(mock_completion_async): - mock_completion_async.set_response("Hello, world! xyz") - response = await controlflow.llm.completions.completion_async( - messages=[{"content": "Hello"}] - ) - assert response.last_response().choices[0].message.content == "Hello, world! xyz" - - -def test_mock_completion_stream(mock_completion_stream): - mock_completion_stream.set_response("Hello, world! xyz") - response = controlflow.llm.completions._completion_stream( - messages=[{"content": "Hello"}], - ) - deltas = [] - for delta, snapshot in response: - deltas.append(delta) - - assert [d.choices[0].delta.content for d in deltas[:5]] == ["H", "e", "l", "l", "o"] - - -async def test_mock_completion_stream_async(mock_completion_stream_async): - mock_completion_stream_async.set_response("Hello, world! xyz") - response = controlflow.llm.completions._completion_stream_async( - messages=[{"content": "Hello"}], stream=True - ) - deltas = [] - async for delta, snapshot in response: - deltas.append(delta) - assert [d.choices[0].delta.content for d in deltas[:5]] == ["H", "e", "l", "l", "o"] diff --git a/tests/llm/test_handlers.py b/tests/llm/test_handlers.py deleted file mode 100644 index 0800b98d..00000000 --- a/tests/llm/test_handlers.py +++ /dev/null @@ -1,53 +0,0 @@ -# from collections import Counter - -# import litellm -# from controlflow.llm.completions import _completion_stream -# from controlflow.llm.handlers import CompletionHandler -# from controlflow.llm.messages import AIMessage -# from controlflow.llm.tools import ToolResult -# from pydantic import BaseModel - - -# class StreamCall(BaseModel): -# method: str -# args: dict - - -# class MockCompletionHandler(CompletionHandler): -# def __init__(self, *args, **kwargs): -# super().__init__(*args, **kwargs) -# self.calls: list[StreamCall] = [] - -# def on_message_created(self, delta: litellm.utils.Delta): -# self.calls.append( -# StreamCall(method="on_message_created", args=dict(delta=delta)) -# ) - -# def on_message_delta(self, delta: litellm.utils.Delta, snapshot: litellm.Message): -# self.calls.append( -# StreamCall( -# method="on_message_delta", args=dict(delta=delta, snapshot=snapshot) -# ) -# ) - -# def on_message_done(self, message: AIMessage): -# self.calls.append( -# StreamCall(method="on_message_done", args=dict(message=message)) -# ) - -# def on_tool_call_done(self, tool_call: ToolResult): -# self.calls.append( -# StreamCall(method="on_tool_call", args=dict(tool_call=tool_call)) -# ) - - -# class TestCompletionHandler: -# def test_stream(self): -# handler = MockCompletionHandler() -# gen = _completion_stream(messages=[{"text": "Hello"}]) -# handler.stream(gen) - -# method_counts = Counter(call.method for call in handler.calls) -# assert method_counts["on_message_created"] == 1 -# assert method_counts["on_message_delta"] == 4 -# assert method_counts["on_message_done"] == 1 diff --git a/tests/test_decorators.py b/tests/test_decorators.py deleted file mode 100644 index af333023..00000000 --- a/tests/test_decorators.py +++ /dev/null @@ -1,168 +0,0 @@ -import controlflow -import pytest -from controlflow import Task -from controlflow.core.flow import Flow -from controlflow.decorators import flow, task -from controlflow.settings import temporary_settings - - -@pytest.mark.usefixtures("mock_controller") -class TestFlowDecorator: - def test_flow_decorator(self): - @flow - def test_flow(): - return 1 - - result = test_flow() - assert result == 1 - - def test_flow_decorator_runs_all_tasks(self): - tasks: list[Task] = [] - - @flow - def test_flow(): - task = Task( - "say hello", - result_type=str, - result="Task completed successfully", - ) - tasks.append(task) - - result = test_flow() - assert result is None - assert tasks[0].is_successful() - assert tasks[0].result == "Task completed successfully" - - def test_flow_decorator_resolves_all_tasks(self): - @flow - def test_flow(): - task1 = Task("say hello", result="hello") - task2 = Task("say goodbye", result="goodbye") - task3 = Task("say goodnight", result="goodnight") - return dict(a=task1, b=[task2], c=dict(x=dict(y=[[task3]]))) - - result = test_flow() - assert result == dict( - a="hello", b=["goodbye"], c=dict(x=dict(y=[["goodnight"]])) - ) - - def test_manually_run_task_in_flow(self): - @flow - def test_flow(): - task = Task("say hello", result="hello") - task.run() - return task.result - - result = test_flow() - assert result == "hello" - - -class TestTaskDecorator: - pass - - -@pytest.mark.usefixtures("mock_controller") -class TestTaskEagerMode: - def test_eager_mode_enabled_by_default(self): - assert controlflow.settings.eager_mode is True - - def test_task_eager_mode(self, mock_controller_run_agent): - @task - def return_42() -> int: - """Return the number 42""" - pass - - return_42() - assert mock_controller_run_agent.call_count == 1 - - def test_task_lazy(self, mock_controller_run_agent): - @task(lazy=True) - def return_42() -> int: - """Return the number 42""" - pass - - result = return_42() - assert mock_controller_run_agent.call_count == 0 - assert isinstance(result, Task) - assert result.objective == "return_42" - assert result.result_type == int - assert result.instructions == "Return the number 42" - - def test_task_eager_mode_loads_default_setting(self, mock_controller_run_agent): - @task - def return_42() -> int: - """Return the number 42""" - pass - - with temporary_settings(eager_mode=False): - result = return_42() - - assert mock_controller_run_agent.call_count == 0 - assert isinstance(result, Task) - assert result.objective == "return_42" - assert result.result_type == int - assert result.instructions == "Return the number 42" - - @pytest.mark.parametrize("eager_mode", [True, False]) - def test_override_eager_mode_at_call_time( - self, mock_controller_run_agent, eager_mode - ): - with temporary_settings(eager_mode=eager_mode): - - @task - def return_42() -> int: - """Return the number 42""" - pass - - return_42(lazy_=eager_mode) - if eager_mode: - assert mock_controller_run_agent.call_count == 0 - else: - assert mock_controller_run_agent.call_count == 1 - - -@pytest.mark.usefixtures("mock_controller") -class TestFlowEagerMode: - def test_flow_eager_mode(self, mock_controller_run_agent): - @flow - def test_flow(): - task = Task("say hello", result="hello") - return task - - result = test_flow() - assert mock_controller_run_agent.call_count == 1 - assert result == "hello" - - def test_flow_lazy(self, mock_controller_run_agent): - @flow(lazy=True) - def test_flow(): - """This is a test flow""" - task = Task("say hello", result="hello") - return task - - result = test_flow() - assert mock_controller_run_agent.call_count == 0 - assert isinstance(result, Flow) - assert result.name == "test_flow" - assert result.description == "This is a test flow" - tasks = list(result._tasks.values()) - assert len(tasks) == 1 - assert tasks[0].objective == "say hello" - assert tasks[0].result == "hello" - - def test_flow_lazy_doesnt_affect_tasks_with_eager_mode_on( - self, mock_controller_run_agent - ): - @task - def return_42() -> int: - """Return the number 42""" - pass - - @flow(lazy=True) - def test_flow(): - result = return_42() - return result - - result = test_flow() - assert mock_controller_run_agent.call_count == 1 - assert not isinstance(result, Task)