From 0e6ad55c67c4af17b9426e617d7c951d06277f3a Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:03:37 -0400 Subject: [PATCH] name is not required for agents --- src/controlflow/agents/agent.py | 17 +++++++++++------ src/controlflow/agents/names.py | 4 ++-- tests/agents/test_agents.py | 24 ++++++++++++++---------- tests/deprecated/test_agent.py | 4 ++-- tests/flows/test_flows.py | 4 ++-- tests/tasks/test_tasks.py | 10 +++++----- tests/test_defaults.py | 6 +++--- tests/test_settings.py | 2 +- 8 files changed, 40 insertions(+), 31 deletions(-) diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index d8342d57..58f91692 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -1,5 +1,6 @@ import abc import logging +import random import warnings from contextlib import contextmanager from typing import ( @@ -15,6 +16,7 @@ from pydantic import Field, field_serializer import controlflow +from controlflow.agents.names import AGENT_NAMES from controlflow.events.base import Event from controlflow.instructions import get_instructions from controlflow.llm.messages import AIMessage, BaseMessage @@ -44,7 +46,10 @@ class Agent(ControlFlowModel, abc.ABC): model_config = dict(arbitrary_types_allowed=True) id: str = Field(None) - name: str = Field(description="The name of the agent.") + name: str = Field( + default_factory=lambda: random.choice(AGENT_NAMES), + description="The name of the agent.", + ) description: Optional[str] = Field( None, description="A description of the agent, visible to other agents." ) @@ -80,17 +85,17 @@ class Agent(ControlFlowModel, abc.ABC): _cm_stack: list[contextmanager] = [] - def __init__(self, name: str = None, user_access: bool = None, **kwargs): - if name is not None: - kwargs["name"] = name + def __init__(self, instructions: str = None, **kwargs): + if instructions is not None: + kwargs["instructions"] = instructions # deprecated in 0.9 - if user_access is not None: + if "user_access" in kwargs: warnings.warn( "The `user_access` argument is deprecated. Use `interactive=True` instead.", DeprecationWarning, ) - kwargs["interactive"] = True + kwargs["interactive"] = kwargs.pop("user_access") if additional_instructions := get_instructions(): kwargs["instructions"] = ( diff --git a/src/controlflow/agents/names.py b/src/controlflow/agents/names.py index 5ca039a6..10331132 100644 --- a/src/controlflow/agents/names.py +++ b/src/controlflow/agents/names.py @@ -1,4 +1,4 @@ -AGENTS = [ +AGENT_NAMES = [ "HAL 9000", "R2-D2", "C-3PO", @@ -21,7 +21,7 @@ "Norby", ] -TEAMS = [ +TEAM_NAMES = [ "Autobots", "Decepticons", "Borg", diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 50148bcf..d44f15b7 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -8,8 +8,12 @@ class TestAgentInitialization: + def test_positional_arg(self): + agent = Agent("talk like a pirate") + assert agent.instructions == "talk like a pirate" + def test_agent_default_model(self): - agent = Agent(name="Marvin") + agent = Agent() # None indicates it will be loaded from the default model assert agent.model is None @@ -17,7 +21,7 @@ def test_agent_default_model(self): def test_agent_model(self): model = ChatOpenAI(model="gpt-4o-mini") - agent = Agent(name="Marvin", model=model) + agent = Agent(model=model) # None indicates it will be loaded from the default model assert agent.model is model @@ -25,7 +29,7 @@ def test_agent_model(self): def test_agent_loads_instructions_at_creation(self): with instructions("test instruction"): - agent = Agent(name="Marvin") + agent = Agent() assert "test instruction" in agent.instructions @@ -34,10 +38,10 @@ def test_stable_id(self): assert agent.id == "69dd1abd" def test_id_includes_instructions(self): - a1 = Agent(name="Test Agent") - a2 = Agent(name="Test Agent", instructions="abc") - a3 = Agent(name="Test Agent", instructions="def") - a4 = Agent(name="Test Agent", instructions="abc", description="xyz") + a1 = Agent() + a2 = Agent(instructions="abc") + a3 = Agent(instructions="def") + a4 = Agent(instructions="abc", description="xyz") assert a1.id != a2.id != a3.id != a4.id @@ -79,16 +83,16 @@ def test_updating_the_default_model_updates_the_default_agent_model(self): class TestAgentPrompt: def test_default_prompt(self): - agent = Agent(name="Marvin") + agent = Agent() assert agent.prompt is None def test_default_template(self): - agent = Agent(name="Marvin") + agent = Agent() prompt = agent.get_prompt() assert prompt.startswith("# Agent") def test_custom_prompt(self): - agent = Agent(name="Marvin", prompt="Custom Prompt") + agent = Agent(prompt="Custom Prompt") prompt = agent.get_prompt() assert prompt == "Custom Prompt" diff --git a/tests/deprecated/test_agent.py b/tests/deprecated/test_agent.py index 0859ba85..41d776b9 100644 --- a/tests/deprecated/test_agent.py +++ b/tests/deprecated/test_agent.py @@ -6,5 +6,5 @@ # deprecated in 0.9 def test_user_access(): with pytest.warns(DeprecationWarning): - a = controlflow.Agent(name="test", user_access=True) - assert a.interactive + a = controlflow.Agent(user_access=True) + assert a.interactive is True diff --git a/tests/flows/test_flows.py b/tests/flows/test_flows.py index 5f453849..9da9851f 100644 --- a/tests/flows/test_flows.py +++ b/tests/flows/test_flows.py @@ -167,12 +167,12 @@ def test_flow_sets_thread_id_for_history(self, tmpdir): class TestFlowCreatesDefaults: def test_flow_with_custom_agents(self): - agent1 = Agent(name="Agent 1") + agent1 = Agent() flow = Flow(agent=agent1) assert flow.agent == agent1 def test_flow_agent_becomes_task_default(self): - agent = Agent(name="BB8") + agent = Agent() t1 = Task("t1") assert agent not in t1.get_agents() assert len(t1.get_agents()) == 1 diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 10bf176e..e70b225b 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -114,7 +114,7 @@ def test_task_parent_context(): def test_task_agent_assignment(): - agent = Agent(name="Test Agent") + agent = Agent() task = SimpleTask(agents=[agent]) assert task.agents == [agent] assert task.get_agents() == [agent] @@ -126,7 +126,7 @@ def test_task_bad_agent_assignment(): def test_task_loads_agent_from_parent(): - agent = Agent(name="Test Agent") + agent = Agent() with SimpleTask(agents=[agent]): child = SimpleTask() @@ -136,7 +136,7 @@ def test_task_loads_agent_from_parent(): def test_task_loads_agent_from_flow(): def_agent = controlflow.defaults.agent - agent = Agent(name="Test Agent") + agent = Agent() with Flow(agent=agent): task = SimpleTask() @@ -156,8 +156,8 @@ def test_task_loads_agent_from_default_if_none_otherwise(): def test_task_loads_agent_from_parent_before_flow(): - agent1 = Agent(name="Test Agent 1") - agent2 = Agent(name="Test Agent 2") + agent1 = Agent() + agent2 = Agent() with Flow(agent=agent1): with SimpleTask(agents=[agent2]): child = SimpleTask() diff --git a/tests/test_defaults.py b/tests/test_defaults.py index 2bc35556..98ec5bf3 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -17,7 +17,7 @@ def test_default_model_failed_validation(): def test_set_default_model(): model = ChatOpenAI(temperature=0.1) controlflow.defaults.model = model - assert controlflow.Agent(name="Marvin").get_model() is model + assert controlflow.Agent().get_model() is model def test_default_agent_failed_validation(): @@ -29,9 +29,9 @@ def test_default_agent_failed_validation(): def test_set_default_agent(): - agent = controlflow.Agent(name="Marvin") + agent = controlflow.Agent() controlflow.defaults.agent = agent - assert controlflow.Task("").get_agents() == [agent] # Updated to check agents list + assert controlflow.Task("").get_agents() == [agent] def test_default_history_failed_validation(): diff --git a/tests/test_settings.py b/tests/test_settings.py index a9cbf391..7b5a392f 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -85,7 +85,7 @@ def test_import_without_default_api_key_errors_when_loading_model(monkeypatch): ValueError, match="No model provided and no default model could be loaded", ): - controlflow.Agent(name="Marvin").get_model() + controlflow.Agent().get_model() finally: defaults_module = importlib.import_module("controlflow.defaults") importlib.reload(defaults_module)