diff --git a/src/controlflow/core/agent/agent.py b/src/controlflow/core/agent/agent.py index e8e40002..ac81f43f 100644 --- a/src/controlflow/core/agent/agent.py +++ b/src/controlflow/core/agent/agent.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Optional +from langchain_core.language_models import BaseChatModel from pydantic import Field, field_serializer import controlflow @@ -60,11 +61,11 @@ class Agent(ControlFlowModel): description="The memory object used by the agent. If not specified, an in-memory memory object will be used. Pass None to disable memory.", ) - # note: `model` should be typed as a BaseChatModel but V2 models can't have + # note: `model` should be typed as Optional[BaseChatModel] but V2 models can't have # V1 attributes without erroring, so we have to use Any. - model: Any = Field( + model: Optional[Any] = Field( + None, description="The LangChain BaseChatModel used by the agent. If not provided, the default model will be used.", - default_factory=get_default_model, exclude=True, ) @@ -85,6 +86,9 @@ def __init__(self, name=None, **kwargs): kwargs["name"] = name super().__init__(**kwargs) + def get_model(self) -> BaseChatModel: + return self.model or get_default_model() + def get_tools(self) -> list[Callable]: tools = self.tools.copy() if self.user_access: diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index 39659a08..70d42ac3 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -228,7 +228,7 @@ async def run_once_async(self) -> list[MessageType]: with ctx(agent=agent, flow=self.flow, controller=self): response_gen = await completion_async( messages=payload["messages"], - model=agent.model, + model=agent.get_model(), tools=payload["tools"], handlers=payload["handlers"], max_iterations=1, @@ -264,7 +264,7 @@ def run_once(self) -> list[MessageType]: with ctx(agent=agent, flow=self.flow, controller=self): response_gen = completion( messages=payload["messages"], - model=agent.model, + model=agent.get_model(), tools=payload["tools"], handlers=payload["handlers"], max_iterations=1, diff --git a/tests/core/test_agents.py b/tests/core/test_agents.py index 00760bd4..bddb6783 100644 --- a/tests/core/test_agents.py +++ b/tests/core/test_agents.py @@ -2,6 +2,7 @@ from controlflow.core.agent import Agent, get_default_agent from controlflow.core.agent.names import NAMES from controlflow.core.task import Task +from langchain_openai import ChatOpenAI class TestAgentInitialization: @@ -13,17 +14,30 @@ def test_agent_gets_random_name(self): def test_agent_default_model(self): agent = Agent() - assert agent.model is controlflow.get_default_model() + # None indicates it will be loaded from the default model + assert agent.model is None + assert agent.get_model() is controlflow.get_default_model() + + def test_agent_model(self): + model = ChatOpenAI(model="gpt-3.5-turbo") + agent = Agent(model=model) + + # None indicates it will be loaded from the default model + assert agent.model is model + assert agent.get_model() is model class TestDefaultAgent: - def test_default_agent_is_marvin(self): - agent = get_default_agent() - assert agent.name == "Marvin" + def test_default_agent(self): + assert get_default_agent().name == "Marvin" + assert Task("task").get_agents()[0] is get_default_agent() def test_default_agent_has_no_tools(self): assert get_default_agent().tools == [] + def test_default_agent_has_no_model(self): + assert get_default_agent().model is None + def test_default_agent_can_be_assigned(self): # baseline assert get_default_agent().name == "Marvin" @@ -35,6 +49,14 @@ def test_default_agent_can_be_assigned(self): assert Task("task").get_agents()[0] is new_default_agent assert [a.name for a in Task("task").get_agents()] == ["New Agent"] - def test_default_agent(self): - assert get_default_agent().name == "Marvin" - assert Task("task").get_agents()[0] is get_default_agent() + def test_updating_the_default_model_updates_the_default_agent_model(self): + new_model = ChatOpenAI(model="gpt-3.5-turbo") + controlflow.default_model = new_model + + new_agent = get_default_agent() + assert new_agent.model is None + assert new_agent.get_model() is new_model + + task = Task("task") + assert task.get_agents()[0].model is None + assert task.get_agents()[0].get_model() is new_model