Skip to content

Commit

Permalink
Merge pull request #34 from jlowin/mock
Browse files Browse the repository at this point in the history
Improve task testing with mocks
  • Loading branch information
jlowin authored May 14, 2024
2 parents 2d301a0 + b4e494f commit bb6bc6f
Show file tree
Hide file tree
Showing 26 changed files with 301 additions and 308 deletions.
18 changes: 8 additions & 10 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
pull_request:
paths:
- .github/workflows/run-tests.yml
- src**
- src/**
- tests/**
- pyproject.toml
- setup.py
Expand All @@ -34,32 +34,30 @@ jobs:
timeout-minutes: 15
strategy:
matrix:
# run no_llm tests across all python versions and oses
# os: [ubuntu-latest, macos-latest, windows-latest]
# python-version: ['3.9', '3.10', '3.11', '3.12']
os: [ubuntu-latest]
python-version: ['3.9']

# python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ["3.9", "3.12"]

runs-on: ${{ matrix.os }}

env:
CONTROLFLOW_OPENAI_API_KEY: ${{ secrets.CONTROLFLOW_OPENAI_API_KEY }}

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: download uv
run: curl -LsSf https://astral.sh/uv/install.sh | sh

- name: Install ControlFlow
run: uv pip install --system ".[tests]"

- name: Run tests
run: pytest -n auto -vv
if: ${{ !(github.event.pull_request.head.repo.fork) }}
run: pytest -vv
if: ${{ !(github.event.pull_request.head.repo.fork) }}
2 changes: 1 addition & 1 deletion docs/concepts/flows.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ flow = Flow()
Flows have several key properties that define their behavior and capabilities:

- `thread` (Thread): The thread associated with the flow, which stores the conversation history and context.
- `tools` (list[AssistantTool | Callable]): A list of tools that are available to all agents in the flow.
- `tools` (list[ToolType]): A list of tools that are available to all agents in the flow.
- `agents` (list[Agent]): The default agents for the flow, which are used for tasks that do not specify agents explicitly.
- `context` (dict): Additional context or information that is shared across tasks and agents in the flow.

Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/tasks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Tasks have several key properties that define their behavior and requirements:
- `agents` (list[Agent], optional): The AI agents assigned to work on the task.
- `context` (dict, optional): Additional context or information required for the task.
- `result_type` (type, optional): The expected type of the task's result.
- `tools` (list[AssistantTool | Callable], optional): Tools or functions available to the agents for completing the task.
- `tools` (list[ToolType], optional): Tools or functions available to the agents for completing the task.
- `user_access` (bool, optional): Indicates whether the task requires human user interaction.

## Task Execution and Results
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ tests = [
"pytest-xdist",
"pre-commit>=3.7.0",
]
dev = ["controlflow[tests]", "ipython>=8.22.2", "pdbpp>=0.10.3", "ruff>=0.3.4"]
dev = ["controlflow[tests]", "ipython", "pdbpp", "ruff>=0.3.4"]

[build-system]
requires = ["hatchling"]
Expand Down Expand Up @@ -77,3 +77,6 @@ skip-magic-trailing-comma = false
"conftest.py" = ["F401", "F403"]
'tests/fixtures/*.py' = ['F401', 'F403']
"src/controlflow/utilities/types.py" = ['F401']

[tool.pytest.ini_options]
timeout = 120
2 changes: 0 additions & 2 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from .settings import settings

# from .agent_old import task, Agent, run_ai
from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow
from .core.task import Task, task
from .core.agent import Agent
from .core.controller.controller import Controller
from .instructions import instructions
from .dx import run_ai

Flow.model_rebuild()
Task.model_rebuild()
Expand Down
8 changes: 4 additions & 4 deletions src/controlflow/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable
from typing import Union

from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.tools import tool_from_function
Expand All @@ -10,7 +10,7 @@
from controlflow.utilities.prefect import (
wrap_prefect_tool,
)
from controlflow.utilities.types import Assistant, AssistantTool, ControlFlowModel
from controlflow.utilities.types import Assistant, ControlFlowModel, ToolType
from controlflow.utilities.user_access import talk_to_human

logger = logging.getLogger(__name__)
Expand All @@ -33,15 +33,15 @@ class Agent(Assistant, ControlFlowModel, ExposeSyncMethodsMixin):
description="If True, the agent is given tools for interacting with a human user.",
)

def get_tools(self) -> list[AssistantTool | Callable]:
def get_tools(self) -> list[ToolType]:
tools = super().get_tools()
if self.user_access:
tools.append(tool_from_function(talk_to_human))

return [wrap_prefect_tool(tool) for tool in tools]

@expose_sync_method("run")
async def run_async(self, tasks: list[Task] | Task | None = None):
async def run_async(self, tasks: Union[list[Task], Task, None] = None):
from controlflow.core.controller import Controller

if isinstance(tasks, Task):
Expand Down
6 changes: 3 additions & 3 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Any
from typing import Any, Union

import marvin.utilities
import marvin.utilities.tools
Expand Down Expand Up @@ -56,7 +56,7 @@ class Controller(BaseModel, ExposeSyncMethodsMixin):
description="Tasks that the controller will complete.",
validate_default=True,
)
agents: list[Agent] | None = None
agents: Union[list[Agent], None] = None
context: dict = {}
graph: Graph = None
model_config: dict = dict(extra="forbid")
Expand Down Expand Up @@ -173,7 +173,7 @@ async def run_once_async(self):
Run the controller for a single iteration of the provided tasks. An agent will be selected to run the tasks.
"""
# get the tasks to run
tasks = self.graph.upstream_dependencies(self.tasks)
tasks = self.graph.upstream_dependencies(self.tasks, include_tasks=True)

# get the agents
agent_candidates = {a for t in tasks for a in t.agents if t.is_ready()}
Expand Down
18 changes: 6 additions & 12 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import functools
import inspect
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Union

import prefect
from marvin.beta.assistants import Thread
from openai.types.beta.threads import Message
from prefect import task as prefect_task
from pydantic import Field, field_validator

import controlflow
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.marvin import patch_marvin
from controlflow.utilities.types import AssistantTool, ControlFlowModel
from controlflow.utilities.types import ControlFlowModel, ToolType

if TYPE_CHECKING:
from controlflow.core.agent import Agent
Expand All @@ -23,7 +22,7 @@

class Flow(ControlFlowModel):
thread: Thread = Field(None, validate_default=True)
tools: list[AssistantTool | Callable] = Field(
tools: list[ToolType] = Field(
default_factory=list,
description="Tools that will be available to every agent in the flow",
)
Expand All @@ -41,8 +40,6 @@ def _load_thread_from_ctx(cls, v):
v = ctx.get("thread", None)
if v is None:
v = Thread()
if not v.id:
v.create()

return v

Expand All @@ -53,9 +50,6 @@ def add_task(self, task: "Task"):
)
self._tasks[task.id] = task

def add_message(self, message: str, role: Literal["user", "assistant"] = None):
prefect_task(self.thread.add)(message, role=role)

@contextmanager
def _context(self):
with ctx(flow=self, tasks=[]):
Expand All @@ -79,7 +73,7 @@ def get_flow() -> Flow:
Will error if no flow is found in the context, unless the global flow is
enabled in settings
"""
flow: Flow | None = ctx.get("flow")
flow: Union[Flow, None] = ctx.get("flow")
if not flow:
if controlflow.settings.enable_global_flow:
return GLOBAL_FLOW
Expand Down Expand Up @@ -108,7 +102,7 @@ def flow(
*,
thread: Thread = None,
instructions: str = None,
tools: list[AssistantTool | Callable] = None,
tools: list[ToolType] = None,
agents: list["Agent"] = None,
):
"""
Expand Down Expand Up @@ -153,7 +147,7 @@ def wrapper(
)

with ctx(flow=flow_obj), patch_marvin():
with controlflow.instructions.instructions(instructions):
with controlflow.instructions(instructions):
return p_fn(*args, **kwargs)

return wrapper
17 changes: 13 additions & 4 deletions src/controlflow/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def downstream_edges(self) -> dict[Task, list[Edge]]:
return self._cache["downstream_edges"]

def upstream_dependencies(
self, tasks: list[Task], prune_completed: bool = True
self,
tasks: list[Task],
prune_completed: bool = True,
include_tasks: bool = False,
) -> list[Task]:
"""
From a list of tasks, returns the subgraph of tasks that are directly or
Expand All @@ -117,11 +120,15 @@ def upstream_dependencies(
dependencies as well as any subtasks that are considered implicit
dependencies.
If `prune_completed` is True, the subgraph will be pruned to stop traversal after adding any completed tasks.
If `prune_completed` is True, the subgraph will be pruned to stop
traversal after adding any completed tasks.
If `include_tasks` is True, the subgraph will include the tasks provided.
"""
subgraph = set()
upstreams = self.upstream_edges()
stack = tasks
# copy stack to allow difference update with original tasks
stack = [t for t in tasks]
while stack:
current = stack.pop()
if current in subgraph:
Expand All @@ -133,6 +140,8 @@ def upstream_dependencies(
continue
stack.extend([edge.upstream for edge in upstreams[current]])

if not include_tasks:
subgraph.difference_update(tasks)
return list(subgraph)

def ready_tasks(self, tasks: list[Task] = None) -> list[Task]:
Expand All @@ -146,7 +155,7 @@ def ready_tasks(self, tasks: list[Task] = None) -> list[Task]:
if tasks is None:
candidates = self.tasks
else:
candidates = self.upstream_dependencies(tasks)
candidates = self.upstream_dependencies(tasks, include_tasks=True)
return sorted(
[task for task in candidates if task.is_ready()], key=lambda t: t.created_at
)
Loading

0 comments on commit bb6bc6f

Please sign in to comment.