Skip to content

Commit

Permalink
Merge pull request #347 from PrefectHQ/model-kwargs
Browse files Browse the repository at this point in the history
Allow model_kwargs to be passed to llm API
  • Loading branch information
jlowin authored Oct 3, 2024
2 parents eace95a + d237622 commit d821c21
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def _run_model(
messages: list[BaseMessage],
tools: list["Tool"],
stream: bool = True,
model_kwargs: Optional[dict] = None,
) -> Generator[Event, None, None]:
from controlflow.events.events import (
AgentMessage,
Expand All @@ -281,7 +282,7 @@ def _run_model(

if stream:
response = None
for delta in model.stream(messages):
for delta in model.stream(messages, **(model_kwargs or {})):
if response is None:
response = delta
else:
Expand Down Expand Up @@ -320,6 +321,7 @@ async def _run_model_async(
messages: list[BaseMessage],
tools: list["Tool"],
stream: bool = True,
model_kwargs: Optional[dict] = None,
) -> AsyncGenerator[Event, None]:
from controlflow.events.events import (
AgentMessage,
Expand All @@ -337,7 +339,7 @@ async def _run_model_async(

if stream:
response = None
async for delta in model.astream(messages):
async for delta in model.astream(messages, **(model_kwargs or {})):
if response is None:
response = delta
else:
Expand Down
48 changes: 40 additions & 8 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def get_memories(self) -> list[Memory]:

@prefect_task(task_run_name="Orchestrator.run()")
def run(
self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None
self,
max_llm_calls: Optional[int] = None,
max_agent_turns: Optional[int] = None,
model_kwargs: Optional[dict] = None,
):
import controlflow.events.orchestrator_events

Expand Down Expand Up @@ -176,7 +179,10 @@ def run(
)
)
turn_count += 1
call_count += self.run_agent_turn(max_llm_calls - call_count)
call_count += self.run_agent_turn(
max_llm_calls - call_count,
model_kwargs=model_kwargs,
)
self.handle_event(
controlflow.events.orchestrator_events.AgentTurnEnd(
orchestrator=self, agent=self.agent
Expand Down Expand Up @@ -207,7 +213,10 @@ def run(

@prefect_task
async def run_async(
self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None
self,
max_llm_calls: Optional[int] = None,
max_agent_turns: Optional[int] = None,
model_kwargs: Optional[dict] = None,
):
"""
Run the orchestration process asynchronously until completion or limits are reached.
Expand Down Expand Up @@ -255,7 +264,8 @@ async def run_async(
)
turn_count += 1
call_count += await self.run_agent_turn_async(
max_llm_calls - call_count
max_llm_calls - call_count,
model_kwargs=model_kwargs,
)
self.handle_event(
controlflow.events.orchestrator_events.AgentTurnEnd(
Expand Down Expand Up @@ -286,7 +296,11 @@ async def run_async(
)

@prefect_task(task_run_name="Agent turn: {self.agent.name}")
def run_agent_turn(self, max_llm_calls: Optional[int]) -> int:
def run_agent_turn(
self,
max_llm_calls: Optional[int],
model_kwargs: Optional[dict] = None,
) -> int:
"""
Run a single agent turn, which may consist of multiple LLM calls.
Expand Down Expand Up @@ -324,11 +338,19 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int:
logger.debug("No `ready` tasks to run")
break

if not any(t.is_incomplete() for t in self.tasks):
logger.debug("No incomplete tasks left")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

for event in self.agent._run_model(messages=messages, tools=tools):
for event in self.agent._run_model(
messages=messages,
tools=tools,
model_kwargs=model_kwargs,
):
self.handle_event(event)

# Check if we've reached the call limit within a turn
Expand All @@ -339,7 +361,11 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int:
return call_count

@prefect_task
async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int:
async def run_agent_turn_async(
self,
max_llm_calls: Optional[int],
model_kwargs: Optional[dict] = None,
) -> int:
"""
Run a single agent turn asynchronously, which may consist of multiple LLM calls.
Expand Down Expand Up @@ -377,12 +403,18 @@ async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int:
logger.debug("No `ready` tasks to run")
break

if not any(t.is_incomplete() for t in self.tasks):
logger.debug("No incomplete tasks left")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

async for event in self.agent._run_model_async(
messages=messages, tools=tools
messages=messages,
tools=tools,
model_kwargs=model_kwargs,
):
self.handle_event(event)

Expand Down
10 changes: 9 additions & 1 deletion src/controlflow/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

from prefect.context import TaskRunContext

Expand Down Expand Up @@ -27,6 +27,7 @@ def run_tasks(
max_llm_calls: int = None,
max_agent_turns: int = None,
handlers: list[Handler] = None,
model_kwargs: Optional[dict] = None,
) -> list[Any]:
"""
Run a list of tasks.
Expand All @@ -45,6 +46,7 @@ def run_tasks(
orchestrator.run(
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
model_kwargs=model_kwargs,
)

if raise_on_failure and any(t.is_failed() for t in tasks):
Expand All @@ -68,6 +70,7 @@ async def run_tasks_async(
max_llm_calls: int = None,
max_agent_turns: int = None,
handlers: list[Handler] = None,
model_kwargs: Optional[dict] = None,
):
"""
Run a list of tasks.
Expand All @@ -83,6 +86,7 @@ async def run_tasks_async(
await orchestrator.run_async(
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
model_kwargs=model_kwargs,
)

if raise_on_failure and any(t.is_failed() for t in tasks):
Expand All @@ -104,6 +108,7 @@ def run(
max_agent_turns: int = None,
raise_on_failure: bool = True,
handlers: list[Handler] = None,
model_kwargs: Optional[dict] = None,
**task_kwargs,
) -> Any:
task = Task(objective=objective, **task_kwargs)
Expand All @@ -114,6 +119,7 @@ def run(
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
handlers=handlers,
model_kwargs=model_kwargs,
)
return results[0]

Expand All @@ -128,6 +134,7 @@ async def run_async(
max_agent_turns: int = None,
raise_on_failure: bool = True,
handlers: list[Handler] = None,
model_kwargs: Optional[dict] = None,
**task_kwargs,
) -> Any:
task = Task(objective=objective, **task_kwargs)
Expand All @@ -140,5 +147,6 @@ async def run_async(
max_agent_turns=max_agent_turns,
raise_on_failure=raise_on_failure,
handlers=handlers,
model_kwargs=model_kwargs,
)
return results[0]
2 changes: 2 additions & 0 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def run(
max_agent_turns: int = None,
handlers: list["Handler"] = None,
raise_on_failure: bool = True,
model_kwargs: Optional[dict] = None,
) -> T:
"""
Run the task
Expand All @@ -393,6 +394,7 @@ def run(
max_agent_turns=max_agent_turns,
raise_on_failure=False,
handlers=handlers,
model_kwargs=model_kwargs,
)

if self.is_successful():
Expand Down

0 comments on commit d821c21

Please sign in to comment.