Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Jun 17, 2024
1 parent 0b818b9 commit b8eb1f4
Show file tree
Hide file tree
Showing 23 changed files with 512 additions and 638 deletions.
20 changes: 12 additions & 8 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from .settings import settings
import controlflow.llm

# --- Default model ---
# assign to controlflow.default_model to change the default model
from .llm.models import DEFAULT_MODEL as default_model

from .core.flow import Flow
from .core.task import Task
from .core.agent import Agent
Expand All @@ -13,16 +9,24 @@
from .instructions import instructions
from .decorators import flow, task

# --- Default settings ---

from .llm.models import model_from_string, get_default_model
from .llm.history import InMemoryHistory, get_default_history

# assign to controlflow.default_model to change the default model
default_model = model_from_string(controlflow.settings.llm_model)
del model_from_string

# --- Default history ---
# assign to controlflow.default_history to change the default history
from .llm.history import DEFAULT_HISTORY as default_history, get_default_history
default_history = InMemoryHistory()
del InMemoryHistory

# --- Default agent ---
# assign to controlflow.default_agent to change the default agent
from .core.agent.agent import DEFAULT_AGENT as default_agent, get_default_agent
default_agent = Agent(name="Marvin")

# --- Version ---

try:
from ._version import version as __version__ # type: ignore
except ImportError:
Expand Down
14 changes: 13 additions & 1 deletion src/controlflow/core/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import random
import re
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional
Expand All @@ -24,12 +25,19 @@ def get_default_agent() -> "Agent":
return controlflow.default_agent


def sanitize_name(name):
"""
Replace any invalid characters with `-`, due to restrictions on names in the API
"""
sanitized_string = re.sub(r"[^a-zA-Z0-9_-]", "-", name)
return sanitized_string


class Agent(ControlFlowModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5]))
model_config = dict(arbitrary_types_allowed=True)
name: str = Field(
description="The name of the agent.",
pattern=r"^[a-zA-Z0-9_-]+$",
default_factory=lambda: random.choice(NAMES),
)
description: Optional[str] = Field(
Expand Down Expand Up @@ -68,6 +76,10 @@ def _serialize_tools(self, tools: list[Callable]):
# tools are Pydantic 1 objects
return [t.dict(include={"name", "description"}) for t in tools]

@field_serializer("name")
def _serialize_name(self, name: str):
return sanitize_name(name)

def __init__(self, name=None, **kwargs):
if name is not None:
kwargs["name"] = name
Expand Down
8 changes: 4 additions & 4 deletions src/controlflow/core/agent/names.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
NAMES = [
"HAL-9000",
"HAL 9000",
"R2-D2",
"C-3PO",
"WALL-E",
"T-800",
"GLaDOS",
"JARVIS",
"J.A.R.V.I.S",
"EVE",
"KITT",
"Johnny-5",
"Johnny 5",
"BB-8",
"Ultron",
"TARS",
"Agent-Smith",
"Agent Smith",
"CLU",
"Deckard",
"HK-47",
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def graph(self) -> Graph:
@model_validator(mode="after")
def _finalize(self):
if self.tasks is None:
self.tasks = list(self.flow._tasks.values())
self.tasks = list(self.flow.tasks.values())
for task in self.tasks:
self.flow.add_task(task)
return self
Expand Down
10 changes: 5 additions & 5 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Flow(ControlFlowModel):
default_factory=list,
)
context: dict[str, Any] = {}
_tasks: dict[str, "Task"] = {}
tasks: dict[str, "Task"] = {}
_cm_stack: list[contextmanager] = []

def __init__(self, *, copy_parent_history: bool = True, **kwargs):
Expand Down Expand Up @@ -75,11 +75,11 @@ def add_messages(self, messages: list[MessageType]):
self.history.save_messages(thread_id=self.thread_id, messages=messages)

def add_task(self, task: "Task"):
if self._tasks.get(task.id, task) is not task:
if self.tasks.get(task.id, task) is not task:
raise ValueError(
f"A different task with id '{task.id}' already exists in flow."
)
self._tasks[task.id] = task
self.tasks[task.id] = task

@contextmanager
def create_context(self, create_prefect_flow_context: bool = True):
Expand All @@ -94,15 +94,15 @@ async def run_async(self):
"""
Runs the flow asynchronously.
"""
if self._tasks:
if self.tasks:
controller = controlflow.Controller(flow=self)
await controller.run_async()

def run(self):
"""
Runs the flow.
"""
if self._tasks:
if self.tasks:
controller = controlflow.Controller(flow=self)
controller.run()

Expand Down
15 changes: 8 additions & 7 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ class Task(ControlFlowModel):
def __init__(
self,
objective=None,
result_type=None,
result_type=NOTSET,
**kwargs,
):
# allow certain args to be provided as a positional args
if result_type is not None:
if result_type is not NOTSET:
kwargs["result_type"] = result_type
if objective is not None:
kwargs["objective"] = objective
Expand Down Expand Up @@ -442,7 +442,7 @@ def is_ready(self) -> bool:
"""
return self.is_incomplete() and all(t.is_complete() for t in self.depends_on)

def _create_success_tool(self) -> Callable:
def _create_success_tool(self) -> Tool:
"""
Create an agent-compatible tool for marking this task as successful.
"""
Expand All @@ -466,7 +466,7 @@ def succeed(result: result_schema) -> str: # type: ignore
metadata=dict(is_task_status_tool=True),
)

def _create_fail_tool(self) -> Callable:
def _create_fail_tool(self) -> Tool:
"""
Create an agent-compatible tool for failing this task.
"""
Expand All @@ -478,7 +478,7 @@ def _create_fail_tool(self) -> Callable:
metadata=dict(is_task_status_tool=True),
)

def _create_skip_tool(self) -> Callable:
def _create_skip_tool(self) -> Tool:
"""
Create an agent-compatible tool for skipping this task.
"""
Expand Down Expand Up @@ -525,9 +525,10 @@ def get_agent_strategy(self) -> Callable:

return controlflow.agent_strategies.round_robin

def get_tools(self) -> list[Callable]:
def get_tools(self) -> list[Union[Tool, Callable]]:
tools = self.tools.copy()
if self.is_incomplete():
# if this task is ready to run, generate tools
if self.is_ready:
tools.extend([self._create_fail_tool(), self._create_success_tool()])
# add skip tool if this task has a parent task
# if self.parent is not None:
Expand Down
3 changes: 0 additions & 3 deletions src/controlflow/llm/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,3 @@ def save_messages(self, thread_id: str, messages: list[MessageType]):
all_messages.extend([msg.model_dump(mode="json") for msg in messages])
with open(self.path(thread_id), "w") as f:
json.dump(all_messages, f)


DEFAULT_HISTORY = InMemoryHistory()
8 changes: 6 additions & 2 deletions src/controlflow/llm/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from typing import Any, Optional

from langchain_core.language_models import BaseChatModel

import controlflow


def get_default_model() -> BaseChatModel:
if controlflow.default_model is None:
if getattr(controlflow, "default_model", None) is None:
return model_from_string(controlflow.settings.llm_model)
else:
return controlflow.default_model


def model_from_string(model: str, temperature: float = None, **kwargs) -> BaseChatModel:
def model_from_string(
model: str, temperature: Optional[float] = None, **kwargs: Any
) -> BaseChatModel:
if "/" not in model:
provider, model = "openai", model
provider, model = model.split("/")
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/tui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TUIApp(App):

def __init__(self, flow: "controlflow.Flow", **kwargs):
self._flow = flow
self._tasks = flow._tasks
self._tasks = flow.tasks
self._is_ready = False
super().__init__(**kwargs)

Expand Down
1 change: 1 addition & 0 deletions tests/ai_tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_task_pydantic_result(self):
assert isinstance(result, Name)
assert result == Name(first="John", last="Doe")

@pytest.xfail(reason="Need to revisit dataframe handling")
def test_task_dataframe_result(self):
task = Task(
'return a dataframe with column "x" that has values 1 and 2 and column "y" that has values 3 and 4',
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from controlflow.llm.messages import MessageType
from controlflow.settings import temporary_settings
from prefect.testing.utilities import prefect_test_harness

Expand Down
23 changes: 0 additions & 23 deletions tests/core/agents.py

This file was deleted.

40 changes: 40 additions & 0 deletions tests/core/test_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import controlflow
from controlflow.core.agent import Agent, get_default_agent
from controlflow.core.agent.names import NAMES
from controlflow.core.task import Task


class TestAgentInitialization:
def test_agent_gets_random_name(self):
agent = Agent()

assert agent.name in NAMES

def test_agent_default_model(self):
agent = Agent()

assert agent.model is controlflow.get_default_model()


class TestDefaultAgent:
def test_default_agent_is_marvin(self):
agent = get_default_agent()
assert agent.name == "Marvin"

def test_default_agent_has_no_tools(self):
assert get_default_agent().tools == []

def test_default_agent_can_be_assigned(self):
# baseline
assert get_default_agent().name == "Marvin"

new_default_agent = Agent(name="New Agent")
controlflow.default_agent = new_default_agent

assert get_default_agent().name == "New Agent"
assert Task("task").get_agents()[0] is new_default_agent
assert [a.name for a in Task("task").get_agents()] == ["New Agent"]

def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()
75 changes: 0 additions & 75 deletions tests/core/test_controller.py

This file was deleted.

Loading

0 comments on commit b8eb1f4

Please sign in to comment.