From 0e6ad55c67c4af17b9426e617d7c951d06277f3a Mon Sep 17 00:00:00 2001
From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com>
Date: Tue, 3 Sep 2024 16:03:37 -0400
Subject: [PATCH] name is not required for agents

---
 src/controlflow/agents/agent.py | 17 +++++++++++------
 src/controlflow/agents/names.py |  4 ++--
 tests/agents/test_agents.py     | 24 ++++++++++++++----------
 tests/deprecated/test_agent.py  |  4 ++--
 tests/flows/test_flows.py       |  4 ++--
 tests/tasks/test_tasks.py       | 10 +++++-----
 tests/test_defaults.py          |  6 +++---
 tests/test_settings.py          |  2 +-
 8 files changed, 40 insertions(+), 31 deletions(-)

diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py
index d8342d57..58f91692 100644
--- a/src/controlflow/agents/agent.py
+++ b/src/controlflow/agents/agent.py
@@ -1,5 +1,6 @@
 import abc
 import logging
+import random
 import warnings
 from contextlib import contextmanager
 from typing import (
@@ -15,6 +16,7 @@
 from pydantic import Field, field_serializer
 
 import controlflow
+from controlflow.agents.names import AGENT_NAMES
 from controlflow.events.base import Event
 from controlflow.instructions import get_instructions
 from controlflow.llm.messages import AIMessage, BaseMessage
@@ -44,7 +46,10 @@ class Agent(ControlFlowModel, abc.ABC):
     model_config = dict(arbitrary_types_allowed=True)
 
     id: str = Field(None)
-    name: str = Field(description="The name of the agent.")
+    name: str = Field(
+        default_factory=lambda: random.choice(AGENT_NAMES),
+        description="The name of the agent.",
+    )
     description: Optional[str] = Field(
         None, description="A description of the agent, visible to other agents."
     )
@@ -80,17 +85,17 @@ class Agent(ControlFlowModel, abc.ABC):
 
     _cm_stack: list[contextmanager] = []
 
-    def __init__(self, name: str = None, user_access: bool = None, **kwargs):
-        if name is not None:
-            kwargs["name"] = name
+    def __init__(self, instructions: str = None, **kwargs):
+        if instructions is not None:
+            kwargs["instructions"] = instructions
 
         # deprecated in 0.9
-        if user_access is not None:
+        if "user_access" in kwargs:
             warnings.warn(
                 "The `user_access` argument is deprecated. Use `interactive=True` instead.",
                 DeprecationWarning,
             )
-            kwargs["interactive"] = True
+            kwargs["interactive"] = kwargs.pop("user_access")
 
         if additional_instructions := get_instructions():
             kwargs["instructions"] = (
diff --git a/src/controlflow/agents/names.py b/src/controlflow/agents/names.py
index 5ca039a6..10331132 100644
--- a/src/controlflow/agents/names.py
+++ b/src/controlflow/agents/names.py
@@ -1,4 +1,4 @@
-AGENTS = [
+AGENT_NAMES = [
     "HAL 9000",
     "R2-D2",
     "C-3PO",
@@ -21,7 +21,7 @@
     "Norby",
 ]
 
-TEAMS = [
+TEAM_NAMES = [
     "Autobots",
     "Decepticons",
     "Borg",
diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py
index 50148bcf..d44f15b7 100644
--- a/tests/agents/test_agents.py
+++ b/tests/agents/test_agents.py
@@ -8,8 +8,12 @@
 
 
 class TestAgentInitialization:
+    def test_positional_arg(self):
+        agent = Agent("talk like a pirate")
+        assert agent.instructions == "talk like a pirate"
+
     def test_agent_default_model(self):
-        agent = Agent(name="Marvin")
+        agent = Agent()
 
         # None indicates it will be loaded from the default model
         assert agent.model is None
@@ -17,7 +21,7 @@ def test_agent_default_model(self):
 
     def test_agent_model(self):
         model = ChatOpenAI(model="gpt-4o-mini")
-        agent = Agent(name="Marvin", model=model)
+        agent = Agent(model=model)
 
         # None indicates it will be loaded from the default model
         assert agent.model is model
@@ -25,7 +29,7 @@ def test_agent_model(self):
 
     def test_agent_loads_instructions_at_creation(self):
         with instructions("test instruction"):
-            agent = Agent(name="Marvin")
+            agent = Agent()
 
         assert "test instruction" in agent.instructions
 
@@ -34,10 +38,10 @@ def test_stable_id(self):
         assert agent.id == "69dd1abd"
 
     def test_id_includes_instructions(self):
-        a1 = Agent(name="Test Agent")
-        a2 = Agent(name="Test Agent", instructions="abc")
-        a3 = Agent(name="Test Agent", instructions="def")
-        a4 = Agent(name="Test Agent", instructions="abc", description="xyz")
+        a1 = Agent()
+        a2 = Agent(instructions="abc")
+        a3 = Agent(instructions="def")
+        a4 = Agent(instructions="abc", description="xyz")
 
         assert a1.id != a2.id != a3.id != a4.id
 
@@ -79,16 +83,16 @@ def test_updating_the_default_model_updates_the_default_agent_model(self):
 
 class TestAgentPrompt:
     def test_default_prompt(self):
-        agent = Agent(name="Marvin")
+        agent = Agent()
         assert agent.prompt is None
 
     def test_default_template(self):
-        agent = Agent(name="Marvin")
+        agent = Agent()
         prompt = agent.get_prompt()
         assert prompt.startswith("# Agent")
 
     def test_custom_prompt(self):
-        agent = Agent(name="Marvin", prompt="Custom Prompt")
+        agent = Agent(prompt="Custom Prompt")
         prompt = agent.get_prompt()
         assert prompt == "Custom Prompt"
 
diff --git a/tests/deprecated/test_agent.py b/tests/deprecated/test_agent.py
index 0859ba85..41d776b9 100644
--- a/tests/deprecated/test_agent.py
+++ b/tests/deprecated/test_agent.py
@@ -6,5 +6,5 @@
 # deprecated in 0.9
 def test_user_access():
     with pytest.warns(DeprecationWarning):
-        a = controlflow.Agent(name="test", user_access=True)
-        assert a.interactive
+        a = controlflow.Agent(user_access=True)
+        assert a.interactive is True
diff --git a/tests/flows/test_flows.py b/tests/flows/test_flows.py
index 5f453849..9da9851f 100644
--- a/tests/flows/test_flows.py
+++ b/tests/flows/test_flows.py
@@ -167,12 +167,12 @@ def test_flow_sets_thread_id_for_history(self, tmpdir):
 
 class TestFlowCreatesDefaults:
     def test_flow_with_custom_agents(self):
-        agent1 = Agent(name="Agent 1")
+        agent1 = Agent()
         flow = Flow(agent=agent1)
         assert flow.agent == agent1
 
     def test_flow_agent_becomes_task_default(self):
-        agent = Agent(name="BB8")
+        agent = Agent()
         t1 = Task("t1")
         assert agent not in t1.get_agents()
         assert len(t1.get_agents()) == 1
diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py
index 10bf176e..e70b225b 100644
--- a/tests/tasks/test_tasks.py
+++ b/tests/tasks/test_tasks.py
@@ -114,7 +114,7 @@ def test_task_parent_context():
 
 
 def test_task_agent_assignment():
-    agent = Agent(name="Test Agent")
+    agent = Agent()
     task = SimpleTask(agents=[agent])
     assert task.agents == [agent]
     assert task.get_agents() == [agent]
@@ -126,7 +126,7 @@ def test_task_bad_agent_assignment():
 
 
 def test_task_loads_agent_from_parent():
-    agent = Agent(name="Test Agent")
+    agent = Agent()
     with SimpleTask(agents=[agent]):
         child = SimpleTask()
 
@@ -136,7 +136,7 @@ def test_task_loads_agent_from_parent():
 
 def test_task_loads_agent_from_flow():
     def_agent = controlflow.defaults.agent
-    agent = Agent(name="Test Agent")
+    agent = Agent()
     with Flow(agent=agent):
         task = SimpleTask()
 
@@ -156,8 +156,8 @@ def test_task_loads_agent_from_default_if_none_otherwise():
 
 
 def test_task_loads_agent_from_parent_before_flow():
-    agent1 = Agent(name="Test Agent 1")
-    agent2 = Agent(name="Test Agent 2")
+    agent1 = Agent()
+    agent2 = Agent()
     with Flow(agent=agent1):
         with SimpleTask(agents=[agent2]):
             child = SimpleTask()
diff --git a/tests/test_defaults.py b/tests/test_defaults.py
index 2bc35556..98ec5bf3 100644
--- a/tests/test_defaults.py
+++ b/tests/test_defaults.py
@@ -17,7 +17,7 @@ def test_default_model_failed_validation():
 def test_set_default_model():
     model = ChatOpenAI(temperature=0.1)
     controlflow.defaults.model = model
-    assert controlflow.Agent(name="Marvin").get_model() is model
+    assert controlflow.Agent().get_model() is model
 
 
 def test_default_agent_failed_validation():
@@ -29,9 +29,9 @@ def test_default_agent_failed_validation():
 
 
 def test_set_default_agent():
-    agent = controlflow.Agent(name="Marvin")
+    agent = controlflow.Agent()
     controlflow.defaults.agent = agent
-    assert controlflow.Task("").get_agents() == [agent]  # Updated to check agents list
+    assert controlflow.Task("").get_agents() == [agent]
 
 
 def test_default_history_failed_validation():
diff --git a/tests/test_settings.py b/tests/test_settings.py
index a9cbf391..7b5a392f 100644
--- a/tests/test_settings.py
+++ b/tests/test_settings.py
@@ -85,7 +85,7 @@ def test_import_without_default_api_key_errors_when_loading_model(monkeypatch):
                 ValueError,
                 match="No model provided and no default model could be loaded",
             ):
-                controlflow.Agent(name="Marvin").get_model()
+                controlflow.Agent().get_model()
     finally:
         defaults_module = importlib.import_module("controlflow.defaults")
         importlib.reload(defaults_module)