-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
523 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from unittest.mock import AsyncMock | ||
|
||
import pytest | ||
from control_flow.core.agent import Agent | ||
from control_flow.core.controller.controller import Controller | ||
from control_flow.core.flow import Flow | ||
from control_flow.core.graph import EdgeType | ||
from control_flow.core.task import Task | ||
|
||
|
||
class TestController: | ||
@pytest.fixture | ||
def flow(self): | ||
return Flow() | ||
|
||
@pytest.fixture | ||
def agent(self): | ||
return Agent(name="Test Agent") | ||
|
||
@pytest.fixture | ||
def task(self): | ||
return Task(objective="Test Task") | ||
|
||
def test_controller_initialization(self, flow, agent, task): | ||
controller = Controller(flow=flow, tasks=[task], agents=[agent]) | ||
assert controller.flow == flow | ||
assert controller.tasks == [task] | ||
assert controller.agents == [agent] | ||
assert controller.run_dependencies is True | ||
assert len(controller.context) == 0 | ||
assert len(controller.graph.tasks) == 1 | ||
assert len(controller.graph.edges) == 0 | ||
|
||
def test_controller_missing_tasks(self, flow): | ||
with pytest.raises(ValueError, match="At least one task is required."): | ||
Controller(flow=flow, tasks=[]) | ||
|
||
async def test_run_agent(self, flow, agent, task, monkeypatch): | ||
controller = Controller(flow=flow, tasks=[task], agents=[agent]) | ||
mocked_run = AsyncMock() | ||
monkeypatch.setattr(Agent, "run", mocked_run) | ||
await controller._run_agent(agent, tasks=[task]) | ||
mocked_run.assert_called_once_with(tasks=[task]) | ||
|
||
async def test_run_once(self, flow, agent, task, monkeypatch): | ||
controller = Controller(flow=flow, tasks=[task], agents=[agent]) | ||
mocked_run_agent = AsyncMock() | ||
monkeypatch.setattr(Controller, "_run_agent", mocked_run_agent) | ||
await controller.run_once_async() | ||
mocked_run_agent.assert_called_once_with(agent, tasks=[task]) | ||
|
||
def test_create_end_run_tool(self, flow, agent, task): | ||
controller = Controller(flow=flow, tasks=[task], agents=[agent]) | ||
end_run_tool = controller._create_end_run_tool() | ||
assert end_run_tool.function.name == "end_run" | ||
assert end_run_tool.function.description.startswith("End your turn") | ||
|
||
def test_controller_graph_creation(self, flow, agent): | ||
task1 = Task(objective="Task 1") | ||
task2 = Task(objective="Task 2", depends_on=[task1]) | ||
controller = Controller(flow=flow, tasks=[task1, task2], agents=[agent]) | ||
assert len(controller.graph.tasks) == 2 | ||
assert len(controller.graph.edges) == 1 | ||
assert controller.graph.edges.pop().type == EdgeType.dependency | ||
|
||
def test_controller_agent_selection(self, flow, monkeypatch): | ||
agent1 = Agent(name="Agent 1") | ||
agent2 = Agent(name="Agent 2") | ||
task = Task(objective="Test Task", agents=[agent1, agent2]) | ||
controller = Controller(flow=flow, tasks=[task], agents=[agent1, agent2]) | ||
mocked_marvin_moderator = AsyncMock(return_value=agent1) | ||
monkeypatch.setattr( | ||
"control_flow.core.controller.moderators.marvin_moderator", | ||
mocked_marvin_moderator, | ||
) | ||
assert controller.agents == [agent1, agent2] | ||
|
||
async def test_controller_run_dependencies(self, flow, agent): | ||
task1 = Task(objective="Task 1") | ||
task2 = Task(objective="Task 2", depends_on=[task1]) | ||
controller = Controller(flow=flow, tasks=[task2], agents=[agent]) | ||
mocked_run_agent = AsyncMock() | ||
controller._run_agent = mocked_run_agent | ||
await controller.run_once_async() | ||
mocked_run_agent.assert_called_once_with(agent, tasks=[task1, task2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# test_flow.py | ||
from unittest.mock import MagicMock | ||
|
||
from control_flow.core.agent import Agent | ||
from control_flow.core.flow import Flow, get_flow | ||
from control_flow.utilities.context import ctx | ||
|
||
|
||
class TestFlow: | ||
def test_flow_initialization(self): | ||
flow = Flow() | ||
assert flow.thread is not None | ||
assert len(flow.tools) == 0 | ||
assert len(flow.agents) == 1 | ||
assert isinstance(flow.agents[0], Agent) | ||
assert len(flow.context) == 0 | ||
|
||
def test_flow_with_custom_agents(self): | ||
agent1 = Agent(name="Agent 1") | ||
agent2 = Agent(name="Agent 2") | ||
flow = Flow(agents=[agent1, agent2]) | ||
assert len(flow.agents) == 2 | ||
assert agent1 in flow.agents | ||
assert agent2 in flow.agents | ||
|
||
def test_flow_with_custom_tools(self): | ||
def tool1(): | ||
pass | ||
|
||
def tool2(): | ||
pass | ||
|
||
flow = Flow(tools=[tool1, tool2]) | ||
assert len(flow.tools) == 2 | ||
assert tool1 in flow.tools | ||
assert tool2 in flow.tools | ||
|
||
def test_flow_with_custom_context(self): | ||
flow = Flow(context={"key": "value"}) | ||
assert len(flow.context) == 1 | ||
assert flow.context["key"] == "value" | ||
|
||
def test_add_message(self, monkeypatch): | ||
flow = Flow() | ||
mocked_add = MagicMock() | ||
monkeypatch.setattr(flow.thread, "add", mocked_add) | ||
flow.add_message("Test message", role="user") | ||
mocked_add.assert_called_once_with("Test message", role="user") | ||
|
||
def test_flow_context_manager(self): | ||
with Flow() as flow: | ||
assert ctx.get("flow") == flow | ||
assert ctx.get("tasks") == [] | ||
assert ctx.get("flow") is None | ||
assert ctx.get("tasks") is None | ||
|
||
def test_get_flow_within_context(self): | ||
with Flow() as flow: | ||
assert get_flow() == flow | ||
|
||
def test_get_flow_without_context(self): | ||
flow1 = get_flow() | ||
with Flow() as flow2: | ||
pass | ||
flow3 = get_flow() | ||
assert flow1 == flow3 | ||
assert flow1 != flow2 | ||
|
||
def test_get_flow_nested_contexts(self): | ||
with Flow() as flow1: | ||
assert get_flow() == flow1 | ||
with Flow() as flow2: | ||
assert get_flow() == flow2 | ||
assert get_flow() == flow1 |
Oops, something went wrong.