Skip to content

Commit

Permalink
Clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Nov 12, 2024
1 parent 072e460 commit 0b0c9fb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 40 deletions.
9 changes: 4 additions & 5 deletions src/controlflow/events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
HumanMessage,
ToolMessage,
)
from controlflow.tools.tools import InvalidToolCall, Tool
from controlflow.tools.tools import ToolCall as ToolCallPayload
from controlflow.tools.tools import InvalidToolCall, Tool, ToolCall
from controlflow.tools.tools import ToolResult as ToolResultPayload
from controlflow.utilities.logging import get_logger

Expand Down Expand Up @@ -136,13 +135,13 @@ def _finalize(self):

def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]:
deltas = []
for call_delta in self.message_delta["tool_call_chunks"]:
for call_delta in self.message_delta.get("tool_call_chunks", []):
# First match chunks by index because streaming chunks come in sequence (0,1,2...)
# and this index lets us correlate deltas to their snapshots during streaming
chunk_snapshot = next(
(
c
for c in self.message_snapshot["tool_call_chunks"]
for c in self.message_snapshot.get("tool_call_chunks", [])
if c.get("index", -1) == call_delta.get("index", -2)
),
None,
Expand Down Expand Up @@ -210,7 +209,7 @@ class AgentToolCall(Event):
event: Literal["tool-call"] = "tool-call"
agent: Agent
agent_message_id: Optional[str] = None
tool_call: Union[ToolCallPayload, InvalidToolCall]
tool_call: Union[ToolCall, InvalidToolCall]
tool: Optional[Tool] = None
args: dict = {}

Expand Down
31 changes: 6 additions & 25 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,32 +644,10 @@ async def _run_agent_turn_async(

async def _run_async(
self,
max_llm_calls: Optional[int] = None,
max_agent_turns: Optional[int] = None,
run_context: RunContext,
model_kwargs: Optional[dict] = None,
run_until: Optional[
Union[RunEndCondition, Callable[[RunContext], bool]]
] = None,
) -> AsyncIterator[Event]:
"""Async version of _run."""
# Create the base termination condition
if run_until is None:
run_until = AllComplete()
elif not isinstance(run_until, RunEndCondition):
run_until = FnCondition(run_until)

# Add max_llm_calls condition
if max_llm_calls is None:
max_llm_calls = controlflow.settings.orchestrator_max_llm_calls
run_until = run_until | MaxLLMCalls(max_llm_calls)

# Add max_agent_turns condition
if max_agent_turns is None:
max_agent_turns = controlflow.settings.orchestrator_max_agent_turns
run_until = run_until | MaxAgentTurns(max_agent_turns)

run_context = RunContext(orchestrator=self, run_end_condition=run_until)

"""Run the orchestrator asynchronously, yielding events as they occur."""
# Initialize the agent if not already set
if not self.agent:
self.agent = self.turn_strategy.get_next_agent(
Expand All @@ -686,6 +664,7 @@ async def _run_async(

yield AgentTurnStart(orchestrator=self, agent=self.agent)

# Run turn and yield its events
async for event in self._run_agent_turn_async(
run_context=run_context,
model_kwargs=model_kwargs,
Expand All @@ -701,10 +680,12 @@ async def _run_async(
)

except Exception as exc:
# Yield error event if something goes wrong
yield OrchestratorError(orchestrator=self, error=exc)
raise
finally:
yield OrchestratorEnd(orchestrator=self)
# Signal the end of orchestration
yield OrchestratorEnd(orchestrator=self, run_context=run_context)


# Rebuild all models with forward references after Orchestrator is defined
Expand Down
6 changes: 3 additions & 3 deletions tests/tools/test_lc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel

import controlflow
from controlflow.events.events import AgentToolCall, AIMessage
from controlflow.events.events import AIMessage, ToolCall


class LCBaseToolInput(BaseModel):
Expand All @@ -26,7 +26,7 @@ def test_lc_base_tool(default_fake_llm, monkeypatch):
AIMessage(
content="",
tool_calls=[
AgentToolCall(
ToolCall(
id="abc",
name="TestTool",
args={"x": 3},
Expand All @@ -52,7 +52,7 @@ def test_ddg_tool(default_fake_llm, monkeypatch):
AIMessage(
content="",
tool_calls=[
AgentToolCall(
ToolCall(
id="abc",
name="duckduckgo_search",
args={"query": "top business headlines"},
Expand Down
14 changes: 7 additions & 7 deletions tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def test_record_task_events(default_fake_llm):
assert response == events[1].ai_message

assert events[3].event == "tool-result"
assert events[3].tool_call == {
assert events[3].tool_result.tool_call == {
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
assert events[3].tool_result.model_dump() == dict(
tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe",
str_result='Task #12345 ("say hello") marked successful.',
is_error=False,
tool_metadata={"is_completion_tool": True},
)
tool_result = events[3].tool_result.model_dump()
assert tool_result["tool_call"]["id"] == "call_ZEPdV8mCgeBe5UHjKzm6e3pe"
assert tool_result["str_result"] == 'Task #12345 ("say hello") marked successful.'
assert not tool_result["is_error"]
assert tool_result["tool"]["metadata"]["is_completion_tool"]
assert tool_result["tool"]["metadata"]["is_success_tool"]

0 comments on commit 0b0c9fb

Please sign in to comment.