Skip to content

Commit

Permalink
Merge pull request #38 from jlowin/deco-tests
Browse files Browse the repository at this point in the history
Add decorator tests for flow
  • Loading branch information
jlowin authored May 15, 2024
2 parents 302fb44 + 5472b03 commit bdefd5f
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 3,407 deletions.
3,393 changes: 0 additions & 3,393 deletions all_files.md

This file was deleted.

16 changes: 15 additions & 1 deletion src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ async def _run_agent(
controller=self, agent=agent, tasks=tasks, thread=thread
)

def choose_agent(
self,
agents: list[Agent],
tasks: list[Task],
history: list = None,
instructions: list[str] = None,
) -> Agent:
return marvin_moderator(
agents=agents,
tasks=tasks,
history=history,
instructions=instructions,
)

@expose_sync_method("run_once")
async def run_once_async(self):
"""
Expand All @@ -190,7 +204,7 @@ async def run_once_async(self):
elif len(agents) == 1:
agent = agents[0]
else:
agent = marvin_moderator(
agent = self.choose_agent(
agents=agents,
tasks=tasks,
history=get_flow_messages(),
Expand Down
21 changes: 12 additions & 9 deletions src/controlflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,19 @@ def wrapper(
if agents is not None:
flow_kwargs.setdefault("agents", agents)

p_fn = prefect.flow(fn)

flow_obj = Flow(**flow_kwargs, context=bound.arguments)

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with flow_obj, patch_marvin():
# create a function to wrap as a Prefect flow
@prefect.flow
def wrapped_flow(*args, **kwargs):
with Task(
fn.__name__,
instructions="Complete all subtasks of this task.",
is_auto_completed_by_subtasks=True,
context=bound.arguments,
) as parent_task:
with controlflow.instructions(instructions):
result = p_fn(*args, **kwargs)
result = fn(*args, **kwargs)

# ensure all subtasks are completed
parent_task.run()
Expand All @@ -107,7 +103,14 @@ def wrapper(
# resolve any returned tasks; this will raise on failure
result = resolve_tasks(result)

return result
return result

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with flow_obj, patch_marvin():
return wrapped_flow(*args, **kwargs)

return wrapper

Expand Down
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import pytest
from controlflow.settings import temporary_settings
from prefect.testing.utilities import prefect_test_harness

from .fixtures import *


@pytest.fixture(autouse=True, scope="session")
def temp_settings():
def temp_controlflow_settings():
with temporary_settings(enable_global_flow=False, max_task_iterations=3):
yield


@pytest.fixture(autouse=True, scope="session")
def prefect_test_fixture():
"""
Run Prefect against temporary sqlite database
"""
with prefect_test_harness():
yield
66 changes: 63 additions & 3 deletions tests/fixtures/mocks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any
from unittest.mock import AsyncMock, Mock, patch

import pytest
from controlflow.core.agent import Agent
from controlflow.core.task import Task, TaskStatus
from controlflow.utilities.user_access import talk_to_human
from marvin.settings import temporary_settings as temporary_marvin_settings

# @pytest.fixture(autouse=True)
# def mock_talk_to_human():
Expand All @@ -20,10 +24,18 @@


@pytest.fixture
def mock_run(monkeypatch):
def prevent_openai_calls():
"""Prevent any calls to the OpenAI API from being made."""
with temporary_marvin_settings(openai__api_key="unset"):
yield


@pytest.fixture
def mock_run(monkeypatch, prevent_openai_calls):
"""
This fixture mocks the calls to OpenAI. Use it in a test and assign any desired side effects (like completing a task)
to the mock object's `.side_effect` attribute.
This fixture mocks the calls to the OpenAI Assistants API. Use it in a test
and assign any desired side effects (like completing a task) to the mock
object's `.side_effect` attribute.
For example:
Expand All @@ -41,3 +53,51 @@ def side_effect():
MockRun = AsyncMock()
monkeypatch.setattr("controlflow.core.controller.controller.Run.run_async", MockRun)
yield MockRun


@pytest.fixture
def mock_controller_run_agent(monkeypatch, prevent_openai_calls):
MockRunAgent = AsyncMock()
MockThreadGetMessages = Mock()

async def _run_agent(agent: Agent, tasks: list[Task] = None, thread=None):
for task in tasks:
if agent in task.agents:
# we can't call mark_successful because we don't know the result
task.status = TaskStatus.SUCCESSFUL

MockRunAgent.side_effect = _run_agent

def get_messages(*args, **kwargs):
return []

MockThreadGetMessages.side_effect = get_messages

monkeypatch.setattr(
"controlflow.core.controller.controller.Controller._run_agent", MockRunAgent
)
monkeypatch.setattr(
"marvin.beta.assistants.Thread.get_messages", MockThreadGetMessages
)
yield MockRunAgent


@pytest.fixture
def mock_controller_choose_agent(monkeypatch):
MockChooseAgent = Mock()

def choose_agent(agents, **kwargs):
return agents[0]

MockChooseAgent.side_effect = choose_agent

monkeypatch.setattr(
"controlflow.core.controller.controller.Controller.choose_agent",
MockChooseAgent,
)
yield MockChooseAgent


@pytest.fixture
def mock_controller(mock_controller_choose_agent, mock_controller_run_agent):
pass
54 changes: 54 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from controlflow import Task
from controlflow.decorators import flow


@pytest.mark.usefixtures("mock_controller")
class TestFlowDecorator:
def test_flow_decorator(self):
@flow
def test_flow():
return 1

result = test_flow()
assert result == 1

def test_flow_decorator_runs_all_tasks(self):
tasks: list[Task] = []

@flow
def test_flow():
task = Task(
"say hello",
result_type=str,
result="Task completed successfully",
)
tasks.append(task)

result = test_flow()
assert result is None
assert tasks[0].is_successful()
assert tasks[0].result == "Task completed successfully"

def test_flow_decorator_resolves_all_tasks(self):
@flow
def test_flow():
task1 = Task("say hello", result="hello")
task2 = Task("say goodbye", result="goodbye")
task3 = Task("say goodnight", result="goodnight")
return dict(a=task1, b=[task2], c=dict(x=dict(y=[[task3]])))

result = test_flow()
assert result == dict(
a="hello", b=["goodbye"], c=dict(x=dict(y=[["goodnight"]]))
)

def test_manually_run_task_in_flow(self):
@flow
def test_flow():
task = Task("say hello", result="hello")
task.run()
return task.result

result = test_flow()
assert result == "hello"

0 comments on commit bdefd5f

Please sign in to comment.