Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that agents always load the then-default model if one is not specified #125

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/controlflow/core/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional

from langchain_core.language_models import BaseChatModel
from pydantic import Field, field_serializer

import controlflow
Expand Down Expand Up @@ -60,11 +61,11 @@ class Agent(ControlFlowModel):
description="The memory object used by the agent. If not specified, an in-memory memory object will be used. Pass None to disable memory.",
)

# note: `model` should be typed as a BaseChatModel but V2 models can't have
# note: `model` should be typed as Optional[BaseChatModel] but V2 models can't have
# V1 attributes without erroring, so we have to use Any.
model: Any = Field(
model: Optional[Any] = Field(
None,
description="The LangChain BaseChatModel used by the agent. If not provided, the default model will be used.",
default_factory=get_default_model,
exclude=True,
)

Expand All @@ -85,6 +86,9 @@ def __init__(self, name=None, **kwargs):
kwargs["name"] = name
super().__init__(**kwargs)

def get_model(self) -> BaseChatModel:
return self.model or get_default_model()

def get_tools(self) -> list[Callable]:
tools = self.tools.copy()
if self.user_access:
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def run_once_async(self) -> list[MessageType]:
with ctx(agent=agent, flow=self.flow, controller=self):
response_gen = await completion_async(
messages=payload["messages"],
model=agent.model,
model=agent.get_model(),
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
Expand Down Expand Up @@ -264,7 +264,7 @@ def run_once(self) -> list[MessageType]:
with ctx(agent=agent, flow=self.flow, controller=self):
response_gen = completion(
messages=payload["messages"],
model=agent.model,
model=agent.get_model(),
tools=payload["tools"],
handlers=payload["handlers"],
max_iterations=1,
Expand Down
36 changes: 29 additions & 7 deletions tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from controlflow.core.agent import Agent, get_default_agent
from controlflow.core.agent.names import NAMES
from controlflow.core.task import Task
from langchain_openai import ChatOpenAI


class TestAgentInitialization:
Expand All @@ -13,17 +14,30 @@ def test_agent_gets_random_name(self):
def test_agent_default_model(self):
agent = Agent()

assert agent.model is controlflow.get_default_model()
# None indicates it will be loaded from the default model
assert agent.model is None
assert agent.get_model() is controlflow.get_default_model()

def test_agent_model(self):
model = ChatOpenAI(model="gpt-3.5-turbo")
agent = Agent(model=model)

# None indicates it will be loaded from the default model
assert agent.model is model
assert agent.get_model() is model


class TestDefaultAgent:
def test_default_agent_is_marvin(self):
agent = get_default_agent()
assert agent.name == "Marvin"
def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()

def test_default_agent_has_no_tools(self):
assert get_default_agent().tools == []

def test_default_agent_has_no_model(self):
assert get_default_agent().model is None

def test_default_agent_can_be_assigned(self):
# baseline
assert get_default_agent().name == "Marvin"
Expand All @@ -35,6 +49,14 @@ def test_default_agent_can_be_assigned(self):
assert Task("task").get_agents()[0] is new_default_agent
assert [a.name for a in Task("task").get_agents()] == ["New Agent"]

def test_default_agent(self):
assert get_default_agent().name == "Marvin"
assert Task("task").get_agents()[0] is get_default_agent()
def test_updating_the_default_model_updates_the_default_agent_model(self):
new_model = ChatOpenAI(model="gpt-3.5-turbo")
controlflow.default_model = new_model

new_agent = get_default_agent()
assert new_agent.model is None
assert new_agent.get_model() is new_model

task = Task("task")
assert task.get_agents()[0].model is None
assert task.get_agents()[0].get_model() is new_model