Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed May 12, 2024
1 parent 2596a02 commit 538a419
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 80 deletions.
10 changes: 7 additions & 3 deletions src/control_flow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from prefect import task as prefect_task
from pydantic import Field, field_validator

import control_flow
from control_flow.utilities.context import ctx
from control_flow.utilities.logging import get_logger
from control_flow.utilities.types import AssistantTool, ControlFlowModel
Expand Down Expand Up @@ -36,7 +37,6 @@ class Flow(ControlFlowModel):
description="The default agents for the flow. These agents will be used "
"for any task that does not specify agents.",
)
model: str | None = None
context: dict = {}

@field_validator("thread", mode="before")
Expand Down Expand Up @@ -73,11 +73,15 @@ def get_flow() -> Flow:
"""
Loads the flow from the context.
Will error if no flow is found in the context.
Will error if no flow is found in the context, unless the global flow is
enabled in settings
"""
flow: Flow | None = ctx.get("flow")
if not flow:
return GLOBAL_FLOW
if control_flow.settings.enable_global_flow:
return GLOBAL_FLOW
else:
raise ValueError("No flow found in context.")
return flow


Expand Down
60 changes: 2 additions & 58 deletions src/control_flow/instructions.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,28 @@
import inspect
from contextlib import contextmanager
from typing import Generator, List

from control_flow.core.flow import Flow
from control_flow.utilities.context import ctx
from control_flow.utilities.logging import get_logger

logger = get_logger(__name__)


@contextmanager
def instructions(
*instructions: str,
post_add_message: bool = False,
post_remove_message: bool = False,
) -> Generator[list[str], None, None]:
def instructions(*instructions: str) -> Generator[list[str], None, None]:
"""
Temporarily add instructions to the current instruction stack. The
instruction is removed when the context is exited.
If `post_add_message` is True, a message will be added to the flow when the
instruction is added. If `post_remove_message` is True, a message will be
added to the flow when the instruction is removed. These explicit reminders
can help when agents infer instructions more from history.
with instructions("talk like a pirate"):
...
"""

if post_add_message or post_remove_message:
flow: Flow = ctx.get("flow")
if flow is None:
raise ValueError(
"instructions() with message posting must be used within a flow context"
)

stack: list[str] = ctx.get("instructions", [])
stack = stack + list(instructions)

with ctx(instructions=stack):
try:
if post_add_message:
for instruction in instructions:
flow.add_message(
inspect.cleandoc(
"""
# SYSTEM MESSAGE: INSTRUCTION ADDED
The following instruction is now active:
<instruction>
{instruction}
</instruction>
Always consult your current instructions before acting.
"""
).format(instruction=instruction)
)
yield

# yield new_stack
finally:
if post_remove_message:
for instruction in instructions:
flow.add_message(
inspect.cleandoc(
"""
# SYSTEM MESSAGE: INSTRUCTION REMOVED
The following instruction is no longer active:
<instruction>
{instruction}
</instruction>
Always consult your current instructions before acting.
"""
).format(instruction=instruction)
)
yield


def get_instructions() -> List[str]:
Expand Down
56 changes: 56 additions & 0 deletions src/control_flow/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import sys
import warnings
from contextlib import contextmanager
from copy import deepcopy
from typing import Any

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
Expand Down Expand Up @@ -46,10 +49,63 @@ class Settings(ControlFlowSettings):
assistant_model: str = "gpt-4-1106-preview"
max_agent_iterations: int = 10
prefect: PrefectSettings = Field(default_factory=PrefectSettings)
enable_global_flow: bool = Field(
True,
description="If True, a global flow is created for convenience, so users don't have to wrap every invocation in a flow function. Disable to avoid accidentally sharing context between agents.",
)

def __init__(self, **data):
super().__init__(**data)
self.prefect.apply()


settings = Settings()


@contextmanager
def temporary_settings(**kwargs: Any):
"""
Temporarily override ControlFlow setting values, including nested settings objects.
To override nested settings, use `__` to separate nested attribute names.
Args:
**kwargs: The settings to override, including nested settings.
Example:
Temporarily override log level and OpenAI API key:
```python
import control_flow
from control_flow.settings import temporary_settings
# Override top-level settings
with temporary_settings(log_level="INFO"):
assert control_flow.settings.log_level == "INFO"
assert control_flow.settings.log_level == "DEBUG"
# Override nested settings
with temporary_settings(openai__api_key="new-api-key"):
assert control_flow.settings.openai.api_key.get_secret_value() == "new-api-key"
assert control_flow.settings.openai.api_key.get_secret_value().startswith("sk-")
```
"""
old_env = os.environ.copy()
old_settings = deepcopy(settings)

def set_nested_attr(obj: object, attr_path: str, value: Any):
parts = attr_path.split("__")
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)

try:
for attr_path, value in kwargs.items():
set_nested_attr(settings, attr_path, value)
yield

finally:
os.environ.clear()
os.environ.update(old_env)

for attr, value in old_settings:
set_nested_attr(settings, attr, value)
85 changes: 85 additions & 0 deletions tests/core/test_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from unittest.mock import AsyncMock

import pytest
from control_flow.core.agent import Agent
from control_flow.core.controller.controller import Controller
from control_flow.core.flow import Flow
from control_flow.core.graph import EdgeType
from control_flow.core.task import Task


class TestController:
@pytest.fixture
def flow(self):
return Flow()

@pytest.fixture
def agent(self):
return Agent(name="Test Agent")

@pytest.fixture
def task(self):
return Task(objective="Test Task")

def test_controller_initialization(self, flow, agent, task):
controller = Controller(flow=flow, tasks=[task], agents=[agent])
assert controller.flow == flow
assert controller.tasks == [task]
assert controller.agents == [agent]
assert controller.run_dependencies is True
assert len(controller.context) == 0
assert len(controller.graph.tasks) == 1
assert len(controller.graph.edges) == 0

def test_controller_missing_tasks(self, flow):
with pytest.raises(ValueError, match="At least one task is required."):
Controller(flow=flow, tasks=[])

async def test_run_agent(self, flow, agent, task, monkeypatch):
controller = Controller(flow=flow, tasks=[task], agents=[agent])
mocked_run = AsyncMock()
monkeypatch.setattr(Agent, "run", mocked_run)
await controller._run_agent(agent, tasks=[task])
mocked_run.assert_called_once_with(tasks=[task])

async def test_run_once(self, flow, agent, task, monkeypatch):
controller = Controller(flow=flow, tasks=[task], agents=[agent])
mocked_run_agent = AsyncMock()
monkeypatch.setattr(Controller, "_run_agent", mocked_run_agent)
await controller.run_once_async()
mocked_run_agent.assert_called_once_with(agent, tasks=[task])

def test_create_end_run_tool(self, flow, agent, task):
controller = Controller(flow=flow, tasks=[task], agents=[agent])
end_run_tool = controller._create_end_run_tool()
assert end_run_tool.function.name == "end_run"
assert end_run_tool.function.description.startswith("End your turn")

def test_controller_graph_creation(self, flow, agent):
task1 = Task(objective="Task 1")
task2 = Task(objective="Task 2", depends_on=[task1])
controller = Controller(flow=flow, tasks=[task1, task2], agents=[agent])
assert len(controller.graph.tasks) == 2
assert len(controller.graph.edges) == 1
assert controller.graph.edges.pop().type == EdgeType.dependency

def test_controller_agent_selection(self, flow, monkeypatch):
agent1 = Agent(name="Agent 1")
agent2 = Agent(name="Agent 2")
task = Task(objective="Test Task", agents=[agent1, agent2])
controller = Controller(flow=flow, tasks=[task], agents=[agent1, agent2])
mocked_marvin_moderator = AsyncMock(return_value=agent1)
monkeypatch.setattr(
"control_flow.core.controller.moderators.marvin_moderator",
mocked_marvin_moderator,
)
assert controller.agents == [agent1, agent2]

async def test_controller_run_dependencies(self, flow, agent):
task1 = Task(objective="Task 1")
task2 = Task(objective="Task 2", depends_on=[task1])
controller = Controller(flow=flow, tasks=[task2], agents=[agent])
mocked_run_agent = AsyncMock()
controller._run_agent = mocked_run_agent
await controller.run_once_async()
mocked_run_agent.assert_called_once_with(agent, tasks=[task1, task2])
74 changes: 74 additions & 0 deletions tests/core/test_flows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# test_flow.py
from unittest.mock import MagicMock

from control_flow.core.agent import Agent
from control_flow.core.flow import Flow, get_flow
from control_flow.utilities.context import ctx


class TestFlow:
def test_flow_initialization(self):
flow = Flow()
assert flow.thread is not None
assert len(flow.tools) == 0
assert len(flow.agents) == 1
assert isinstance(flow.agents[0], Agent)
assert len(flow.context) == 0

def test_flow_with_custom_agents(self):
agent1 = Agent(name="Agent 1")
agent2 = Agent(name="Agent 2")
flow = Flow(agents=[agent1, agent2])
assert len(flow.agents) == 2
assert agent1 in flow.agents
assert agent2 in flow.agents

def test_flow_with_custom_tools(self):
def tool1():
pass

def tool2():
pass

flow = Flow(tools=[tool1, tool2])
assert len(flow.tools) == 2
assert tool1 in flow.tools
assert tool2 in flow.tools

def test_flow_with_custom_context(self):
flow = Flow(context={"key": "value"})
assert len(flow.context) == 1
assert flow.context["key"] == "value"

def test_add_message(self, monkeypatch):
flow = Flow()
mocked_add = MagicMock()
monkeypatch.setattr(flow.thread, "add", mocked_add)
flow.add_message("Test message", role="user")
mocked_add.assert_called_once_with("Test message", role="user")

def test_flow_context_manager(self):
with Flow() as flow:
assert ctx.get("flow") == flow
assert ctx.get("tasks") == []
assert ctx.get("flow") is None
assert ctx.get("tasks") is None

def test_get_flow_within_context(self):
with Flow() as flow:
assert get_flow() == flow

def test_get_flow_without_context(self):
flow1 = get_flow()
with Flow() as flow2:
pass
flow3 = get_flow()
assert flow1 == flow3
assert flow1 != flow2

def test_get_flow_nested_contexts(self):
with Flow() as flow1:
assert get_flow() == flow1
with Flow() as flow2:
assert get_flow() == flow2
assert get_flow() == flow1
Loading

0 comments on commit 538a419

Please sign in to comment.