Skip to content

Commit

Permalink
Merge pull request #312 from PrefectHQ/small-updates
Browse files Browse the repository at this point in the history
remove old `typer` extra and update `json` -> `model_dump_json`
  • Loading branch information
jlowin authored Sep 18, 2024
2 parents 38f7c44 + 7940567 commit 623e485
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 25 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ authors = [
dependencies = [
"prefect>=3.0",
"jinja2>=3.1.4",
"langchain_core>=0.2,<0.3",
"langchain_core>=0.3",
"langchain_openai>=0.1.8",
"langchain-anthropic>=0.1.19",
"markdownify>=0.12.1",
"pydantic-settings>=2.2.1",
"textual>=0.61.1",
"tiktoken>=0.7.0",
"typer[all]>=0.10",
"typer>=0.10",
]
readme = "README.md"
requires-python = ">= 3.9"
Expand Down
6 changes: 3 additions & 3 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _validate_model(cls, model: Optional[Union[str, Any]]):
def _serialize_tools(self, tools: list[Tool]):
tools = controlflow.tools.as_tools(tools)
# tools are Pydantic 1 objects
return [t.dict(include={"name", "description"}) for t in tools]
return [t.model_dump(include={"name", "description"}) for t in tools]

def serialize_for_prompt(self) -> dict:
dct = self.model_dump(
Expand Down Expand Up @@ -304,7 +304,7 @@ def _run_model(
#### Payload
```json
{response.json(indent=2)}
{response.model_dump_json(indent=2)}
```
""",
description=f"LLM Response for Agent {self.name}",
Expand Down Expand Up @@ -361,7 +361,7 @@ async def _run_model_async(
#### Payload
```json
{response.json(indent=2)}
{response.model_dump_json(indent=2)}
```
""",
description=f"LLM Response for Agent {self.name}",
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class AgentMessage(Event):
@field_validator("message", mode="before")
def _message(cls, v):
if isinstance(v, BaseMessage):
v = v.dict()
v = v.model_dump()
v["type"] = "ai"
return v

Expand Down Expand Up @@ -93,7 +93,7 @@ class AgentMessageDelta(UnpersistedEvent):
@field_validator("delta", "snapshot", mode="before")
def _message(cls, v):
if isinstance(v, BaseMessage):
v = v.dict()
v = v.model_dump()
v["type"] = "AIMessageChunk"
return v

Expand Down
4 changes: 3 additions & 1 deletion src/controlflow/events/message_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def format_message_name(

def count_tokens(message: BaseMessage) -> int:
# always use gpt-3.5 token counter with the entire message object; we only need to be approximate here
return len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.json()))
return len(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.model_dump_json())
)


def trim_messages(
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def plan(

agent_dict = dict(enumerate(agents))
tool_dict = dict(
enumerate([t.dict(include={"name", "description"}) for t in tools])
enumerate([t.model_dump(include={"name", "description"}) for t in tools])
)

def validate_plan(plan: list[PlanTask]):
Expand Down
12 changes: 7 additions & 5 deletions tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_context_manager(self):


class TestHandlers:
class TestHandler(Handler):
class ExampleHandler(Handler):
def __init__(self):
self.events = []
self.agent_messages = []
Expand All @@ -176,8 +176,9 @@ def on_event(self, event: Event):
def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

def test_agent_run_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
@pytest.mark.usefixtures("default_fake_llm")
def test_agent_run_with_handlers(self):
handler = self.ExampleHandler()
agent = Agent()
agent.run(
"Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1
Expand All @@ -187,8 +188,9 @@ def test_agent_run_with_handlers(self, default_fake_llm):
assert len(handler.agent_messages) == 1

@pytest.mark.asyncio
async def test_agent_run_async_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
@pytest.mark.usefixtures("default_fake_llm")
async def test_agent_run_async_with_handlers(self):
handler = self.ExampleHandler()
agent = Agent()
await agent.run_async(
"Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1
Expand Down
9 changes: 4 additions & 5 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_list_of_literals_result(self):

def test_map_labels_to_values(self):
task = Task(
"Choose the right label",
"Choose the right label, in order provided in context",
context=dict(goals=["the second letter", "the first letter"]),
result_type=list[Literal["a", "b", "c"]],
)
Expand Down Expand Up @@ -523,7 +523,7 @@ class Person(BaseModel):


class TestHandlers:
class TestHandler(Handler):
class ExampleHandler(Handler):
def __init__(self):
self.events = []
self.agent_messages = []
Expand All @@ -535,16 +535,15 @@ def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

def test_task_run_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
task = Task(objective="Calculate 2 + 2", result_type=int)
task.run(handlers=[handler], max_llm_calls=1)

assert len(handler.events) > 0
assert len(handler.agent_messages) == 1

@pytest.mark.asyncio
async def test_task_run_async_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
task = Task(objective="Calculate 2 + 2", result_type=int)
await task.run_async(handlers=[handler], max_llm_calls=1)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class TestHandlers:
class TestHandler(Handler):
class ExampleHandler(Handler):
def __init__(self):
self.events = []
self.agent_messages = []
Expand All @@ -17,13 +17,13 @@ def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

def test_run_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
run("what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1)
assert len(handler.events) > 0
assert len(handler.agent_messages) == 1

async def test_run_async_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
await run_async(
"what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1
)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib

import openai
import pytest
from prefect.logging import get_logger

Expand Down Expand Up @@ -78,7 +79,9 @@ def test_import_without_default_api_key_errors_when_loading_model(monkeypatch):
importlib.reload(defaults_module)
importlib.reload(controlflow)

with pytest.raises(ValueError, match="Did not find openai_api_key"):
with pytest.raises(
openai.OpenAIError, match="api_key client option must be set"
):
controlflow.llm.models.get_default_model()

with pytest.raises(
Expand Down
4 changes: 2 additions & 2 deletions tests/tools/test_lc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class LCBaseToolInput(BaseModel):


class LCBaseTool(BaseTool):
name = "TestTool"
description = "A test tool"
name: str = "TestTool"
description: str = "A test tool"
args_schema: type[BaseModel] = LCBaseToolInput

def _run(self, x: int) -> str:
Expand Down

0 comments on commit 623e485

Please sign in to comment.