Skip to content

Commit

Permalink
Ensure that agents load the default model if one is specified
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Jun 17, 2024
1 parent c9a0531 commit fc118c0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
10 changes: 7 additions & 3 deletions src/controlflow/core/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional

from langchain_core.language_models import BaseChatModel
from pydantic import Field, field_serializer

import controlflow
Expand Down Expand Up @@ -60,11 +61,11 @@ class Agent(ControlFlowModel):
description="The memory object used by the agent. If not specified, an in-memory memory object will be used. Pass None to disable memory.",
)

# note: `model` should be typed as a BaseChatModel but V2 models can't have
# note: `model` should be typed as Optional[BaseChatModel] but V2 models can't have
# V1 attributes without erroring, so we have to use Any.
model: Any = Field(
model: Optional[Any] = Field(
None,
description="The LangChain BaseChatModel used by the agent. If not provided, the default model will be used.",
default_factory=get_default_model,
exclude=True,
)

Expand All @@ -85,6 +86,9 @@ def __init__(self, name=None, **kwargs):
kwargs["name"] = name
super().__init__(**kwargs)

def get_model(self) -> BaseChatModel:
return self.model or get_default_model()

def get_tools(self) -> list[Callable]:
tools = self.tools.copy()
if self.user_access:
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def run_once_async(self) -> list[MessageType]:
with ctx(agent=agent, flow=self.flow, controller=self):
response_gen = await completion_async(
messages=payload["messages"],
model=agent.model,
model=agent.get_model(),
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
Expand Down Expand Up @@ -264,7 +264,7 @@ def run_once(self) -> list[MessageType]:
with ctx(agent=agent, flow=self.flow, controller=self):
response_gen = completion(
messages=payload["messages"],
model=agent.model,
model=agent.get_model(),
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
Expand Down
36 changes: 29 additions & 7 deletions tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from controlflow.core.agent import Agent, get_default_agent
from controlflow.core.agent.names import NAMES
from controlflow.core.task import Task
from langchain_openai import ChatOpenAI


class TestAgentInitialization:
Expand All @@ -13,17 +14,30 @@ def test_agent_gets_random_name(self):
def test_agent_default_model(self):
agent = Agent()

assert agent.model is controlflow.get_default_model()
# None indicates it will be loaded from the default model
assert agent.model is None
assert agent.get_model() is controlflow.get_default_model()

def test_agent_model(self):
model = ChatOpenAI(model="gpt-3.5-turbo")
agent = Agent(model=model)

# None indicates it will be loaded from the default model
assert agent.model is model
assert agent.get_model() is model


class TestDefaultAgent:
def test_default_agent_is_marvin(self):
agent = get_default_agent()
assert agent.name == "Marvin"
def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()

def test_default_agent_has_no_tools(self):
assert get_default_agent().tools == []

def test_default_agent_has_no_model(self):
assert get_default_agent().model is None

def test_default_agent_can_be_assigned(self):
# baseline
assert get_default_agent().name == "Marvin"
Expand All @@ -35,6 +49,14 @@ def test_default_agent_can_be_assigned(self):
assert Task("task").get_agents()[0] is new_default_agent
assert [a.name for a in Task("task").get_agents()] == ["New Agent"]

def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()
def test_updating_the_default_model_updates_the_default_agent_model(self):
new_model = ChatOpenAI(model="gpt-3.5-turbo")
controlflow.default_model = new_model

new_agent = get_default_agent()
assert new_agent.model is None
assert new_agent.get_model() is new_model

task = Task("task")
assert task.get_agents()[0].model is None
assert task.get_agents()[0].get_model() is new_model

0 comments on commit fc118c0

Please sign in to comment.