Skip to content

Commit bb1d009

Browse files
authored
Merge pull request #22 from jlowin/tests
Add tests
2 parents 2596a02 + 538a419 commit bb1d009

File tree

9 files changed

+523
-80
lines changed

9 files changed

+523
-80
lines changed

src/control_flow/core/flow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from prefect import task as prefect_task
77
from pydantic import Field, field_validator
88

9+
import control_flow
910
from control_flow.utilities.context import ctx
1011
from control_flow.utilities.logging import get_logger
1112
from control_flow.utilities.types import AssistantTool, ControlFlowModel
@@ -36,7 +37,6 @@ class Flow(ControlFlowModel):
3637
description="The default agents for the flow. These agents will be used "
3738
"for any task that does not specify agents.",
3839
)
39-
model: str | None = None
4040
context: dict = {}
4141

4242
@field_validator("thread", mode="before")
@@ -73,11 +73,15 @@ def get_flow() -> Flow:
7373
"""
7474
Loads the flow from the context.
7575
76-
Will error if no flow is found in the context.
76+
Will error if no flow is found in the context, unless the global flow is
77+
enabled in settings
7778
"""
7879
flow: Flow | None = ctx.get("flow")
7980
if not flow:
80-
return GLOBAL_FLOW
81+
if control_flow.settings.enable_global_flow:
82+
return GLOBAL_FLOW
83+
else:
84+
raise ValueError("No flow found in context.")
8185
return flow
8286

8387

src/control_flow/instructions.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,28 @@
1-
import inspect
21
from contextlib import contextmanager
32
from typing import Generator, List
43

5-
from control_flow.core.flow import Flow
64
from control_flow.utilities.context import ctx
75
from control_flow.utilities.logging import get_logger
86

97
logger = get_logger(__name__)
108

119

1210
@contextmanager
13-
def instructions(
14-
*instructions: str,
15-
post_add_message: bool = False,
16-
post_remove_message: bool = False,
17-
) -> Generator[list[str], None, None]:
11+
def instructions(*instructions: str) -> Generator[list[str], None, None]:
1812
"""
1913
Temporarily add instructions to the current instruction stack. The
2014
instruction is removed when the context is exited.
2115
22-
If `post_add_message` is True, a message will be added to the flow when the
23-
instruction is added. If `post_remove_message` is True, a message will be
24-
added to the flow when the instruction is removed. These explicit reminders
25-
can help when agents infer instructions more from history.
26-
2716
with instructions("talk like a pirate"):
2817
...
2918
3019
"""
3120

32-
if post_add_message or post_remove_message:
33-
flow: Flow = ctx.get("flow")
34-
if flow is None:
35-
raise ValueError(
36-
"instructions() with message posting must be used within a flow context"
37-
)
38-
3921
stack: list[str] = ctx.get("instructions", [])
4022
stack = stack + list(instructions)
4123

4224
with ctx(instructions=stack):
43-
try:
44-
if post_add_message:
45-
for instruction in instructions:
46-
flow.add_message(
47-
inspect.cleandoc(
48-
"""
49-
# SYSTEM MESSAGE: INSTRUCTION ADDED
50-
51-
The following instruction is now active:
52-
53-
<instruction>
54-
{instruction}
55-
</instruction>
56-
57-
Always consult your current instructions before acting.
58-
"""
59-
).format(instruction=instruction)
60-
)
61-
yield
62-
63-
# yield new_stack
64-
finally:
65-
if post_remove_message:
66-
for instruction in instructions:
67-
flow.add_message(
68-
inspect.cleandoc(
69-
"""
70-
# SYSTEM MESSAGE: INSTRUCTION REMOVED
71-
72-
The following instruction is no longer active:
73-
74-
<instruction>
75-
{instruction}
76-
</instruction>
77-
78-
Always consult your current instructions before acting.
79-
"""
80-
).format(instruction=instruction)
81-
)
25+
yield
8226

8327

8428
def get_instructions() -> List[str]:

src/control_flow/settings.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
22
import sys
33
import warnings
4+
from contextlib import contextmanager
5+
from copy import deepcopy
6+
from typing import Any
47

58
from pydantic import Field
69
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -46,10 +49,63 @@ class Settings(ControlFlowSettings):
4649
assistant_model: str = "gpt-4-1106-preview"
4750
max_agent_iterations: int = 10
4851
prefect: PrefectSettings = Field(default_factory=PrefectSettings)
52+
enable_global_flow: bool = Field(
53+
True,
54+
description="If True, a global flow is created for convenience, so users don't have to wrap every invocation in a flow function. Disable to avoid accidentally sharing context between agents.",
55+
)
4956

5057
def __init__(self, **data):
5158
super().__init__(**data)
5259
self.prefect.apply()
5360

5461

5562
settings = Settings()
63+
64+
65+
@contextmanager
66+
def temporary_settings(**kwargs: Any):
67+
"""
68+
Temporarily override ControlFlow setting values, including nested settings objects.
69+
70+
To override nested settings, use `__` to separate nested attribute names.
71+
72+
Args:
73+
**kwargs: The settings to override, including nested settings.
74+
75+
Example:
76+
Temporarily override log level and OpenAI API key:
77+
```python
78+
import control_flow
79+
from control_flow.settings import temporary_settings
80+
81+
# Override top-level settings
82+
with temporary_settings(log_level="INFO"):
83+
assert control_flow.settings.log_level == "INFO"
84+
assert control_flow.settings.log_level == "DEBUG"
85+
86+
# Override nested settings
87+
with temporary_settings(openai__api_key="new-api-key"):
88+
assert control_flow.settings.openai.api_key.get_secret_value() == "new-api-key"
89+
assert control_flow.settings.openai.api_key.get_secret_value().startswith("sk-")
90+
```
91+
"""
92+
old_env = os.environ.copy()
93+
old_settings = deepcopy(settings)
94+
95+
def set_nested_attr(obj: object, attr_path: str, value: Any):
96+
parts = attr_path.split("__")
97+
for part in parts[:-1]:
98+
obj = getattr(obj, part)
99+
setattr(obj, parts[-1], value)
100+
101+
try:
102+
for attr_path, value in kwargs.items():
103+
set_nested_attr(settings, attr_path, value)
104+
yield
105+
106+
finally:
107+
os.environ.clear()
108+
os.environ.update(old_env)
109+
110+
for attr, value in old_settings:
111+
set_nested_attr(settings, attr, value)

tests/core/test_controller.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from unittest.mock import AsyncMock
2+
3+
import pytest
4+
from control_flow.core.agent import Agent
5+
from control_flow.core.controller.controller import Controller
6+
from control_flow.core.flow import Flow
7+
from control_flow.core.graph import EdgeType
8+
from control_flow.core.task import Task
9+
10+
11+
class TestController:
12+
@pytest.fixture
13+
def flow(self):
14+
return Flow()
15+
16+
@pytest.fixture
17+
def agent(self):
18+
return Agent(name="Test Agent")
19+
20+
@pytest.fixture
21+
def task(self):
22+
return Task(objective="Test Task")
23+
24+
def test_controller_initialization(self, flow, agent, task):
25+
controller = Controller(flow=flow, tasks=[task], agents=[agent])
26+
assert controller.flow == flow
27+
assert controller.tasks == [task]
28+
assert controller.agents == [agent]
29+
assert controller.run_dependencies is True
30+
assert len(controller.context) == 0
31+
assert len(controller.graph.tasks) == 1
32+
assert len(controller.graph.edges) == 0
33+
34+
def test_controller_missing_tasks(self, flow):
35+
with pytest.raises(ValueError, match="At least one task is required."):
36+
Controller(flow=flow, tasks=[])
37+
38+
async def test_run_agent(self, flow, agent, task, monkeypatch):
39+
controller = Controller(flow=flow, tasks=[task], agents=[agent])
40+
mocked_run = AsyncMock()
41+
monkeypatch.setattr(Agent, "run", mocked_run)
42+
await controller._run_agent(agent, tasks=[task])
43+
mocked_run.assert_called_once_with(tasks=[task])
44+
45+
async def test_run_once(self, flow, agent, task, monkeypatch):
46+
controller = Controller(flow=flow, tasks=[task], agents=[agent])
47+
mocked_run_agent = AsyncMock()
48+
monkeypatch.setattr(Controller, "_run_agent", mocked_run_agent)
49+
await controller.run_once_async()
50+
mocked_run_agent.assert_called_once_with(agent, tasks=[task])
51+
52+
def test_create_end_run_tool(self, flow, agent, task):
53+
controller = Controller(flow=flow, tasks=[task], agents=[agent])
54+
end_run_tool = controller._create_end_run_tool()
55+
assert end_run_tool.function.name == "end_run"
56+
assert end_run_tool.function.description.startswith("End your turn")
57+
58+
def test_controller_graph_creation(self, flow, agent):
59+
task1 = Task(objective="Task 1")
60+
task2 = Task(objective="Task 2", depends_on=[task1])
61+
controller = Controller(flow=flow, tasks=[task1, task2], agents=[agent])
62+
assert len(controller.graph.tasks) == 2
63+
assert len(controller.graph.edges) == 1
64+
assert controller.graph.edges.pop().type == EdgeType.dependency
65+
66+
def test_controller_agent_selection(self, flow, monkeypatch):
67+
agent1 = Agent(name="Agent 1")
68+
agent2 = Agent(name="Agent 2")
69+
task = Task(objective="Test Task", agents=[agent1, agent2])
70+
controller = Controller(flow=flow, tasks=[task], agents=[agent1, agent2])
71+
mocked_marvin_moderator = AsyncMock(return_value=agent1)
72+
monkeypatch.setattr(
73+
"control_flow.core.controller.moderators.marvin_moderator",
74+
mocked_marvin_moderator,
75+
)
76+
assert controller.agents == [agent1, agent2]
77+
78+
async def test_controller_run_dependencies(self, flow, agent):
79+
task1 = Task(objective="Task 1")
80+
task2 = Task(objective="Task 2", depends_on=[task1])
81+
controller = Controller(flow=flow, tasks=[task2], agents=[agent])
82+
mocked_run_agent = AsyncMock()
83+
controller._run_agent = mocked_run_agent
84+
await controller.run_once_async()
85+
mocked_run_agent.assert_called_once_with(agent, tasks=[task1, task2])

tests/core/test_flows.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# test_flow.py
2+
from unittest.mock import MagicMock
3+
4+
from control_flow.core.agent import Agent
5+
from control_flow.core.flow import Flow, get_flow
6+
from control_flow.utilities.context import ctx
7+
8+
9+
class TestFlow:
10+
def test_flow_initialization(self):
11+
flow = Flow()
12+
assert flow.thread is not None
13+
assert len(flow.tools) == 0
14+
assert len(flow.agents) == 1
15+
assert isinstance(flow.agents[0], Agent)
16+
assert len(flow.context) == 0
17+
18+
def test_flow_with_custom_agents(self):
19+
agent1 = Agent(name="Agent 1")
20+
agent2 = Agent(name="Agent 2")
21+
flow = Flow(agents=[agent1, agent2])
22+
assert len(flow.agents) == 2
23+
assert agent1 in flow.agents
24+
assert agent2 in flow.agents
25+
26+
def test_flow_with_custom_tools(self):
27+
def tool1():
28+
pass
29+
30+
def tool2():
31+
pass
32+
33+
flow = Flow(tools=[tool1, tool2])
34+
assert len(flow.tools) == 2
35+
assert tool1 in flow.tools
36+
assert tool2 in flow.tools
37+
38+
def test_flow_with_custom_context(self):
39+
flow = Flow(context={"key": "value"})
40+
assert len(flow.context) == 1
41+
assert flow.context["key"] == "value"
42+
43+
def test_add_message(self, monkeypatch):
44+
flow = Flow()
45+
mocked_add = MagicMock()
46+
monkeypatch.setattr(flow.thread, "add", mocked_add)
47+
flow.add_message("Test message", role="user")
48+
mocked_add.assert_called_once_with("Test message", role="user")
49+
50+
def test_flow_context_manager(self):
51+
with Flow() as flow:
52+
assert ctx.get("flow") == flow
53+
assert ctx.get("tasks") == []
54+
assert ctx.get("flow") is None
55+
assert ctx.get("tasks") is None
56+
57+
def test_get_flow_within_context(self):
58+
with Flow() as flow:
59+
assert get_flow() == flow
60+
61+
def test_get_flow_without_context(self):
62+
flow1 = get_flow()
63+
with Flow() as flow2:
64+
pass
65+
flow3 = get_flow()
66+
assert flow1 == flow3
67+
assert flow1 != flow2
68+
69+
def test_get_flow_nested_contexts(self):
70+
with Flow() as flow1:
71+
assert get_flow() == flow1
72+
with Flow() as flow2:
73+
assert get_flow() == flow2
74+
assert get_flow() == flow1

0 commit comments

Comments
 (0)