diff --git a/src/control_flow/core/flow.py b/src/control_flow/core/flow.py
index a3737f0d..2dc018e0 100644
--- a/src/control_flow/core/flow.py
+++ b/src/control_flow/core/flow.py
@@ -6,6 +6,7 @@
from prefect import task as prefect_task
from pydantic import Field, field_validator
+import control_flow
from control_flow.utilities.context import ctx
from control_flow.utilities.logging import get_logger
from control_flow.utilities.types import AssistantTool, ControlFlowModel
@@ -36,7 +37,6 @@ class Flow(ControlFlowModel):
description="The default agents for the flow. These agents will be used "
"for any task that does not specify agents.",
)
- model: str | None = None
context: dict = {}
@field_validator("thread", mode="before")
@@ -73,11 +73,15 @@ def get_flow() -> Flow:
"""
Loads the flow from the context.
- Will error if no flow is found in the context.
+ Will error if no flow is found in the context, unless the global flow is
+ enabled in settings
"""
flow: Flow | None = ctx.get("flow")
if not flow:
- return GLOBAL_FLOW
+ if control_flow.settings.enable_global_flow:
+ return GLOBAL_FLOW
+ else:
+ raise ValueError("No flow found in context.")
return flow
diff --git a/src/control_flow/instructions.py b/src/control_flow/instructions.py
index 97e89664..c5c97c3b 100644
--- a/src/control_flow/instructions.py
+++ b/src/control_flow/instructions.py
@@ -1,8 +1,6 @@
-import inspect
from contextlib import contextmanager
from typing import Generator, List
-from control_flow.core.flow import Flow
from control_flow.utilities.context import ctx
from control_flow.utilities.logging import get_logger
@@ -10,75 +8,21 @@
@contextmanager
-def instructions(
- *instructions: str,
- post_add_message: bool = False,
- post_remove_message: bool = False,
-) -> Generator[list[str], None, None]:
+def instructions(*instructions: str) -> Generator[list[str], None, None]:
"""
Temporarily add instructions to the current instruction stack. The
instruction is removed when the context is exited.
- If `post_add_message` is True, a message will be added to the flow when the
- instruction is added. If `post_remove_message` is True, a message will be
- added to the flow when the instruction is removed. These explicit reminders
- can help when agents infer instructions more from history.
-
with instructions("talk like a pirate"):
...
"""
- if post_add_message or post_remove_message:
- flow: Flow = ctx.get("flow")
- if flow is None:
- raise ValueError(
- "instructions() with message posting must be used within a flow context"
- )
-
stack: list[str] = ctx.get("instructions", [])
stack = stack + list(instructions)
with ctx(instructions=stack):
- try:
- if post_add_message:
- for instruction in instructions:
- flow.add_message(
- inspect.cleandoc(
- """
- # SYSTEM MESSAGE: INSTRUCTION ADDED
-
- The following instruction is now active:
-
-
- {instruction}
-
-
- Always consult your current instructions before acting.
- """
- ).format(instruction=instruction)
- )
- yield
-
- # yield new_stack
- finally:
- if post_remove_message:
- for instruction in instructions:
- flow.add_message(
- inspect.cleandoc(
- """
- # SYSTEM MESSAGE: INSTRUCTION REMOVED
-
- The following instruction is no longer active:
-
-
- {instruction}
-
-
- Always consult your current instructions before acting.
- """
- ).format(instruction=instruction)
- )
+ yield
def get_instructions() -> List[str]:
diff --git a/src/control_flow/settings.py b/src/control_flow/settings.py
index e8d083cf..17972a1e 100644
--- a/src/control_flow/settings.py
+++ b/src/control_flow/settings.py
@@ -1,6 +1,9 @@
import os
import sys
import warnings
+from contextlib import contextmanager
+from copy import deepcopy
+from typing import Any
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -46,6 +49,10 @@ class Settings(ControlFlowSettings):
assistant_model: str = "gpt-4-1106-preview"
max_agent_iterations: int = 10
prefect: PrefectSettings = Field(default_factory=PrefectSettings)
+ enable_global_flow: bool = Field(
+ True,
+ description="If True, a global flow is created for convenience, so users don't have to wrap every invocation in a flow function. Disable to avoid accidentally sharing context between agents.",
+ )
def __init__(self, **data):
super().__init__(**data)
@@ -53,3 +60,52 @@ def __init__(self, **data):
settings = Settings()
+
+
+@contextmanager
+def temporary_settings(**kwargs: Any):
+ """
+ Temporarily override ControlFlow setting values, including nested settings objects.
+
+ To override nested settings, use `__` to separate nested attribute names.
+
+ Args:
+ **kwargs: The settings to override, including nested settings.
+
+ Example:
+ Temporarily override log level and OpenAI API key:
+ ```python
+ import control_flow
+ from control_flow.settings import temporary_settings
+
+ # Override top-level settings
+ with temporary_settings(log_level="INFO"):
+ assert control_flow.settings.log_level == "INFO"
+ assert control_flow.settings.log_level == "DEBUG"
+
+ # Override nested settings
+ with temporary_settings(openai__api_key="new-api-key"):
+ assert control_flow.settings.openai.api_key.get_secret_value() == "new-api-key"
+ assert control_flow.settings.openai.api_key.get_secret_value().startswith("sk-")
+ ```
+ """
+ old_env = os.environ.copy()
+ old_settings = deepcopy(settings)
+
+ def set_nested_attr(obj: object, attr_path: str, value: Any):
+ parts = attr_path.split("__")
+ for part in parts[:-1]:
+ obj = getattr(obj, part)
+ setattr(obj, parts[-1], value)
+
+ try:
+ for attr_path, value in kwargs.items():
+ set_nested_attr(settings, attr_path, value)
+ yield
+
+ finally:
+ os.environ.clear()
+ os.environ.update(old_env)
+
+ for attr, value in old_settings:
+ set_nested_attr(settings, attr, value)
diff --git a/tests/core/test_controller.py b/tests/core/test_controller.py
new file mode 100644
index 00000000..78c8cc9c
--- /dev/null
+++ b/tests/core/test_controller.py
@@ -0,0 +1,85 @@
+from unittest.mock import AsyncMock
+
+import pytest
+from control_flow.core.agent import Agent
+from control_flow.core.controller.controller import Controller
+from control_flow.core.flow import Flow
+from control_flow.core.graph import EdgeType
+from control_flow.core.task import Task
+
+
+class TestController:
+ @pytest.fixture
+ def flow(self):
+ return Flow()
+
+ @pytest.fixture
+ def agent(self):
+ return Agent(name="Test Agent")
+
+ @pytest.fixture
+ def task(self):
+ return Task(objective="Test Task")
+
+ def test_controller_initialization(self, flow, agent, task):
+ controller = Controller(flow=flow, tasks=[task], agents=[agent])
+ assert controller.flow == flow
+ assert controller.tasks == [task]
+ assert controller.agents == [agent]
+ assert controller.run_dependencies is True
+ assert len(controller.context) == 0
+ assert len(controller.graph.tasks) == 1
+ assert len(controller.graph.edges) == 0
+
+ def test_controller_missing_tasks(self, flow):
+ with pytest.raises(ValueError, match="At least one task is required."):
+ Controller(flow=flow, tasks=[])
+
+ async def test_run_agent(self, flow, agent, task, monkeypatch):
+ controller = Controller(flow=flow, tasks=[task], agents=[agent])
+ mocked_run = AsyncMock()
+ monkeypatch.setattr(Agent, "run", mocked_run)
+ await controller._run_agent(agent, tasks=[task])
+ mocked_run.assert_called_once_with(tasks=[task])
+
+ async def test_run_once(self, flow, agent, task, monkeypatch):
+ controller = Controller(flow=flow, tasks=[task], agents=[agent])
+ mocked_run_agent = AsyncMock()
+ monkeypatch.setattr(Controller, "_run_agent", mocked_run_agent)
+ await controller.run_once_async()
+ mocked_run_agent.assert_called_once_with(agent, tasks=[task])
+
+ def test_create_end_run_tool(self, flow, agent, task):
+ controller = Controller(flow=flow, tasks=[task], agents=[agent])
+ end_run_tool = controller._create_end_run_tool()
+ assert end_run_tool.function.name == "end_run"
+ assert end_run_tool.function.description.startswith("End your turn")
+
+ def test_controller_graph_creation(self, flow, agent):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ controller = Controller(flow=flow, tasks=[task1, task2], agents=[agent])
+ assert len(controller.graph.tasks) == 2
+ assert len(controller.graph.edges) == 1
+ assert controller.graph.edges.pop().type == EdgeType.dependency
+
+ def test_controller_agent_selection(self, flow, monkeypatch):
+ agent1 = Agent(name="Agent 1")
+ agent2 = Agent(name="Agent 2")
+ task = Task(objective="Test Task", agents=[agent1, agent2])
+ controller = Controller(flow=flow, tasks=[task], agents=[agent1, agent2])
+ mocked_marvin_moderator = AsyncMock(return_value=agent1)
+ monkeypatch.setattr(
+ "control_flow.core.controller.moderators.marvin_moderator",
+ mocked_marvin_moderator,
+ )
+ assert controller.agents == [agent1, agent2]
+
+ async def test_controller_run_dependencies(self, flow, agent):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ controller = Controller(flow=flow, tasks=[task2], agents=[agent])
+ mocked_run_agent = AsyncMock()
+ controller._run_agent = mocked_run_agent
+ await controller.run_once_async()
+ mocked_run_agent.assert_called_once_with(agent, tasks=[task1, task2])
diff --git a/tests/core/test_flows.py b/tests/core/test_flows.py
new file mode 100644
index 00000000..ce064e7b
--- /dev/null
+++ b/tests/core/test_flows.py
@@ -0,0 +1,74 @@
+# test_flow.py
+from unittest.mock import MagicMock
+
+from control_flow.core.agent import Agent
+from control_flow.core.flow import Flow, get_flow
+from control_flow.utilities.context import ctx
+
+
+class TestFlow:
+ def test_flow_initialization(self):
+ flow = Flow()
+ assert flow.thread is not None
+ assert len(flow.tools) == 0
+ assert len(flow.agents) == 1
+ assert isinstance(flow.agents[0], Agent)
+ assert len(flow.context) == 0
+
+ def test_flow_with_custom_agents(self):
+ agent1 = Agent(name="Agent 1")
+ agent2 = Agent(name="Agent 2")
+ flow = Flow(agents=[agent1, agent2])
+ assert len(flow.agents) == 2
+ assert agent1 in flow.agents
+ assert agent2 in flow.agents
+
+ def test_flow_with_custom_tools(self):
+ def tool1():
+ pass
+
+ def tool2():
+ pass
+
+ flow = Flow(tools=[tool1, tool2])
+ assert len(flow.tools) == 2
+ assert tool1 in flow.tools
+ assert tool2 in flow.tools
+
+ def test_flow_with_custom_context(self):
+ flow = Flow(context={"key": "value"})
+ assert len(flow.context) == 1
+ assert flow.context["key"] == "value"
+
+ def test_add_message(self, monkeypatch):
+ flow = Flow()
+ mocked_add = MagicMock()
+ monkeypatch.setattr(flow.thread, "add", mocked_add)
+ flow.add_message("Test message", role="user")
+ mocked_add.assert_called_once_with("Test message", role="user")
+
+ def test_flow_context_manager(self):
+ with Flow() as flow:
+ assert ctx.get("flow") == flow
+ assert ctx.get("tasks") == []
+ assert ctx.get("flow") is None
+ assert ctx.get("tasks") is None
+
+ def test_get_flow_within_context(self):
+ with Flow() as flow:
+ assert get_flow() == flow
+
+ def test_get_flow_without_context(self):
+ flow1 = get_flow()
+ with Flow() as flow2:
+ pass
+ flow3 = get_flow()
+ assert flow1 == flow3
+ assert flow1 != flow2
+
+ def test_get_flow_nested_contexts(self):
+ with Flow() as flow1:
+ assert get_flow() == flow1
+ with Flow() as flow2:
+ assert get_flow() == flow2
+ assert get_flow() == flow1
diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py
new file mode 100644
index 00000000..2f20b719
--- /dev/null
+++ b/tests/core/test_graph.py
@@ -0,0 +1,100 @@
+# test_graph.py
+from control_flow.core.graph import Edge, EdgeType, Graph
+from control_flow.core.task import Task
+
+
+class TestGraph:
+ def test_graph_initialization(self):
+ graph = Graph()
+ assert len(graph.tasks) == 0
+ assert len(graph.edges) == 0
+
+ def test_add_task(self):
+ graph = Graph()
+ task = Task(objective="Test objective")
+ graph.add_task(task)
+ assert len(graph.tasks) == 1
+ assert task in graph.tasks
+
+ def test_add_edge(self):
+ graph = Graph()
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2")
+ edge = Edge(upstream=task1, downstream=task2, type=EdgeType.DEPENDENCY)
+ graph.add_edge(edge)
+ assert len(graph.tasks) == 2
+ assert task1 in graph.tasks
+ assert task2 in graph.tasks
+ assert len(graph.edges) == 1
+ assert edge in graph.edges
+
+ def test_from_tasks(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ task3 = Task(objective="Task 3", parent=task2)
+ graph = Graph.from_tasks([task1, task2, task3])
+ assert len(graph.tasks) == 3
+ assert task1 in graph.tasks
+ assert task2 in graph.tasks
+ assert task3 in graph.tasks
+ assert len(graph.edges) == 2
+ assert any(
+ edge.upstream == task1
+ and edge.downstream == task2
+ and edge.type == EdgeType.DEPENDENCY
+ for edge in graph.edges
+ )
+ assert any(
+ edge.upstream == task3
+ and edge.downstream == task2
+ and edge.type == EdgeType.SUBTASK
+ for edge in graph.edges
+ )
+
+ def test_upstream_edges(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ graph = Graph.from_tasks([task1, task2])
+ upstream_edges = graph.upstream_edges()
+ assert len(upstream_edges[task1]) == 0
+ assert len(upstream_edges[task2]) == 1
+ assert upstream_edges[task2][0].upstream == task1
+
+ def test_downstream_edges(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ graph = Graph.from_tasks([task1, task2])
+ downstream_edges = graph.downstream_edges()
+ assert len(downstream_edges[task1]) == 1
+ assert len(downstream_edges[task2]) == 0
+ assert downstream_edges[task1][0].downstream == task2
+
+ def test_upstream_dependencies(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ task3 = Task(objective="Task 3", parent=task2)
+ graph = Graph.from_tasks([task1, task2, task3])
+ dependencies = graph.upstream_dependencies([task3])
+ assert len(dependencies) == 3
+ assert task1 in dependencies
+ assert task2 in dependencies
+ assert task3 in dependencies
+
+ def test_ready_tasks(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ task3 = Task(objective="Task 3", parent=task2)
+ graph = Graph.from_tasks([task1, task2, task3])
+ ready_tasks = graph.ready_tasks()
+ assert len(ready_tasks) == 1
+ assert task1 in ready_tasks
+
+ task1.mark_successful()
+ ready_tasks = graph.ready_tasks()
+ assert len(ready_tasks) == 1
+ assert task2 in ready_tasks
+
+ task2.mark_successful()
+ ready_tasks = graph.ready_tasks()
+ assert len(ready_tasks) == 1
+ assert task3 in ready_tasks
diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py
index feebca57..1ad7caf0 100644
--- a/tests/core/test_tasks.py
+++ b/tests/core/test_tasks.py
@@ -1,22 +1,180 @@
-from control_flow.core.task import Task, get_tasks
+from control_flow.core.agent import Agent
+from control_flow.core.flow import Flow
+from control_flow.core.graph import EdgeType
+from control_flow.core.task import Task, TaskStatus, get_tasks
from control_flow.utilities.context import ctx
-class TestTaskContext:
- def test_context_open_and_close(self):
- assert ctx.get("tasks") == []
- with Task("a") as ta:
- assert ctx.get("tasks") == [ta]
- with Task("b") as tb:
- assert ctx.get("tasks") == [ta, tb]
- assert ctx.get("tasks") == [ta]
- assert ctx.get("tasks") == []
-
- def test_get_tasks_function(self):
- # assert get_tasks() == []
- with Task("a") as ta:
- assert get_tasks() == [ta]
- with Task("b") as tb:
- assert get_tasks() == [ta, tb]
- assert get_tasks() == [ta]
- assert get_tasks() == []
+def test_context_open_and_close():
+ assert ctx.get("tasks") == []
+ with Task("a") as ta:
+ assert ctx.get("tasks") == [ta]
+ with Task("b") as tb:
+ assert ctx.get("tasks") == [ta, tb]
+ assert ctx.get("tasks") == [ta]
+ assert ctx.get("tasks") == []
+
+
+def test_get_tasks_function():
+ # assert get_tasks() == []
+ with Task("a") as ta:
+ assert get_tasks() == [ta]
+ with Task("b") as tb:
+ assert get_tasks() == [ta, tb]
+ assert get_tasks() == [ta]
+ assert get_tasks() == []
+
+
+def test_task_initialization():
+ task = Task(objective="Test objective")
+ assert task.objective == "Test objective"
+ assert task.status == TaskStatus.INCOMPLETE
+ assert task.result is None
+ assert task.error is None
+
+
+def test_task_dependencies():
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ assert task1 in task2.depends_on
+ assert task2 in task1._downstreams
+
+
+def test_task_subtasks():
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", parent=task1)
+ assert task2 in task1.subtasks
+ assert task2._parent == task1
+
+
+def test_task_agent_assignment():
+ agent = Agent(name="Test Agent")
+ task = Task(objective="Test objective", agents=[agent])
+ assert agent in task.agents
+
+
+def test_task_context():
+ with Flow():
+ task = Task(objective="Test objective")
+ assert task in Task._context_stack
+
+
+def test_task_status_transitions():
+ task = Task(objective="Test objective")
+ assert task.is_incomplete()
+ assert not task.is_complete()
+ assert not task.is_successful()
+ assert not task.is_failed()
+ assert not task.is_skipped()
+
+ task.mark_successful()
+ assert not task.is_incomplete()
+ assert task.is_complete()
+ assert task.is_successful()
+ assert not task.is_failed()
+ assert not task.is_skipped()
+
+ task = Task(objective="Test objective")
+ task.mark_failed()
+ assert not task.is_incomplete()
+ assert task.is_complete()
+ assert not task.is_successful()
+ assert task.is_failed()
+ assert not task.is_skipped()
+
+ task = Task(objective="Test objective")
+ task.mark_skipped()
+ assert not task.is_incomplete()
+ assert task.is_complete()
+ assert not task.is_successful()
+ assert not task.is_failed()
+ assert task.is_skipped()
+
+
+def test_task_ready():
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ assert not task2.is_ready()
+
+ task1.mark_successful()
+ assert task2.is_ready()
+
+
+def test_task_hash():
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2")
+ assert hash(task1) != hash(task2)
+
+
+def test_task_tools():
+ task = Task(objective="Test objective")
+ tools = task.get_tools()
+ assert any(tool.name == f"mark_task_{task.id}_failed" for tool in tools)
+ assert any(tool.name == f"mark_task_{task.id}_successful" for tool in tools)
+
+ task.mark_successful()
+ tools = task.get_tools()
+ assert not any(tool.name == f"mark_task_{task.id}_failed" for tool in tools)
+ assert not any(tool.name == f"mark_task_{task.id}_successful" for tool in tools)
+
+
+class TestTaskToGraph:
+ def test_single_task_graph(self):
+ task = Task(objective="Test objective")
+ graph = task.as_graph()
+ assert len(graph.tasks) == 1
+ assert task in graph.tasks
+ assert len(graph.edges) == 0
+
+ def test_task_with_subtasks_graph(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", parent=task1)
+ graph = task1.as_graph()
+ assert len(graph.tasks) == 2
+ assert task1 in graph.tasks
+ assert task2 in graph.tasks
+ assert len(graph.edges) == 1
+ assert any(
+ edge.upstream == task2
+ and edge.downstream == task1
+ and edge.type == EdgeType.SUBTASK
+ for edge in graph.edges
+ )
+
+ def test_task_with_dependencies_graph(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ graph = task2.as_graph()
+ assert len(graph.tasks) == 2
+ assert task1 in graph.tasks
+ assert task2 in graph.tasks
+ assert len(graph.edges) == 1
+ assert any(
+ edge.upstream == task1
+ and edge.downstream == task2
+ and edge.type == EdgeType.DEPENDENCY
+ for edge in graph.edges
+ )
+
+ def test_task_with_subtasks_and_dependencies_graph(self):
+ task1 = Task(objective="Task 1")
+ task2 = Task(objective="Task 2", depends_on=[task1])
+ task3 = Task(objective="Task 3", parent=task2)
+ graph = task2.as_graph()
+ assert len(graph.tasks) == 3
+ assert task1 in graph.tasks
+ assert task2 in graph.tasks
+ assert task3 in graph.tasks
+ assert len(graph.edges) == 2
+ assert any(
+ edge.upstream == task1
+ and edge.downstream == task2
+ and edge.type == EdgeType.DEPENDENCY
+ for edge in graph.edges
+ )
+ assert any(
+ edge.upstream == task3
+ and edge.downstream == task2
+ and edge.type == EdgeType.SUBTASK
+ for edge in graph.edges
+ )
diff --git a/tests/fixtures/flows.py b/tests/fixtures/flows.py
new file mode 100644
index 00000000..052d8f5e
--- /dev/null
+++ b/tests/fixtures/flows.py
@@ -0,0 +1,8 @@
+import pytest
+from control_flow.settings import temporary_settings
+
+
+@pytest.fixture(autouse=True, scope="session")
+def disable_global_flow():
+ with temporary_settings(enable_global_flow=False):
+ yield
diff --git a/tests/test_instructions.py b/tests/test_instructions.py
index 97e8ff5c..28da6cd6 100644
--- a/tests/test_instructions.py
+++ b/tests/test_instructions.py
@@ -16,3 +16,17 @@ def test_instructions_context_nested():
assert get_instructions() == ["abc", "def"]
assert get_instructions() == ["abc"]
assert get_instructions() == []
+
+
+def test_instructions_context_multiple():
+ assert get_instructions() == []
+ with instructions("abc", "def", "ghi"):
+ assert get_instructions() == ["abc", "def", "ghi"]
+ assert get_instructions() == []
+
+
+def test_instructions_context_empty():
+ assert get_instructions() == []
+ with instructions():
+ assert get_instructions() == []
+ assert get_instructions() == []