Skip to content

Improve task testing with mocks #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
pull_request:
paths:
- .github/workflows/run-tests.yml
- src**
- src/**
- tests/**
- pyproject.toml
- setup.py
Expand All @@ -34,32 +34,30 @@ jobs:
timeout-minutes: 15
strategy:
matrix:
# run no_llm tests across all python versions and oses
# os: [ubuntu-latest, macos-latest, windows-latest]
# python-version: ['3.9', '3.10', '3.11', '3.12']
os: [ubuntu-latest]
python-version: ['3.9']

# python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ["3.9", "3.12"]

runs-on: ${{ matrix.os }}

env:
CONTROLFLOW_OPENAI_API_KEY: ${{ secrets.CONTROLFLOW_OPENAI_API_KEY }}

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: download uv
run: curl -LsSf https://astral.sh/uv/install.sh | sh

- name: Install ControlFlow
run: uv pip install --system ".[tests]"

- name: Run tests
run: pytest -n auto -vv
if: ${{ !(github.event.pull_request.head.repo.fork) }}
run: pytest -vv
if: ${{ !(github.event.pull_request.head.repo.fork) }}
2 changes: 1 addition & 1 deletion docs/concepts/flows.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ flow = Flow()
Flows have several key properties that define their behavior and capabilities:

- `thread` (Thread): The thread associated with the flow, which stores the conversation history and context.
- `tools` (list[AssistantTool | Callable]): A list of tools that are available to all agents in the flow.
- `tools` (list[ToolType]): A list of tools that are available to all agents in the flow.
- `agents` (list[Agent]): The default agents for the flow, which are used for tasks that do not specify agents explicitly.
- `context` (dict): Additional context or information that is shared across tasks and agents in the flow.

Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/tasks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Tasks have several key properties that define their behavior and requirements:
- `agents` (list[Agent], optional): The AI agents assigned to work on the task.
- `context` (dict, optional): Additional context or information required for the task.
- `result_type` (type, optional): The expected type of the task's result.
- `tools` (list[AssistantTool | Callable], optional): Tools or functions available to the agents for completing the task.
- `tools` (list[ToolType], optional): Tools or functions available to the agents for completing the task.
- `user_access` (bool, optional): Indicates whether the task requires human user interaction.

## Task Execution and Results
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ tests = [
"pytest-xdist",
"pre-commit>=3.7.0",
]
dev = ["controlflow[tests]", "ipython>=8.22.2", "pdbpp>=0.10.3", "ruff>=0.3.4"]
dev = ["controlflow[tests]", "ipython", "pdbpp", "ruff>=0.3.4"]

[build-system]
requires = ["hatchling"]
Expand Down Expand Up @@ -77,3 +77,6 @@ skip-magic-trailing-comma = false
"conftest.py" = ["F401", "F403"]
'tests/fixtures/*.py' = ['F401', 'F403']
"src/controlflow/utilities/types.py" = ['F401']

[tool.pytest.ini_options]
timeout = 120
2 changes: 0 additions & 2 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from .settings import settings

# from .agent_old import task, Agent, run_ai
from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow
from .core.task import Task, task
from .core.agent import Agent
from .core.controller.controller import Controller
from .instructions import instructions
from .dx import run_ai

Flow.model_rebuild()
Task.model_rebuild()
Expand Down
8 changes: 4 additions & 4 deletions src/controlflow/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable
from typing import Union

from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.tools import tool_from_function
Expand All @@ -10,7 +10,7 @@
from controlflow.utilities.prefect import (
wrap_prefect_tool,
)
from controlflow.utilities.types import Assistant, AssistantTool, ControlFlowModel
from controlflow.utilities.types import Assistant, ControlFlowModel, ToolType
from controlflow.utilities.user_access import talk_to_human

logger = logging.getLogger(__name__)
Expand All @@ -33,15 +33,15 @@ class Agent(Assistant, ControlFlowModel, ExposeSyncMethodsMixin):
description="If True, the agent is given tools for interacting with a human user.",
)

def get_tools(self) -> list[AssistantTool | Callable]:
def get_tools(self) -> list[ToolType]:
tools = super().get_tools()
if self.user_access:
tools.append(tool_from_function(talk_to_human))

return [wrap_prefect_tool(tool) for tool in tools]

@expose_sync_method("run")
async def run_async(self, tasks: list[Task] | Task | None = None):
async def run_async(self, tasks: Union[list[Task], Task, None] = None):
from controlflow.core.controller import Controller

if isinstance(tasks, Task):
Expand Down
6 changes: 3 additions & 3 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Any
from typing import Any, Union

import marvin.utilities
import marvin.utilities.tools
Expand Down Expand Up @@ -56,7 +56,7 @@ class Controller(BaseModel, ExposeSyncMethodsMixin):
description="Tasks that the controller will complete.",
validate_default=True,
)
agents: list[Agent] | None = None
agents: Union[list[Agent], None] = None
context: dict = {}
graph: Graph = None
model_config: dict = dict(extra="forbid")
Expand Down Expand Up @@ -173,7 +173,7 @@ async def run_once_async(self):
Run the controller for a single iteration of the provided tasks. An agent will be selected to run the tasks.
"""
# get the tasks to run
tasks = self.graph.upstream_dependencies(self.tasks)
tasks = self.graph.upstream_dependencies(self.tasks, include_tasks=True)

# get the agents
agent_candidates = {a for t in tasks for a in t.agents if t.is_ready()}
Expand Down
18 changes: 6 additions & 12 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import functools
import inspect
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Union

import prefect
from marvin.beta.assistants import Thread
from openai.types.beta.threads import Message
from prefect import task as prefect_task
from pydantic import Field, field_validator

import controlflow
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.marvin import patch_marvin
from controlflow.utilities.types import AssistantTool, ControlFlowModel
from controlflow.utilities.types import ControlFlowModel, ToolType

if TYPE_CHECKING:
from controlflow.core.agent import Agent
Expand All @@ -23,7 +22,7 @@

class Flow(ControlFlowModel):
thread: Thread = Field(None, validate_default=True)
tools: list[AssistantTool | Callable] = Field(
tools: list[ToolType] = Field(
default_factory=list,
description="Tools that will be available to every agent in the flow",
)
Expand All @@ -41,8 +40,6 @@ def _load_thread_from_ctx(cls, v):
v = ctx.get("thread", None)
if v is None:
v = Thread()
if not v.id:
v.create()

return v

Expand All @@ -53,9 +50,6 @@ def add_task(self, task: "Task"):
)
self._tasks[task.id] = task

def add_message(self, message: str, role: Literal["user", "assistant"] = None):
prefect_task(self.thread.add)(message, role=role)

@contextmanager
def _context(self):
with ctx(flow=self, tasks=[]):
Expand All @@ -79,7 +73,7 @@ def get_flow() -> Flow:
Will error if no flow is found in the context, unless the global flow is
enabled in settings
"""
flow: Flow | None = ctx.get("flow")
flow: Union[Flow, None] = ctx.get("flow")
if not flow:
if controlflow.settings.enable_global_flow:
return GLOBAL_FLOW
Expand Down Expand Up @@ -108,7 +102,7 @@ def flow(
*,
thread: Thread = None,
instructions: str = None,
tools: list[AssistantTool | Callable] = None,
tools: list[ToolType] = None,
agents: list["Agent"] = None,
):
"""
Expand Down Expand Up @@ -153,7 +147,7 @@ def wrapper(
)

with ctx(flow=flow_obj), patch_marvin():
with controlflow.instructions.instructions(instructions):
with controlflow.instructions(instructions):
return p_fn(*args, **kwargs)

return wrapper
17 changes: 13 additions & 4 deletions src/controlflow/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def downstream_edges(self) -> dict[Task, list[Edge]]:
return self._cache["downstream_edges"]

def upstream_dependencies(
self, tasks: list[Task], prune_completed: bool = True
self,
tasks: list[Task],
prune_completed: bool = True,
include_tasks: bool = False,
) -> list[Task]:
"""
From a list of tasks, returns the subgraph of tasks that are directly or
Expand All @@ -117,11 +120,15 @@ def upstream_dependencies(
dependencies as well as any subtasks that are considered implicit
dependencies.

If `prune_completed` is True, the subgraph will be pruned to stop traversal after adding any completed tasks.
If `prune_completed` is True, the subgraph will be pruned to stop
traversal after adding any completed tasks.

If `include_tasks` is True, the subgraph will include the tasks provided.
"""
subgraph = set()
upstreams = self.upstream_edges()
stack = tasks
# copy stack to allow difference update with original tasks
stack = [t for t in tasks]
while stack:
current = stack.pop()
if current in subgraph:
Expand All @@ -133,6 +140,8 @@ def upstream_dependencies(
continue
stack.extend([edge.upstream for edge in upstreams[current]])

if not include_tasks:
subgraph.difference_update(tasks)
return list(subgraph)

def ready_tasks(self, tasks: list[Task] = None) -> list[Task]:
Expand All @@ -146,7 +155,7 @@ def ready_tasks(self, tasks: list[Task] = None) -> list[Task]:
if tasks is None:
candidates = self.tasks
else:
candidates = self.upstream_dependencies(tasks)
candidates = self.upstream_dependencies(tasks, include_tasks=True)
return sorted(
[task for task in candidates if task.is_ready()], key=lambda t: t.created_at
)
Loading