diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 44010868..274dfcd4 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -165,7 +165,7 @@ def test_context_manager(self): class TestHandlers: - class TestHandler(Handler): + class ExampleHandler(Handler): def __init__(self): self.events = [] self.agent_messages = [] @@ -176,8 +176,9 @@ def on_event(self, event: Event): def on_agent_message(self, event: AgentMessage): self.agent_messages.append(event) - def test_agent_run_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + @pytest.mark.usefixtures("default_fake_llm") + def test_agent_run_with_handlers(self): + handler = self.ExampleHandler() agent = Agent() agent.run( "Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1 @@ -187,8 +188,9 @@ def test_agent_run_with_handlers(self, default_fake_llm): assert len(handler.agent_messages) == 1 @pytest.mark.asyncio - async def test_agent_run_async_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + @pytest.mark.usefixtures("default_fake_llm") + async def test_agent_run_async_with_handlers(self): + handler = self.ExampleHandler() agent = Agent() await agent.run_async( "Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1 diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index f2da7123..a294544d 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -429,7 +429,7 @@ class Person(BaseModel): class TestHandlers: - class TestHandler(Handler): + class ExampleHandler(Handler): def __init__(self): self.events = [] self.agent_messages = [] @@ -441,16 +441,15 @@ def on_agent_message(self, event: AgentMessage): self.agent_messages.append(event) def test_task_run_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() task = Task(objective="Calculate 2 + 2", result_type=int) task.run(handlers=[handler], max_llm_calls=1) assert len(handler.events) > 0 assert len(handler.agent_messages) == 1 - @pytest.mark.asyncio async def test_task_run_async_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() task = Task(objective="Calculate 2 + 2", result_type=int) await task.run_async(handlers=[handler], max_llm_calls=1) diff --git a/tests/test_run.py b/tests/test_run.py index 41f5d470..d59d3ab8 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -5,7 +5,7 @@ class TestHandlers: - class TestHandler(Handler): + class ExampleHandler(Handler): def __init__(self): self.events = [] self.agent_messages = [] @@ -17,13 +17,13 @@ def on_agent_message(self, event: AgentMessage): self.agent_messages.append(event) def test_run_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() run("what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1) assert len(handler.events) > 0 assert len(handler.agent_messages) == 1 async def test_run_async_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() await run_async( "what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1 ) diff --git a/tests/tools/test_lc_tools.py b/tests/tools/test_lc_tools.py index 0a6d54b0..58d5220d 100644 --- a/tests/tools/test_lc_tools.py +++ b/tests/tools/test_lc_tools.py @@ -13,8 +13,8 @@ class LCBaseToolInput(BaseModel): class LCBaseTool(BaseTool): - name = "TestTool" - description = "A test tool" + name: str = "TestTool" + description: str = "A test tool" args_schema: type[BaseModel] = LCBaseToolInput def _run(self, x: int) -> str: