Skip to content

Commit

Permalink
Merge pull request #124 from PrefectHQ/tests
Browse files Browse the repository at this point in the history
Update tests
  • Loading branch information
jlowin authored Jun 17, 2024
2 parents 0b818b9 + 6b0161d commit c9a0531
Show file tree
Hide file tree
Showing 24 changed files with 522 additions and 649 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
runs-on: ${{ matrix.os }}

env:
CONTROLFLOW_OPENAI_API_KEY: ${{ secrets.CONTROLFLOW_OPENAI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

steps:
- uses: actions/checkout@v4
Expand Down
24 changes: 14 additions & 10 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
from .settings import settings
import controlflow.llm

# --- Default model ---
# assign to controlflow.default_model to change the default model
from .llm.models import DEFAULT_MODEL as default_model

from .core.flow import Flow
from .core.task import Task
from .core.agent import Agent
from .core.task import Task
from .core.flow import Flow
from .core.controller.controller import Controller

from .instructions import instructions
from .decorators import flow, task

# --- Default settings ---

from .llm.models import model_from_string, get_default_model
from .llm.history import InMemoryHistory, get_default_history

# assign to controlflow.default_model to change the default model
default_model = model_from_string(controlflow.settings.llm_model)
del model_from_string

# --- Default history ---
# assign to controlflow.default_history to change the default history
from .llm.history import DEFAULT_HISTORY as default_history, get_default_history
default_history = InMemoryHistory()
del InMemoryHistory

# --- Default agent ---
# assign to controlflow.default_agent to change the default agent
from .core.agent.agent import DEFAULT_AGENT as default_agent, get_default_agent
default_agent = Agent(name="Marvin")

# --- Version ---

try:
from ._version import version as __version__ # type: ignore
except ImportError:
Expand Down
14 changes: 13 additions & 1 deletion src/controlflow/core/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import random
import re
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional
Expand All @@ -24,12 +25,19 @@ def get_default_agent() -> "Agent":
return controlflow.default_agent


def sanitize_name(name):
"""
Replace any invalid characters with `-`, due to restrictions on names in the API
"""
sanitized_string = re.sub(r"[^a-zA-Z0-9_-]", "-", name)
return sanitized_string


class Agent(ControlFlowModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5]))
model_config = dict(arbitrary_types_allowed=True)
name: str = Field(
description="The name of the agent.",
pattern=r"^[a-zA-Z0-9_-]+$",
default_factory=lambda: random.choice(NAMES),
)
description: Optional[str] = Field(
Expand Down Expand Up @@ -68,6 +76,10 @@ def _serialize_tools(self, tools: list[Callable]):
# tools are Pydantic 1 objects
return [t.dict(include={"name", "description"}) for t in tools]

@field_serializer("name")
def _serialize_name(self, name: str):
return sanitize_name(name)

def __init__(self, name=None, **kwargs):
if name is not None:
kwargs["name"] = name
Expand Down
8 changes: 4 additions & 4 deletions src/controlflow/core/agent/names.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
NAMES = [
"HAL-9000",
"HAL 9000",
"R2-D2",
"C-3PO",
"WALL-E",
"T-800",
"GLaDOS",
"JARVIS",
"J.A.R.V.I.S",
"EVE",
"KITT",
"Johnny-5",
"Johnny 5",
"BB-8",
"Ultron",
"TARS",
"Agent-Smith",
"Agent Smith",
"CLU",
"Deckard",
"HK-47",
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def graph(self) -> Graph:
@model_validator(mode="after")
def _finalize(self):
if self.tasks is None:
self.tasks = list(self.flow._tasks.values())
self.tasks = list(self.flow.tasks.values())
for task in self.tasks:
self.flow.add_task(task)
return self
Expand Down
21 changes: 10 additions & 11 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import datetime
import uuid
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union

from pydantic import Field

import controlflow
import controlflow.llm
from controlflow.core.agent import Agent
from controlflow.core.task import Task
from controlflow.llm.history import History, get_default_history
from controlflow.llm.messages import MessageType
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.prefect import prefect_flow_context
from controlflow.utilities.types import ControlFlowModel

if TYPE_CHECKING:
from controlflow.core.agent import Agent
from controlflow.core.task import Task
logger = get_logger(__name__)


Expand All @@ -31,13 +30,13 @@ class Flow(ControlFlowModel):
default_factory=list,
description="Tools that will be available to every agent in the flow",
)
agents: list["Agent"] = Field(
agents: list[Agent] = Field(
description="The default agents for the flow. These agents will be used "
"for any task that does not specify agents.",
default_factory=list,
)
context: dict[str, Any] = {}
_tasks: dict[str, "Task"] = {}
tasks: dict[str, Task] = {}
_cm_stack: list[contextmanager] = []

def __init__(self, *, copy_parent_history: bool = True, **kwargs):
Expand Down Expand Up @@ -74,12 +73,12 @@ def get_messages(
def add_messages(self, messages: list[MessageType]):
self.history.save_messages(thread_id=self.thread_id, messages=messages)

def add_task(self, task: "Task"):
if self._tasks.get(task.id, task) is not task:
def add_task(self, task: Task):
if self.tasks.get(task.id, task) is not task:
raise ValueError(
f"A different task with id '{task.id}' already exists in flow."
)
self._tasks[task.id] = task
self.tasks[task.id] = task

@contextmanager
def create_context(self, create_prefect_flow_context: bool = True):
Expand All @@ -94,15 +93,15 @@ async def run_async(self):
"""
Runs the flow asynchronously.
"""
if self._tasks:
if self.tasks:
controller = controlflow.Controller(flow=self)
await controller.run_async()

def run(self):
"""
Runs the flow.
"""
if self._tasks:
if self.tasks:
controller = controlflow.Controller(flow=self)
controller.run()

Expand Down
19 changes: 10 additions & 9 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import controlflow
import controlflow.core
from controlflow.core.agent import Agent
from controlflow.instructions import get_instructions
from controlflow.llm.tools import Tool
from controlflow.tools.talk_to_human import talk_to_human
Expand All @@ -47,7 +48,6 @@
)

if TYPE_CHECKING:
from controlflow.core.agent import Agent
from controlflow.core.flow import Flow
from controlflow.core.graph import Graph

Expand All @@ -57,7 +57,7 @@

def get_task_run_name() -> str:
context = TaskRunContext.get()
return f'Run {context.parameters['self'].friendly_name()}'
return f'Run {context.parameters["self"].friendly_name()}'


class TaskStatus(Enum):
Expand Down Expand Up @@ -128,11 +128,11 @@ class Task(ControlFlowModel):
def __init__(
self,
objective=None,
result_type=None,
result_type=NOTSET,
**kwargs,
):
# allow certain args to be provided as a positional args
if result_type is not None:
if result_type is not NOTSET:
kwargs["result_type"] = result_type
if objective is not None:
kwargs["objective"] = objective
Expand Down Expand Up @@ -442,7 +442,7 @@ def is_ready(self) -> bool:
"""
return self.is_incomplete() and all(t.is_complete() for t in self.depends_on)

def _create_success_tool(self) -> Callable:
def _create_success_tool(self) -> Tool:
"""
Create an agent-compatible tool for marking this task as successful.
"""
Expand All @@ -466,7 +466,7 @@ def succeed(result: result_schema) -> str: # type: ignore
metadata=dict(is_task_status_tool=True),
)

def _create_fail_tool(self) -> Callable:
def _create_fail_tool(self) -> Tool:
"""
Create an agent-compatible tool for failing this task.
"""
Expand All @@ -478,7 +478,7 @@ def _create_fail_tool(self) -> Callable:
metadata=dict(is_task_status_tool=True),
)

def _create_skip_tool(self) -> Callable:
def _create_skip_tool(self) -> Tool:
"""
Create an agent-compatible tool for skipping this task.
"""
Expand Down Expand Up @@ -525,9 +525,10 @@ def get_agent_strategy(self) -> Callable:

return controlflow.agent_strategies.round_robin

def get_tools(self) -> list[Callable]:
def get_tools(self) -> list[Union[Tool, Callable]]:
tools = self.tools.copy()
if self.is_incomplete():
# if this task is ready to run, generate tools
if self.is_ready:
tools.extend([self._create_fail_tool(), self._create_success_tool()])
# add skip tool if this task has a parent task
# if self.parent is not None:
Expand Down
3 changes: 0 additions & 3 deletions src/controlflow/llm/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,3 @@ def save_messages(self, thread_id: str, messages: list[MessageType]):
all_messages.extend([msg.model_dump(mode="json") for msg in messages])
with open(self.path(thread_id), "w") as f:
json.dump(all_messages, f)


DEFAULT_HISTORY = InMemoryHistory()
8 changes: 6 additions & 2 deletions src/controlflow/llm/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from typing import Any, Optional

from langchain_core.language_models import BaseChatModel

import controlflow


def get_default_model() -> BaseChatModel:
if controlflow.default_model is None:
if getattr(controlflow, "default_model", None) is None:
return model_from_string(controlflow.settings.llm_model)
else:
return controlflow.default_model


def model_from_string(model: str, temperature: float = None, **kwargs) -> BaseChatModel:
def model_from_string(
model: str, temperature: Optional[float] = None, **kwargs: Any
) -> BaseChatModel:
if "/" not in model:
provider, model = "openai", model
provider, model = model.split("/")
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/tui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TUIApp(App):

def __init__(self, flow: "controlflow.Flow", **kwargs):
self._flow = flow
self._tasks = flow._tasks
self._tasks = flow.tasks
self._is_ready = False
super().__init__(**kwargs)

Expand Down
1 change: 1 addition & 0 deletions tests/ai_tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_task_pydantic_result(self):
assert isinstance(result, Name)
assert result == Name(first="John", last="Doe")

@pytest.mark.xfail(reason="Need to revisit dataframe handling")
def test_task_dataframe_result(self):
task = Task(
'return a dataframe with column "x" that has values 1 and 2 and column "y" that has values 3 and 4',
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from controlflow.llm.messages import MessageType
from controlflow.settings import temporary_settings
from prefect.testing.utilities import prefect_test_harness

Expand Down
23 changes: 0 additions & 23 deletions tests/core/agents.py

This file was deleted.

40 changes: 40 additions & 0 deletions tests/core/test_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import controlflow
from controlflow.core.agent import Agent, get_default_agent
from controlflow.core.agent.names import NAMES
from controlflow.core.task import Task


class TestAgentInitialization:
def test_agent_gets_random_name(self):
agent = Agent()

assert agent.name in NAMES

def test_agent_default_model(self):
agent = Agent()

assert agent.model is controlflow.get_default_model()


class TestDefaultAgent:
def test_default_agent_is_marvin(self):
agent = get_default_agent()
assert agent.name == "Marvin"

def test_default_agent_has_no_tools(self):
assert get_default_agent().tools == []

def test_default_agent_can_be_assigned(self):
# baseline
assert get_default_agent().name == "Marvin"

new_default_agent = Agent(name="New Agent")
controlflow.default_agent = new_default_agent

assert get_default_agent().name == "New Agent"
assert Task("task").get_agents()[0] is new_default_agent
assert [a.name for a in Task("task").get_agents()] == ["New Agent"]

def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()
Loading

0 comments on commit c9a0531

Please sign in to comment.