diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 7fa52d97..fb98d1b8 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -264,6 +264,7 @@ def _run_model( messages: list[BaseMessage], tools: list["Tool"], stream: bool = True, + model_kwargs: Optional[dict] = None, ) -> Generator[Event, None, None]: from controlflow.events.events import ( AgentMessage, @@ -281,7 +282,7 @@ def _run_model( if stream: response = None - for delta in model.stream(messages): + for delta in model.stream(messages, **(model_kwargs or {})): if response is None: response = delta else: @@ -320,6 +321,7 @@ async def _run_model_async( messages: list[BaseMessage], tools: list["Tool"], stream: bool = True, + model_kwargs: Optional[dict] = None, ) -> AsyncGenerator[Event, None]: from controlflow.events.events import ( AgentMessage, @@ -337,7 +339,7 @@ async def _run_model_async( if stream: response = None - async for delta in model.astream(messages): + async for delta in model.astream(messages, **(model_kwargs or {})): if response is None: response = delta else: diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 35757aa1..960f8a6a 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -137,7 +137,10 @@ def get_memories(self) -> list[Memory]: @prefect_task(task_run_name="Orchestrator.run()") def run( - self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None + self, + max_llm_calls: Optional[int] = None, + max_agent_turns: Optional[int] = None, + model_kwargs: Optional[dict] = None, ): import controlflow.events.orchestrator_events @@ -176,7 +179,10 @@ def run( ) ) turn_count += 1 - call_count += self.run_agent_turn(max_llm_calls - call_count) + call_count += self.run_agent_turn( + max_llm_calls - call_count, + model_kwargs=model_kwargs, + ) self.handle_event( controlflow.events.orchestrator_events.AgentTurnEnd( orchestrator=self, agent=self.agent @@ -207,7 +213,10 @@ def run( @prefect_task async def run_async( - self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None + self, + max_llm_calls: Optional[int] = None, + max_agent_turns: Optional[int] = None, + model_kwargs: Optional[dict] = None, ): """ Run the orchestration process asynchronously until completion or limits are reached. @@ -255,7 +264,8 @@ async def run_async( ) turn_count += 1 call_count += await self.run_agent_turn_async( - max_llm_calls - call_count + max_llm_calls - call_count, + model_kwargs=model_kwargs, ) self.handle_event( controlflow.events.orchestrator_events.AgentTurnEnd( @@ -286,7 +296,11 @@ async def run_async( ) @prefect_task(task_run_name="Agent turn: {self.agent.name}") - def run_agent_turn(self, max_llm_calls: Optional[int]) -> int: + def run_agent_turn( + self, + max_llm_calls: Optional[int], + model_kwargs: Optional[dict] = None, + ) -> int: """ Run a single agent turn, which may consist of multiple LLM calls. @@ -324,11 +338,19 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int: logger.debug("No `ready` tasks to run") break + if not any(t.is_incomplete() for t in self.tasks): + logger.debug("No incomplete tasks left") + break + call_count += 1 messages = self.compile_messages() tools = self.get_tools() - for event in self.agent._run_model(messages=messages, tools=tools): + for event in self.agent._run_model( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): self.handle_event(event) # Check if we've reached the call limit within a turn @@ -339,7 +361,11 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int: return call_count @prefect_task - async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int: + async def run_agent_turn_async( + self, + max_llm_calls: Optional[int], + model_kwargs: Optional[dict] = None, + ) -> int: """ Run a single agent turn asynchronously, which may consist of multiple LLM calls. @@ -377,12 +403,18 @@ async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int: logger.debug("No `ready` tasks to run") break + if not any(t.is_incomplete() for t in self.tasks): + logger.debug("No incomplete tasks left") + break + call_count += 1 messages = self.compile_messages() tools = self.get_tools() async for event in self.agent._run_model_async( - messages=messages, tools=tools + messages=messages, + tools=tools, + model_kwargs=model_kwargs, ): self.handle_event(event) diff --git a/src/controlflow/run.py b/src/controlflow/run.py index 0d38eb96..10a538ed 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from prefect.context import TaskRunContext @@ -27,6 +27,7 @@ def run_tasks( max_llm_calls: int = None, max_agent_turns: int = None, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, ) -> list[Any]: """ Run a list of tasks. @@ -45,6 +46,7 @@ def run_tasks( orchestrator.run( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, + model_kwargs=model_kwargs, ) if raise_on_failure and any(t.is_failed() for t in tasks): @@ -68,6 +70,7 @@ async def run_tasks_async( max_llm_calls: int = None, max_agent_turns: int = None, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, ): """ Run a list of tasks. @@ -83,6 +86,7 @@ async def run_tasks_async( await orchestrator.run_async( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, + model_kwargs=model_kwargs, ) if raise_on_failure and any(t.is_failed() for t in tasks): @@ -104,6 +108,7 @@ def run( max_agent_turns: int = None, raise_on_failure: bool = True, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -114,6 +119,7 @@ def run( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, handlers=handlers, + model_kwargs=model_kwargs, ) return results[0] @@ -128,6 +134,7 @@ async def run_async( max_agent_turns: int = None, raise_on_failure: bool = True, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -140,5 +147,6 @@ async def run_async( max_agent_turns=max_agent_turns, raise_on_failure=raise_on_failure, handlers=handlers, + model_kwargs=model_kwargs, ) return results[0] diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index bffe411e..c728d3b4 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -379,6 +379,7 @@ def run( max_agent_turns: int = None, handlers: list["Handler"] = None, raise_on_failure: bool = True, + model_kwargs: Optional[dict] = None, ) -> T: """ Run the task @@ -393,6 +394,7 @@ def run( max_agent_turns=max_agent_turns, raise_on_failure=False, handlers=handlers, + model_kwargs=model_kwargs, ) if self.is_successful():