diff --git a/src/control_flow/core/flow.py b/src/control_flow/core/flow.py index a3737f0d..2dc018e0 100644 --- a/src/control_flow/core/flow.py +++ b/src/control_flow/core/flow.py @@ -6,6 +6,7 @@ from prefect import task as prefect_task from pydantic import Field, field_validator +import control_flow from control_flow.utilities.context import ctx from control_flow.utilities.logging import get_logger from control_flow.utilities.types import AssistantTool, ControlFlowModel @@ -36,7 +37,6 @@ class Flow(ControlFlowModel): description="The default agents for the flow. These agents will be used " "for any task that does not specify agents.", ) - model: str | None = None context: dict = {} @field_validator("thread", mode="before") @@ -73,11 +73,15 @@ def get_flow() -> Flow: """ Loads the flow from the context. - Will error if no flow is found in the context. + Will error if no flow is found in the context, unless the global flow is + enabled in settings """ flow: Flow | None = ctx.get("flow") if not flow: - return GLOBAL_FLOW + if control_flow.settings.enable_global_flow: + return GLOBAL_FLOW + else: + raise ValueError("No flow found in context.") return flow diff --git a/src/control_flow/instructions.py b/src/control_flow/instructions.py index 97e89664..c5c97c3b 100644 --- a/src/control_flow/instructions.py +++ b/src/control_flow/instructions.py @@ -1,8 +1,6 @@ -import inspect from contextlib import contextmanager from typing import Generator, List -from control_flow.core.flow import Flow from control_flow.utilities.context import ctx from control_flow.utilities.logging import get_logger @@ -10,75 +8,21 @@ @contextmanager -def instructions( - *instructions: str, - post_add_message: bool = False, - post_remove_message: bool = False, -) -> Generator[list[str], None, None]: +def instructions(*instructions: str) -> Generator[list[str], None, None]: """ Temporarily add instructions to the current instruction stack. The instruction is removed when the context is exited. - If `post_add_message` is True, a message will be added to the flow when the - instruction is added. If `post_remove_message` is True, a message will be - added to the flow when the instruction is removed. These explicit reminders - can help when agents infer instructions more from history. - with instructions("talk like a pirate"): ... """ - if post_add_message or post_remove_message: - flow: Flow = ctx.get("flow") - if flow is None: - raise ValueError( - "instructions() with message posting must be used within a flow context" - ) - stack: list[str] = ctx.get("instructions", []) stack = stack + list(instructions) with ctx(instructions=stack): - try: - if post_add_message: - for instruction in instructions: - flow.add_message( - inspect.cleandoc( - """ - # SYSTEM MESSAGE: INSTRUCTION ADDED - - The following instruction is now active: - - - {instruction} - - - Always consult your current instructions before acting. - """ - ).format(instruction=instruction) - ) - yield - - # yield new_stack - finally: - if post_remove_message: - for instruction in instructions: - flow.add_message( - inspect.cleandoc( - """ - # SYSTEM MESSAGE: INSTRUCTION REMOVED - - The following instruction is no longer active: - - - {instruction} - - - Always consult your current instructions before acting. - """ - ).format(instruction=instruction) - ) + yield def get_instructions() -> List[str]: diff --git a/src/control_flow/settings.py b/src/control_flow/settings.py index e8d083cf..17972a1e 100644 --- a/src/control_flow/settings.py +++ b/src/control_flow/settings.py @@ -1,6 +1,9 @@ import os import sys import warnings +from contextlib import contextmanager +from copy import deepcopy +from typing import Any from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -46,6 +49,10 @@ class Settings(ControlFlowSettings): assistant_model: str = "gpt-4-1106-preview" max_agent_iterations: int = 10 prefect: PrefectSettings = Field(default_factory=PrefectSettings) + enable_global_flow: bool = Field( + True, + description="If True, a global flow is created for convenience, so users don't have to wrap every invocation in a flow function. Disable to avoid accidentally sharing context between agents.", + ) def __init__(self, **data): super().__init__(**data) @@ -53,3 +60,52 @@ def __init__(self, **data): settings = Settings() + + +@contextmanager +def temporary_settings(**kwargs: Any): + """ + Temporarily override ControlFlow setting values, including nested settings objects. + + To override nested settings, use `__` to separate nested attribute names. + + Args: + **kwargs: The settings to override, including nested settings. + + Example: + Temporarily override log level and OpenAI API key: + ```python + import control_flow + from control_flow.settings import temporary_settings + + # Override top-level settings + with temporary_settings(log_level="INFO"): + assert control_flow.settings.log_level == "INFO" + assert control_flow.settings.log_level == "DEBUG" + + # Override nested settings + with temporary_settings(openai__api_key="new-api-key"): + assert control_flow.settings.openai.api_key.get_secret_value() == "new-api-key" + assert control_flow.settings.openai.api_key.get_secret_value().startswith("sk-") + ``` + """ + old_env = os.environ.copy() + old_settings = deepcopy(settings) + + def set_nested_attr(obj: object, attr_path: str, value: Any): + parts = attr_path.split("__") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + try: + for attr_path, value in kwargs.items(): + set_nested_attr(settings, attr_path, value) + yield + + finally: + os.environ.clear() + os.environ.update(old_env) + + for attr, value in old_settings: + set_nested_attr(settings, attr, value) diff --git a/tests/core/test_controller.py b/tests/core/test_controller.py new file mode 100644 index 00000000..78c8cc9c --- /dev/null +++ b/tests/core/test_controller.py @@ -0,0 +1,85 @@ +from unittest.mock import AsyncMock + +import pytest +from control_flow.core.agent import Agent +from control_flow.core.controller.controller import Controller +from control_flow.core.flow import Flow +from control_flow.core.graph import EdgeType +from control_flow.core.task import Task + + +class TestController: + @pytest.fixture + def flow(self): + return Flow() + + @pytest.fixture + def agent(self): + return Agent(name="Test Agent") + + @pytest.fixture + def task(self): + return Task(objective="Test Task") + + def test_controller_initialization(self, flow, agent, task): + controller = Controller(flow=flow, tasks=[task], agents=[agent]) + assert controller.flow == flow + assert controller.tasks == [task] + assert controller.agents == [agent] + assert controller.run_dependencies is True + assert len(controller.context) == 0 + assert len(controller.graph.tasks) == 1 + assert len(controller.graph.edges) == 0 + + def test_controller_missing_tasks(self, flow): + with pytest.raises(ValueError, match="At least one task is required."): + Controller(flow=flow, tasks=[]) + + async def test_run_agent(self, flow, agent, task, monkeypatch): + controller = Controller(flow=flow, tasks=[task], agents=[agent]) + mocked_run = AsyncMock() + monkeypatch.setattr(Agent, "run", mocked_run) + await controller._run_agent(agent, tasks=[task]) + mocked_run.assert_called_once_with(tasks=[task]) + + async def test_run_once(self, flow, agent, task, monkeypatch): + controller = Controller(flow=flow, tasks=[task], agents=[agent]) + mocked_run_agent = AsyncMock() + monkeypatch.setattr(Controller, "_run_agent", mocked_run_agent) + await controller.run_once_async() + mocked_run_agent.assert_called_once_with(agent, tasks=[task]) + + def test_create_end_run_tool(self, flow, agent, task): + controller = Controller(flow=flow, tasks=[task], agents=[agent]) + end_run_tool = controller._create_end_run_tool() + assert end_run_tool.function.name == "end_run" + assert end_run_tool.function.description.startswith("End your turn") + + def test_controller_graph_creation(self, flow, agent): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + controller = Controller(flow=flow, tasks=[task1, task2], agents=[agent]) + assert len(controller.graph.tasks) == 2 + assert len(controller.graph.edges) == 1 + assert controller.graph.edges.pop().type == EdgeType.dependency + + def test_controller_agent_selection(self, flow, monkeypatch): + agent1 = Agent(name="Agent 1") + agent2 = Agent(name="Agent 2") + task = Task(objective="Test Task", agents=[agent1, agent2]) + controller = Controller(flow=flow, tasks=[task], agents=[agent1, agent2]) + mocked_marvin_moderator = AsyncMock(return_value=agent1) + monkeypatch.setattr( + "control_flow.core.controller.moderators.marvin_moderator", + mocked_marvin_moderator, + ) + assert controller.agents == [agent1, agent2] + + async def test_controller_run_dependencies(self, flow, agent): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + controller = Controller(flow=flow, tasks=[task2], agents=[agent]) + mocked_run_agent = AsyncMock() + controller._run_agent = mocked_run_agent + await controller.run_once_async() + mocked_run_agent.assert_called_once_with(agent, tasks=[task1, task2]) diff --git a/tests/core/test_flows.py b/tests/core/test_flows.py new file mode 100644 index 00000000..ce064e7b --- /dev/null +++ b/tests/core/test_flows.py @@ -0,0 +1,74 @@ +# test_flow.py +from unittest.mock import MagicMock + +from control_flow.core.agent import Agent +from control_flow.core.flow import Flow, get_flow +from control_flow.utilities.context import ctx + + +class TestFlow: + def test_flow_initialization(self): + flow = Flow() + assert flow.thread is not None + assert len(flow.tools) == 0 + assert len(flow.agents) == 1 + assert isinstance(flow.agents[0], Agent) + assert len(flow.context) == 0 + + def test_flow_with_custom_agents(self): + agent1 = Agent(name="Agent 1") + agent2 = Agent(name="Agent 2") + flow = Flow(agents=[agent1, agent2]) + assert len(flow.agents) == 2 + assert agent1 in flow.agents + assert agent2 in flow.agents + + def test_flow_with_custom_tools(self): + def tool1(): + pass + + def tool2(): + pass + + flow = Flow(tools=[tool1, tool2]) + assert len(flow.tools) == 2 + assert tool1 in flow.tools + assert tool2 in flow.tools + + def test_flow_with_custom_context(self): + flow = Flow(context={"key": "value"}) + assert len(flow.context) == 1 + assert flow.context["key"] == "value" + + def test_add_message(self, monkeypatch): + flow = Flow() + mocked_add = MagicMock() + monkeypatch.setattr(flow.thread, "add", mocked_add) + flow.add_message("Test message", role="user") + mocked_add.assert_called_once_with("Test message", role="user") + + def test_flow_context_manager(self): + with Flow() as flow: + assert ctx.get("flow") == flow + assert ctx.get("tasks") == [] + assert ctx.get("flow") is None + assert ctx.get("tasks") is None + + def test_get_flow_within_context(self): + with Flow() as flow: + assert get_flow() == flow + + def test_get_flow_without_context(self): + flow1 = get_flow() + with Flow() as flow2: + pass + flow3 = get_flow() + assert flow1 == flow3 + assert flow1 != flow2 + + def test_get_flow_nested_contexts(self): + with Flow() as flow1: + assert get_flow() == flow1 + with Flow() as flow2: + assert get_flow() == flow2 + assert get_flow() == flow1 diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py new file mode 100644 index 00000000..2f20b719 --- /dev/null +++ b/tests/core/test_graph.py @@ -0,0 +1,100 @@ +# test_graph.py +from control_flow.core.graph import Edge, EdgeType, Graph +from control_flow.core.task import Task + + +class TestGraph: + def test_graph_initialization(self): + graph = Graph() + assert len(graph.tasks) == 0 + assert len(graph.edges) == 0 + + def test_add_task(self): + graph = Graph() + task = Task(objective="Test objective") + graph.add_task(task) + assert len(graph.tasks) == 1 + assert task in graph.tasks + + def test_add_edge(self): + graph = Graph() + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2") + edge = Edge(upstream=task1, downstream=task2, type=EdgeType.DEPENDENCY) + graph.add_edge(edge) + assert len(graph.tasks) == 2 + assert task1 in graph.tasks + assert task2 in graph.tasks + assert len(graph.edges) == 1 + assert edge in graph.edges + + def test_from_tasks(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + task3 = Task(objective="Task 3", parent=task2) + graph = Graph.from_tasks([task1, task2, task3]) + assert len(graph.tasks) == 3 + assert task1 in graph.tasks + assert task2 in graph.tasks + assert task3 in graph.tasks + assert len(graph.edges) == 2 + assert any( + edge.upstream == task1 + and edge.downstream == task2 + and edge.type == EdgeType.DEPENDENCY + for edge in graph.edges + ) + assert any( + edge.upstream == task3 + and edge.downstream == task2 + and edge.type == EdgeType.SUBTASK + for edge in graph.edges + ) + + def test_upstream_edges(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + graph = Graph.from_tasks([task1, task2]) + upstream_edges = graph.upstream_edges() + assert len(upstream_edges[task1]) == 0 + assert len(upstream_edges[task2]) == 1 + assert upstream_edges[task2][0].upstream == task1 + + def test_downstream_edges(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + graph = Graph.from_tasks([task1, task2]) + downstream_edges = graph.downstream_edges() + assert len(downstream_edges[task1]) == 1 + assert len(downstream_edges[task2]) == 0 + assert downstream_edges[task1][0].downstream == task2 + + def test_upstream_dependencies(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + task3 = Task(objective="Task 3", parent=task2) + graph = Graph.from_tasks([task1, task2, task3]) + dependencies = graph.upstream_dependencies([task3]) + assert len(dependencies) == 3 + assert task1 in dependencies + assert task2 in dependencies + assert task3 in dependencies + + def test_ready_tasks(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + task3 = Task(objective="Task 3", parent=task2) + graph = Graph.from_tasks([task1, task2, task3]) + ready_tasks = graph.ready_tasks() + assert len(ready_tasks) == 1 + assert task1 in ready_tasks + + task1.mark_successful() + ready_tasks = graph.ready_tasks() + assert len(ready_tasks) == 1 + assert task2 in ready_tasks + + task2.mark_successful() + ready_tasks = graph.ready_tasks() + assert len(ready_tasks) == 1 + assert task3 in ready_tasks diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index feebca57..1ad7caf0 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -1,22 +1,180 @@ -from control_flow.core.task import Task, get_tasks +from control_flow.core.agent import Agent +from control_flow.core.flow import Flow +from control_flow.core.graph import EdgeType +from control_flow.core.task import Task, TaskStatus, get_tasks from control_flow.utilities.context import ctx -class TestTaskContext: - def test_context_open_and_close(self): - assert ctx.get("tasks") == [] - with Task("a") as ta: - assert ctx.get("tasks") == [ta] - with Task("b") as tb: - assert ctx.get("tasks") == [ta, tb] - assert ctx.get("tasks") == [ta] - assert ctx.get("tasks") == [] - - def test_get_tasks_function(self): - # assert get_tasks() == [] - with Task("a") as ta: - assert get_tasks() == [ta] - with Task("b") as tb: - assert get_tasks() == [ta, tb] - assert get_tasks() == [ta] - assert get_tasks() == [] +def test_context_open_and_close(): + assert ctx.get("tasks") == [] + with Task("a") as ta: + assert ctx.get("tasks") == [ta] + with Task("b") as tb: + assert ctx.get("tasks") == [ta, tb] + assert ctx.get("tasks") == [ta] + assert ctx.get("tasks") == [] + + +def test_get_tasks_function(): + # assert get_tasks() == [] + with Task("a") as ta: + assert get_tasks() == [ta] + with Task("b") as tb: + assert get_tasks() == [ta, tb] + assert get_tasks() == [ta] + assert get_tasks() == [] + + +def test_task_initialization(): + task = Task(objective="Test objective") + assert task.objective == "Test objective" + assert task.status == TaskStatus.INCOMPLETE + assert task.result is None + assert task.error is None + + +def test_task_dependencies(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + assert task1 in task2.depends_on + assert task2 in task1._downstreams + + +def test_task_subtasks(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", parent=task1) + assert task2 in task1.subtasks + assert task2._parent == task1 + + +def test_task_agent_assignment(): + agent = Agent(name="Test Agent") + task = Task(objective="Test objective", agents=[agent]) + assert agent in task.agents + + +def test_task_context(): + with Flow(): + task = Task(objective="Test objective") + assert task in Task._context_stack + + +def test_task_status_transitions(): + task = Task(objective="Test objective") + assert task.is_incomplete() + assert not task.is_complete() + assert not task.is_successful() + assert not task.is_failed() + assert not task.is_skipped() + + task.mark_successful() + assert not task.is_incomplete() + assert task.is_complete() + assert task.is_successful() + assert not task.is_failed() + assert not task.is_skipped() + + task = Task(objective="Test objective") + task.mark_failed() + assert not task.is_incomplete() + assert task.is_complete() + assert not task.is_successful() + assert task.is_failed() + assert not task.is_skipped() + + task = Task(objective="Test objective") + task.mark_skipped() + assert not task.is_incomplete() + assert task.is_complete() + assert not task.is_successful() + assert not task.is_failed() + assert task.is_skipped() + + +def test_task_ready(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + assert not task2.is_ready() + + task1.mark_successful() + assert task2.is_ready() + + +def test_task_hash(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2") + assert hash(task1) != hash(task2) + + +def test_task_tools(): + task = Task(objective="Test objective") + tools = task.get_tools() + assert any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) + assert any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) + + task.mark_successful() + tools = task.get_tools() + assert not any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) + assert not any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) + + +class TestTaskToGraph: + def test_single_task_graph(self): + task = Task(objective="Test objective") + graph = task.as_graph() + assert len(graph.tasks) == 1 + assert task in graph.tasks + assert len(graph.edges) == 0 + + def test_task_with_subtasks_graph(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", parent=task1) + graph = task1.as_graph() + assert len(graph.tasks) == 2 + assert task1 in graph.tasks + assert task2 in graph.tasks + assert len(graph.edges) == 1 + assert any( + edge.upstream == task2 + and edge.downstream == task1 + and edge.type == EdgeType.SUBTASK + for edge in graph.edges + ) + + def test_task_with_dependencies_graph(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + graph = task2.as_graph() + assert len(graph.tasks) == 2 + assert task1 in graph.tasks + assert task2 in graph.tasks + assert len(graph.edges) == 1 + assert any( + edge.upstream == task1 + and edge.downstream == task2 + and edge.type == EdgeType.DEPENDENCY + for edge in graph.edges + ) + + def test_task_with_subtasks_and_dependencies_graph(self): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2", depends_on=[task1]) + task3 = Task(objective="Task 3", parent=task2) + graph = task2.as_graph() + assert len(graph.tasks) == 3 + assert task1 in graph.tasks + assert task2 in graph.tasks + assert task3 in graph.tasks + assert len(graph.edges) == 2 + assert any( + edge.upstream == task1 + and edge.downstream == task2 + and edge.type == EdgeType.DEPENDENCY + for edge in graph.edges + ) + assert any( + edge.upstream == task3 + and edge.downstream == task2 + and edge.type == EdgeType.SUBTASK + for edge in graph.edges + ) diff --git a/tests/fixtures/flows.py b/tests/fixtures/flows.py new file mode 100644 index 00000000..052d8f5e --- /dev/null +++ b/tests/fixtures/flows.py @@ -0,0 +1,8 @@ +import pytest +from control_flow.settings import temporary_settings + + +@pytest.fixture(autouse=True, scope="session") +def disable_global_flow(): + with temporary_settings(enable_global_flow=False): + yield diff --git a/tests/test_instructions.py b/tests/test_instructions.py index 97e8ff5c..28da6cd6 100644 --- a/tests/test_instructions.py +++ b/tests/test_instructions.py @@ -16,3 +16,17 @@ def test_instructions_context_nested(): assert get_instructions() == ["abc", "def"] assert get_instructions() == ["abc"] assert get_instructions() == [] + + +def test_instructions_context_multiple(): + assert get_instructions() == [] + with instructions("abc", "def", "ghi"): + assert get_instructions() == ["abc", "def", "ghi"] + assert get_instructions() == [] + + +def test_instructions_context_empty(): + assert get_instructions() == [] + with instructions(): + assert get_instructions() == [] + assert get_instructions() == []