Skip to content

Ensure that agents always load the then-default model if one is not specified #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/controlflow/core/agent/agent.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional

from langchain_core.language_models import BaseChatModel
from pydantic import Field, field_serializer

import controlflow
@@ -60,11 +61,11 @@ class Agent(ControlFlowModel):
description="The memory object used by the agent. If not specified, an in-memory memory object will be used. Pass None to disable memory.",
)

# note: `model` should be typed as a BaseChatModel but V2 models can't have
# note: `model` should be typed as Optional[BaseChatModel] but V2 models can't have
# V1 attributes without erroring, so we have to use Any.
model: Any = Field(
model: Optional[Any] = Field(
None,
description="The LangChain BaseChatModel used by the agent. If not provided, the default model will be used.",
default_factory=get_default_model,
exclude=True,
)

@@ -85,6 +86,9 @@ def __init__(self, name=None, **kwargs):
kwargs["name"] = name
super().__init__(**kwargs)

def get_model(self) -> BaseChatModel:
return self.model or get_default_model()

def get_tools(self) -> list[Callable]:
tools = self.tools.copy()
if self.user_access:
4 changes: 2 additions & 2 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -228,7 +228,7 @@ async def run_once_async(self) -> list[MessageType]:
with ctx(agent=agent, flow=self.flow, controller=self):
response_gen = await completion_async(
messages=payload["messages"],
model=agent.model,
model=agent.get_model(),
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
@@ -264,7 +264,7 @@ def run_once(self) -> list[MessageType]:
with ctx(agent=agent, flow=self.flow, controller=self):
response_gen = completion(
messages=payload["messages"],
model=agent.model,
model=agent.get_model(),
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
36 changes: 29 additions & 7 deletions tests/core/test_agents.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from controlflow.core.agent import Agent, get_default_agent
from controlflow.core.agent.names import NAMES
from controlflow.core.task import Task
from langchain_openai import ChatOpenAI


class TestAgentInitialization:
@@ -13,17 +14,30 @@ def test_agent_gets_random_name(self):
def test_agent_default_model(self):
agent = Agent()

assert agent.model is controlflow.get_default_model()
# None indicates it will be loaded from the default model
assert agent.model is None
assert agent.get_model() is controlflow.get_default_model()

def test_agent_model(self):
model = ChatOpenAI(model="gpt-3.5-turbo")
agent = Agent(model=model)

# None indicates it will be loaded from the default model
assert agent.model is model
assert agent.get_model() is model


class TestDefaultAgent:
def test_default_agent_is_marvin(self):
agent = get_default_agent()
assert agent.name == "Marvin"
def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()

def test_default_agent_has_no_tools(self):
assert get_default_agent().tools == []

def test_default_agent_has_no_model(self):
assert get_default_agent().model is None

def test_default_agent_can_be_assigned(self):
# baseline
assert get_default_agent().name == "Marvin"
@@ -35,6 +49,14 @@ def test_default_agent_can_be_assigned(self):
assert Task("task").get_agents()[0] is new_default_agent
assert [a.name for a in Task("task").get_agents()] == ["New Agent"]

def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()
def test_updating_the_default_model_updates_the_default_agent_model(self):
new_model = ChatOpenAI(model="gpt-3.5-turbo")
controlflow.default_model = new_model

new_agent = get_default_agent()
assert new_agent.model is None
assert new_agent.get_model() is new_model

task = Task("task")
assert task.get_agents()[0].model is None
assert task.get_agents()[0].get_model() is new_model