Skip to content

Commit

Permalink
Merge pull request #313 from PrefectHQ/small-updates-test
Browse files Browse the repository at this point in the history
update tests
  • Loading branch information
zzstoatzz authored Sep 17, 2024
2 parents 3d493f6 + 699b276 commit a1aa9dd
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
12 changes: 7 additions & 5 deletions tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_context_manager(self):


class TestHandlers:
class TestHandler(Handler):
class ExampleHandler(Handler):
def __init__(self):
self.events = []
self.agent_messages = []
Expand All @@ -176,8 +176,9 @@ def on_event(self, event: Event):
def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

def test_agent_run_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
@pytest.mark.usefixtures("default_fake_llm")
def test_agent_run_with_handlers(self):
handler = self.ExampleHandler()
agent = Agent()
agent.run(
"Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1
Expand All @@ -187,8 +188,9 @@ def test_agent_run_with_handlers(self, default_fake_llm):
assert len(handler.agent_messages) == 1

@pytest.mark.asyncio
async def test_agent_run_async_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
@pytest.mark.usefixtures("default_fake_llm")
async def test_agent_run_async_with_handlers(self):
handler = self.ExampleHandler()
agent = Agent()
await agent.run_async(
"Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1
Expand Down
7 changes: 3 additions & 4 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ class Person(BaseModel):


class TestHandlers:
class TestHandler(Handler):
class ExampleHandler(Handler):
def __init__(self):
self.events = []
self.agent_messages = []
Expand All @@ -441,16 +441,15 @@ def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

def test_task_run_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
task = Task(objective="Calculate 2 + 2", result_type=int)
task.run(handlers=[handler], max_llm_calls=1)

assert len(handler.events) > 0
assert len(handler.agent_messages) == 1

@pytest.mark.asyncio
async def test_task_run_async_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
task = Task(objective="Calculate 2 + 2", result_type=int)
await task.run_async(handlers=[handler], max_llm_calls=1)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class TestHandlers:
class TestHandler(Handler):
class ExampleHandler(Handler):
def __init__(self):
self.events = []
self.agent_messages = []
Expand All @@ -17,13 +17,13 @@ def on_agent_message(self, event: AgentMessage):
self.agent_messages.append(event)

def test_run_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
run("what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1)
assert len(handler.events) > 0
assert len(handler.agent_messages) == 1

async def test_run_async_with_handlers(self, default_fake_llm):
handler = self.TestHandler()
handler = self.ExampleHandler()
await run_async(
"what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1
)
Expand Down
4 changes: 2 additions & 2 deletions tests/tools/test_lc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class LCBaseToolInput(BaseModel):


class LCBaseTool(BaseTool):
name = "TestTool"
description = "A test tool"
name: str = "TestTool"
description: str = "A test tool"
args_schema: type[BaseModel] = LCBaseToolInput

def _run(self, x: int) -> str:
Expand Down

0 comments on commit a1aa9dd

Please sign in to comment.