Skip to content

Commit

Permalink
Merge pull request #270 from PrefectHQ/agents
Browse files Browse the repository at this point in the history
Name is not required for agents; instructions can be provided as a positional arg
  • Loading branch information
jlowin authored Sep 3, 2024
2 parents 909aa83 + 0e6ad55 commit 64ca7f0
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 31 deletions.
17 changes: 11 additions & 6 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import logging
import random
import warnings
from contextlib import contextmanager
from typing import (
Expand All @@ -15,6 +16,7 @@
from pydantic import Field, field_serializer

import controlflow
from controlflow.agents.names import AGENT_NAMES
from controlflow.events.base import Event
from controlflow.instructions import get_instructions
from controlflow.llm.messages import AIMessage, BaseMessage
Expand Down Expand Up @@ -44,7 +46,10 @@ class Agent(ControlFlowModel, abc.ABC):
model_config = dict(arbitrary_types_allowed=True)

id: str = Field(None)
name: str = Field(description="The name of the agent.")
name: str = Field(
default_factory=lambda: random.choice(AGENT_NAMES),
description="The name of the agent.",
)
description: Optional[str] = Field(
None, description="A description of the agent, visible to other agents."
)
Expand Down Expand Up @@ -80,17 +85,17 @@ class Agent(ControlFlowModel, abc.ABC):

_cm_stack: list[contextmanager] = []

def __init__(self, name: str = None, user_access: bool = None, **kwargs):
if name is not None:
kwargs["name"] = name
def __init__(self, instructions: str = None, **kwargs):
if instructions is not None:
kwargs["instructions"] = instructions

# deprecated in 0.9
if user_access is not None:
if "user_access" in kwargs:
warnings.warn(
"The `user_access` argument is deprecated. Use `interactive=True` instead.",
DeprecationWarning,
)
kwargs["interactive"] = True
kwargs["interactive"] = kwargs.pop("user_access")

if additional_instructions := get_instructions():
kwargs["instructions"] = (
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/agents/names.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
AGENTS = [
AGENT_NAMES = [
"HAL 9000",
"R2-D2",
"C-3PO",
Expand All @@ -21,7 +21,7 @@
"Norby",
]

TEAMS = [
TEAM_NAMES = [
"Autobots",
"Decepticons",
"Borg",
Expand Down
24 changes: 14 additions & 10 deletions tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,28 @@


class TestAgentInitialization:
def test_positional_arg(self):
agent = Agent("talk like a pirate")
assert agent.instructions == "talk like a pirate"

def test_agent_default_model(self):
agent = Agent(name="Marvin")
agent = Agent()

# None indicates it will be loaded from the default model
assert agent.model is None
assert agent.get_model() is controlflow.defaults.model

def test_agent_model(self):
model = ChatOpenAI(model="gpt-4o-mini")
agent = Agent(name="Marvin", model=model)
agent = Agent(model=model)

# None indicates it will be loaded from the default model
assert agent.model is model
assert agent.get_model() is model

def test_agent_loads_instructions_at_creation(self):
with instructions("test instruction"):
agent = Agent(name="Marvin")
agent = Agent()

assert "test instruction" in agent.instructions

Expand All @@ -34,10 +38,10 @@ def test_stable_id(self):
assert agent.id == "69dd1abd"

def test_id_includes_instructions(self):
a1 = Agent(name="Test Agent")
a2 = Agent(name="Test Agent", instructions="abc")
a3 = Agent(name="Test Agent", instructions="def")
a4 = Agent(name="Test Agent", instructions="abc", description="xyz")
a1 = Agent()
a2 = Agent(instructions="abc")
a3 = Agent(instructions="def")
a4 = Agent(instructions="abc", description="xyz")

assert a1.id != a2.id != a3.id != a4.id

Expand Down Expand Up @@ -79,16 +83,16 @@ def test_updating_the_default_model_updates_the_default_agent_model(self):

class TestAgentPrompt:
def test_default_prompt(self):
agent = Agent(name="Marvin")
agent = Agent()
assert agent.prompt is None

def test_default_template(self):
agent = Agent(name="Marvin")
agent = Agent()
prompt = agent.get_prompt()
assert prompt.startswith("# Agent")

def test_custom_prompt(self):
agent = Agent(name="Marvin", prompt="Custom Prompt")
agent = Agent(prompt="Custom Prompt")
prompt = agent.get_prompt()
assert prompt == "Custom Prompt"

Expand Down
4 changes: 2 additions & 2 deletions tests/deprecated/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
# deprecated in 0.9
def test_user_access():
with pytest.warns(DeprecationWarning):
a = controlflow.Agent(name="test", user_access=True)
assert a.interactive
a = controlflow.Agent(user_access=True)
assert a.interactive is True
4 changes: 2 additions & 2 deletions tests/flows/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ def test_flow_sets_thread_id_for_history(self, tmpdir):

class TestFlowCreatesDefaults:
def test_flow_with_custom_agents(self):
agent1 = Agent(name="Agent 1")
agent1 = Agent()
flow = Flow(agent=agent1)
assert flow.agent == agent1

def test_flow_agent_becomes_task_default(self):
agent = Agent(name="BB8")
agent = Agent()
t1 = Task("t1")
assert agent not in t1.get_agents()
assert len(t1.get_agents()) == 1
Expand Down
10 changes: 5 additions & 5 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_task_parent_context():


def test_task_agent_assignment():
agent = Agent(name="Test Agent")
agent = Agent()
task = SimpleTask(agents=[agent])
assert task.agents == [agent]
assert task.get_agents() == [agent]
Expand All @@ -126,7 +126,7 @@ def test_task_bad_agent_assignment():


def test_task_loads_agent_from_parent():
agent = Agent(name="Test Agent")
agent = Agent()
with SimpleTask(agents=[agent]):
child = SimpleTask()

Expand All @@ -136,7 +136,7 @@ def test_task_loads_agent_from_parent():

def test_task_loads_agent_from_flow():
def_agent = controlflow.defaults.agent
agent = Agent(name="Test Agent")
agent = Agent()
with Flow(agent=agent):
task = SimpleTask()

Expand All @@ -156,8 +156,8 @@ def test_task_loads_agent_from_default_if_none_otherwise():


def test_task_loads_agent_from_parent_before_flow():
agent1 = Agent(name="Test Agent 1")
agent2 = Agent(name="Test Agent 2")
agent1 = Agent()
agent2 = Agent()
with Flow(agent=agent1):
with SimpleTask(agents=[agent2]):
child = SimpleTask()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_default_model_failed_validation():
def test_set_default_model():
model = ChatOpenAI(temperature=0.1)
controlflow.defaults.model = model
assert controlflow.Agent(name="Marvin").get_model() is model
assert controlflow.Agent().get_model() is model


def test_default_agent_failed_validation():
Expand All @@ -29,9 +29,9 @@ def test_default_agent_failed_validation():


def test_set_default_agent():
agent = controlflow.Agent(name="Marvin")
agent = controlflow.Agent()
controlflow.defaults.agent = agent
assert controlflow.Task("").get_agents() == [agent] # Updated to check agents list
assert controlflow.Task("").get_agents() == [agent]


def test_default_history_failed_validation():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_import_without_default_api_key_errors_when_loading_model(monkeypatch):
ValueError,
match="No model provided and no default model could be loaded",
):
controlflow.Agent(name="Marvin").get_model()
controlflow.Agent().get_model()
finally:
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(defaults_module)
Expand Down

0 comments on commit 64ca7f0

Please sign in to comment.