diff --git a/examples/multi_agent_conversation.py b/examples/multi_agent_conversation.py index 3878db9d..bf712dbe 100644 --- a/examples/multi_agent_conversation.py +++ b/examples/multi_agent_conversation.py @@ -64,16 +64,15 @@ @flow -def demo(): - topic = "milk and cereal" +def demo(topic: str): task = Task( "Discuss a topic", agents=[jerry, george, elaine, kramer, newman], context=dict(topic=topic), - instructions="every agent should speak at least once", + instructions="every agent should speak at least once. only one agent per turn.", ) - task.run() + return task if __name__ == "__main__": - demo() + demo(topic="sandwiches") diff --git a/pyproject.toml b/pyproject.toml index 8b39bceb..140b2fb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,11 @@ authors = [ ] dependencies = [ "prefect>=3.0rc1", - "textual>=0.61.1", - "litellm>=1.37.17", "jinja2>=3.1.4", + "langchain_core>=0.2.4", + "langchain_openai>=0.1.8", "pydantic-settings>=2.2.1", + "textual>=0.61.1", "tiktoken>=0.7.0", ] readme = "README.md" @@ -47,10 +48,16 @@ tests = [ "pytest>=7.0", "pytest-timeout", "pytest-xdist", - "pre-commit>=3.7.0", "pandas", ] -dev = ["controlflow[tests]", "ipython", "pdbpp", "ruff>=0.3.4", "textual-dev"] +dev = [ + "controlflow[tests]", + "ipython", + "pdbpp", + "pre-commit", + "ruff>=0.3.4", + "textual-dev", +] [build-system] requires = ["hatchling"] @@ -84,7 +91,3 @@ skip-magic-trailing-comma = false [tool.pytest.ini_options] timeout = 120 asyncio_mode = "auto" -filterwarnings = [ - "ignore::DeprecationWarning:litellm.*", - "ignore::PendingDeprecationWarning:litellm.*", -] diff --git a/requirements-dev.lock b/requirements-dev.lock index 15a74a0a..846e246b 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -8,10 +8,6 @@ # with-sources: false -e file:. -aiohttp==3.9.5 - # via litellm -aiosignal==1.3.1 - # via aiohttp aiosqlite==0.20.0 # via prefect alembic==1.13.1 @@ -31,7 +27,6 @@ asgi-lifespan==2.1.0 asyncpg==0.29.0 # via prefect attrs==23.2.0 - # via aiohttp # via jsonschema # via referencing cachetools==5.3.3 @@ -49,7 +44,6 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via apprise - # via litellm # via prefect # via typer # via uvicorn @@ -77,13 +71,7 @@ fastapi==0.111.0 # via prefect fastapi-cli==0.0.4 # via fastapi -filelock==3.13.3 - # via huggingface-hub -frozenlist==1.4.1 - # via aiohttp - # via aiosignal fsspec==2024.3.1 - # via huggingface-hub # via prefect google-auth==2.29.0 # via kubernetes @@ -109,8 +97,6 @@ httpx==0.27.0 # via fastapi # via openai # via prefect -huggingface-hub==0.23.0 - # via tokenizers humanize==4.9.0 # via jinja2-humanize-extension # via prefect @@ -121,20 +107,17 @@ idna==3.6 # via email-validator # via httpx # via requests - # via yarl -importlib-metadata==7.0.0 - # via litellm importlib-resources==6.1.3 # via prefect jinja2==3.1.4 # via controlflow # via fastapi # via jinja2-humanize-extension - # via litellm # via prefect jinja2-humanize-extension==0.4.0 # via prefect jsonpatch==1.33 + # via langchain-core # via prefect jsonpointer==2.4 # via jsonpatch @@ -144,10 +127,15 @@ jsonschema-specifications==2023.12.1 # via jsonschema kubernetes==29.0.0 # via prefect +langchain-core==0.2.4 + # via controlflow + # via langchain-openai +langchain-openai==0.1.8 + # via controlflow +langsmith==0.1.74 + # via langchain-core linkify-it-py==2.0.3 # via markdown-it-py -litellm==1.37.17 - # via controlflow mako==1.3.2 # via alembic markdown==3.6 @@ -163,20 +151,18 @@ mdit-py-plugins==0.4.1 # via markdown-it-py mdurl==0.1.2 # via markdown-it-py -multidict==6.0.5 - # via aiohttp - # via yarl oauthlib==3.2.2 # via kubernetes # via requests-oauthlib openai==1.28.1 - # via litellm + # via langchain-openai orjson==3.10.0 # via fastapi + # via langsmith # via prefect -packaging==24.0 +packaging==23.2 # via docker - # via huggingface-hub + # via langchain-core # via prefect pathspec==0.12.1 # via prefect @@ -193,6 +179,8 @@ pycparser==2.22 # via cffi pydantic==2.7.2 # via fastapi + # via langchain-core + # via langsmith # via openai # via prefect # via pydantic-extra-types @@ -215,7 +203,6 @@ python-dateutil==2.9.0.post0 # via prefect # via time-machine python-dotenv==1.0.1 - # via litellm # via pydantic-settings # via uvicorn python-multipart==0.0.9 @@ -228,8 +215,8 @@ pytz==2024.1 # via prefect pyyaml==6.0.1 # via apprise - # via huggingface-hub # via kubernetes + # via langchain-core # via prefect # via uvicorn readchar==4.0.6 @@ -243,9 +230,8 @@ regex==2023.12.25 requests==2.31.0 # via apprise # via docker - # via huggingface-hub # via kubernetes - # via litellm + # via langsmith # via requests-oauthlib # via tiktoken requests-oauthlib==2.0.0 @@ -285,21 +271,20 @@ sqlalchemy==2.0.29 # via prefect starlette==0.37.2 # via fastapi +tenacity==8.3.0 + # via langchain-core text-unidecode==1.3 # via python-slugify textual==0.61.1 # via controlflow tiktoken==0.7.0 # via controlflow - # via litellm + # via langchain-openai time-machine==2.14.1 # via pendulum -tokenizers==0.19.1 - # via litellm toml==0.10.2 # via prefect tqdm==4.66.2 - # via huggingface-hub # via openai typer==0.12.3 # via fastapi-cli @@ -308,7 +293,6 @@ typing-extensions==4.10.0 # via aiosqlite # via alembic # via fastapi - # via huggingface-hub # via openai # via prefect # via pydantic @@ -342,7 +326,3 @@ websocket-client==1.7.0 websockets==12.0 # via prefect # via uvicorn -yarl==1.9.4 - # via aiohttp -zipp==3.18.1 - # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index 82399bed..846e246b 100644 --- a/requirements.lock +++ b/requirements.lock @@ -8,10 +8,6 @@ # with-sources: false -e file:. -aiohttp==3.9.5 - # via litellm -aiosignal==1.3.1 - # via aiohttp aiosqlite==0.20.0 # via prefect alembic==1.13.1 @@ -31,7 +27,6 @@ asgi-lifespan==2.1.0 asyncpg==0.29.0 # via prefect attrs==23.2.0 - # via aiohttp # via jsonschema # via referencing cachetools==5.3.3 @@ -49,7 +44,6 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via apprise - # via litellm # via prefect # via typer # via uvicorn @@ -77,13 +71,7 @@ fastapi==0.111.0 # via prefect fastapi-cli==0.0.4 # via fastapi -filelock==3.14.0 - # via huggingface-hub -frozenlist==1.4.1 - # via aiohttp - # via aiosignal fsspec==2024.3.1 - # via huggingface-hub # via prefect google-auth==2.29.0 # via kubernetes @@ -109,8 +97,6 @@ httpx==0.27.0 # via fastapi # via openai # via prefect -huggingface-hub==0.23.0 - # via tokenizers humanize==4.9.0 # via jinja2-humanize-extension # via prefect @@ -121,20 +107,17 @@ idna==3.6 # via email-validator # via httpx # via requests - # via yarl -importlib-metadata==7.0.0 - # via litellm importlib-resources==6.1.3 # via prefect jinja2==3.1.4 # via controlflow # via fastapi # via jinja2-humanize-extension - # via litellm # via prefect jinja2-humanize-extension==0.4.0 # via prefect jsonpatch==1.33 + # via langchain-core # via prefect jsonpointer==2.4 # via jsonpatch @@ -144,10 +127,15 @@ jsonschema-specifications==2023.12.1 # via jsonschema kubernetes==29.0.0 # via prefect +langchain-core==0.2.4 + # via controlflow + # via langchain-openai +langchain-openai==0.1.8 + # via controlflow +langsmith==0.1.74 + # via langchain-core linkify-it-py==2.0.3 # via markdown-it-py -litellm==1.37.17 - # via controlflow mako==1.3.2 # via alembic markdown==3.6 @@ -163,20 +151,18 @@ mdit-py-plugins==0.4.1 # via markdown-it-py mdurl==0.1.2 # via markdown-it-py -multidict==6.0.5 - # via aiohttp - # via yarl oauthlib==3.2.2 # via kubernetes # via requests-oauthlib openai==1.28.1 - # via litellm + # via langchain-openai orjson==3.10.0 # via fastapi + # via langsmith # via prefect -packaging==24.0 +packaging==23.2 # via docker - # via huggingface-hub + # via langchain-core # via prefect pathspec==0.12.1 # via prefect @@ -193,6 +179,8 @@ pycparser==2.22 # via cffi pydantic==2.7.2 # via fastapi + # via langchain-core + # via langsmith # via openai # via prefect # via pydantic-extra-types @@ -215,7 +203,6 @@ python-dateutil==2.9.0.post0 # via prefect # via time-machine python-dotenv==1.0.1 - # via litellm # via pydantic-settings # via uvicorn python-multipart==0.0.9 @@ -228,8 +215,8 @@ pytz==2024.1 # via prefect pyyaml==6.0.1 # via apprise - # via huggingface-hub # via kubernetes + # via langchain-core # via prefect # via uvicorn readchar==4.0.6 @@ -243,9 +230,8 @@ regex==2023.12.25 requests==2.31.0 # via apprise # via docker - # via huggingface-hub # via kubernetes - # via litellm + # via langsmith # via requests-oauthlib # via tiktoken requests-oauthlib==2.0.0 @@ -285,21 +271,20 @@ sqlalchemy==2.0.29 # via prefect starlette==0.37.2 # via fastapi +tenacity==8.3.0 + # via langchain-core text-unidecode==1.3 # via python-slugify textual==0.61.1 # via controlflow tiktoken==0.7.0 # via controlflow - # via litellm + # via langchain-openai time-machine==2.14.1 # via pendulum -tokenizers==0.19.1 - # via litellm toml==0.10.2 # via prefect tqdm==4.66.2 - # via huggingface-hub # via openai typer==0.12.3 # via fastapi-cli @@ -308,7 +293,6 @@ typing-extensions==4.10.0 # via aiosqlite # via alembic # via fastapi - # via huggingface-hub # via openai # via prefect # via pydantic @@ -342,7 +326,3 @@ websocket-client==1.7.0 websockets==12.0 # via prefect # via uvicorn -yarl==1.9.4 - # via aiohttp -zipp==3.18.1 - # via importlib-metadata diff --git a/src/controlflow/core/agent.py b/src/controlflow/core/agent.py index d233c062..d74759d9 100644 --- a/src/controlflow/core/agent.py +++ b/src/controlflow/core/agent.py @@ -5,6 +5,7 @@ import controlflow from controlflow.core.task import Task +from controlflow.llm.models import BaseChatModel, get_default_model from controlflow.tools.talk_to_human import talk_to_human from controlflow.utilities.types import ControlFlowModel @@ -33,9 +34,10 @@ class Agent(ControlFlowModel): False, description="If True, the agent is given tools for interacting with a human user.", ) - model: str = Field( + model: BaseChatModel = Field( description="The model used by the agent. If not provided, the default model will be used.", - default_factory=lambda: controlflow.settings.llm_model, + default_factory=get_default_model, + exclude=True, ) def __init__(self, name, **kwargs): diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index efffe940..30fb2c7d 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -1,7 +1,7 @@ import logging import math from collections import defaultdict -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from functools import cached_property from typing import Callable, Union @@ -17,7 +17,7 @@ from controlflow.llm.completions import completion, completion_async from controlflow.llm.handlers import PrintHandler, ResponseHandler, TUIHandler from controlflow.llm.history import History -from controlflow.llm.messages import AssistantMessage, ControlFlowMessage, SystemMessage +from controlflow.llm.messages import AIMessage, MessageType, SystemMessage from controlflow.tui.app import TUIApp as TUI from controlflow.utilities.context import ctx from controlflow.utilities.tasks import all_complete, any_incomplete @@ -25,11 +25,11 @@ logger = logging.getLogger(__name__) -def add_agent_name_to_message(msg: ControlFlowMessage): +def add_agent_name_to_message(msg: MessageType): """ If the message is from a named assistant, prefix the message with the assistant's name. """ - if isinstance(msg, AssistantMessage) and msg.name: + if isinstance(msg, AIMessage) and msg.name: msg = msg.model_copy(update={"content": f"{msg.name}: {msg.content}"}) return msg @@ -70,7 +70,7 @@ class Controller(BaseModel): enable_tui: bool = Field(default_factory=lambda: controlflow.settings.enable_tui) _iteration: int = 0 _should_stop: bool = False - _end_run_counts: dict = PrivateAttr(default_factory=lambda: defaultdict(int)) + _end_turn_counts: dict = PrivateAttr(default_factory=lambda: defaultdict(int)) @computed_field @cached_property @@ -98,13 +98,13 @@ def end_turn(): # the agent's name is used as the key to track the number of times key = getattr(ctx.get("controller_agent", None), "name", None) - self._end_run_counts[key] += 1 - if self._end_run_counts[key] >= 3: + self._end_turn_counts[key] += 1 + if self._end_turn_counts[key] >= 3: self._should_stop = True - self._end_run_counts[key] = 0 + self._end_turn_counts[key] = 0 return ( - f"Ending turn. {3 - self._end_run_counts[key]}" + f"Ending turn. {3 - self._end_turn_counts[key]}" " more uses will abort the workflow." ) @@ -129,96 +129,89 @@ async def tui(self): else: yield - @contextmanager def _run_once_payload(self): """ - Generate the payload for a single run of the controller. This is a context manager so it can be used with - both async and sync code without duplication. + Generate the payload for a single run of the controller. """ if all(t.is_complete() for t in self.tasks): - yield None return - # put the flow in context - with self.flow: - # TODO: show the agent the entire graph, not just immediate upstreams - # get the tasks to run - tasks = self.graph.ready_tasks() - # get the agents - agent_candidates = [a for t in tasks for a in t.get_agents() if t.is_ready] - if len({a.name for a in agent_candidates}) != len(agent_candidates): - raise ValueError( - "Multiple agents with the same name were found. Agents must have unique names." - ) - if self.agents: - agents = [a for a in agent_candidates if a in self.agents] - else: - agents = agent_candidates + # TODO: show the agent the entire graph, not just immediate upstreams + # get the tasks to run + tasks = self.graph.ready_tasks() + # get the agents + agent_candidates = [a for t in tasks for a in t.get_agents() if t.is_ready] + if len({a.name for a in agent_candidates}) != len(agent_candidates): + raise ValueError( + "Multiple agents with the same name were found. Agents must have unique names." + ) + if self.agents: + agents = [a for a in agent_candidates if a in self.agents] + else: + agents = agent_candidates - # select the next agent - if len(agents) == 0: - raise ValueError( - "No agents were provided that are assigned to tasks that are ready to be run." - ) - elif len(agents) == 1: - agent = agents[0] - else: - agent = self.choose_agent(agents=agents, tasks=tasks) - - with ctx(controller_agent=agent): - from controlflow.core.controller.instruction_template import ( - MainTemplate, - ) + # select the next agent + if len(agents) == 0: + raise ValueError( + "No agents were provided that are assigned to tasks that are ready to be run." + ) + elif len(agents) == 1: + agent = agents[0] + else: + agent = self.choose_agent(agents=agents, tasks=tasks) - tools = ( - self.flow.tools + agent.get_tools() + [self._create_end_turn_tool()] - ) + from controlflow.core.controller.instruction_template import ( + MainTemplate, + ) - # add tools for any inactive tasks that the agent is assigned to - assigned_tools = [] - for task in tasks: - if agent in task.get_agents(): - assigned_tools.extend(task.get_tools()) - if not assigned_tools: - raise ValueError( - f"Agent {agent.name} is not assigned to any of the tasks that are ready to be run." - ) - tools.extend(assigned_tools) - - instructions_template = MainTemplate( - agent=agent, - controller=self, - tasks=tasks, - context=self.context, - instructions=get_instructions(), - ) - instructions = instructions_template.render() - - # prepare messages - system_message = SystemMessage(content=instructions) - messages = self.history.load_messages(thread_id=self.flow.thread_id) - - # setup handlers - handlers = [] - if controlflow.settings.enable_tui: - handlers.append(TUIHandler()) - if controlflow.settings.enable_print_handler: - handlers.append(PrintHandler()) - - # yield the agent payload - yield dict( - agent=agent, - messages=[system_message] + messages, - tools=tools, - handlers=handlers, - message_preprocessor=add_agent_name_to_message, - ) + tools = self.flow.tools + agent.get_tools() + [self._create_end_turn_tool()] + + # add tools for any inactive tasks that the agent is assigned to + assigned_tools = [] + for task in tasks: + if agent in task.get_agents(): + assigned_tools.extend(task.get_tools()) + if not assigned_tools: + raise ValueError( + f"Agent {agent.name} is not assigned to any of the tasks that are ready to be run." + ) + tools.extend(assigned_tools) - self._iteration += 1 + # tools = [prefect.task(tool) for tool in tools] + + instructions_template = MainTemplate( + agent=agent, + controller=self, + tasks=tasks, + context=self.context, + instructions=get_instructions(), + ) + instructions = instructions_template.render() + + # prepare messages + system_message = SystemMessage(content=instructions) + messages = self.history.load_messages(thread_id=self.flow.thread_id) + + # setup handlers + handlers = [] + if controlflow.settings.enable_tui: + handlers.append(TUIHandler()) + if controlflow.settings.enable_print_handler: + handlers.append(PrintHandler()) + + # yield the agent payload + return dict( + agent=agent, + messages=[system_message] + messages, + tools=tools, + handlers=handlers, + # message_preprocessor=add_agent_name_to_message, + ) async def run_once_async(self): async with self.tui(): - with self._run_once_payload() as payload: + with self.flow: + payload = self._run_once_payload() if payload is not None: agent: Agent = payload.pop("agent") response_handler = ResponseHandler() @@ -230,8 +223,8 @@ async def run_once_async(self): tools=payload["tools"], handlers=payload["handlers"], max_iterations=1, - assistant_name=agent.name, - message_preprocessor=payload["message_preprocessor"], + # assistant_name=agent.name, + # message_preprocessor=payload["message_preprocessor"], stream=True, ) async for _ in response_gen: @@ -242,9 +235,11 @@ async def run_once_async(self): thread_id=self.flow.thread_id, messages=response_handler.response_messages, ) + self._iteration += 1 def run_once(self): - with self._run_once_payload() as payload: + with self.flow: + payload = self._run_once_payload() if payload is not None: agent: Agent = payload.pop("agent") response_handler = ResponseHandler() @@ -256,8 +251,8 @@ def run_once(self): tools=payload["tools"], handlers=payload["handlers"], max_iterations=1, - assistant_name=agent.name, - message_preprocessor=payload["message_preprocessor"], + # assistant_name=agent.name, + # message_preprocessor=payload["message_preprocessor"], stream=True, ) for _ in response_gen: @@ -268,6 +263,7 @@ def run_once(self): thread_id=self.flow.thread_id, messages=response_handler.response_messages, ) + self._iteration += 1 async def run_async(self): """ diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 77139a9a..d9ec62f0 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -28,7 +28,7 @@ import controlflow import controlflow.core from controlflow.instructions import get_instructions -from controlflow.llm.tools import annotate_fn +from controlflow.llm.tools import Tool from controlflow.tools.talk_to_human import talk_to_human from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger @@ -426,7 +426,7 @@ def succeed() -> str: def succeed(result: result_schema) -> str: # type: ignore return self.mark_successful(result=result) - return annotate_fn( + return Tool.from_function( succeed, name=f"mark_task_{self.id}_successful", description=f"Mark task {self.id} as successful.", @@ -438,7 +438,7 @@ def _create_fail_tool(self) -> Callable: Create an agent-compatible tool for failing this task. """ - return annotate_fn( + return Tool.from_function( self.mark_failed, name=f"mark_task_{self.id}_failed", description=f"Mark task {self.id} as failed. Only use when a technical issue like a broken tool or unresponsive human prevents completion.", @@ -449,7 +449,7 @@ def _create_skip_tool(self) -> Callable: """ Create an agent-compatible tool for skipping this task. """ - return annotate_fn( + return Tool.from_function( self.mark_skipped, name=f"mark_task_{self.id}_skipped", description=f"Mark task {self.id} as skipped. Only use when completing its parent task early.", diff --git a/src/controlflow/llm/classify.py b/src/controlflow/llm/classify.py index 1aceb3de..432613bc 100644 --- a/src/controlflow/llm/classify.py +++ b/src/controlflow/llm/classify.py @@ -1,9 +1,12 @@ -import litellm +from typing import Union + import tiktoken +from langchain_openai import AzureChatOpenAI, ChatOpenAI from pydantic import TypeAdapter import controlflow -from controlflow.llm.messages import AssistantMessage, SystemMessage, UserMessage +from controlflow.llm.messages import AIMessage, HumanMessage, SystemMessage +from controlflow.llm.models import BaseChatModel def classify( @@ -11,7 +14,7 @@ def classify( labels: list, instructions: str = None, context: dict = None, - model: str = None, + model: BaseChatModel = None, ): try: label_strings = [TypeAdapter(type(t)).dump_json(t).decode() for t in labels] @@ -38,7 +41,7 @@ def classify( {% endfor %} """ ).render(labels=label_strings), - UserMessage( + HumanMessage( """ ## Information to classify @@ -61,15 +64,15 @@ def classify( """ ).render(data=data, instructions=instructions, context=context), - AssistantMessage(""" + AIMessage(""" The best label for the data is Label number """), ] - model = model or controlflow.settings.llm_model + model = model or controlflow.llm.models.get_default_model() kwargs = {} - if model in litellm.models_by_provider["openai"]: + if isinstance(model, (ChatOpenAI, AzureChatOpenAI)): openai_kwargs = _openai_kwargs(model=model, n_labels=len(labels)) kwargs.update(openai_kwargs) else: @@ -90,8 +93,8 @@ def classify( return labels[index] -def _openai_kwargs(model: str, n_labels: int): - encoding = tiktoken.encoding_for_model(model) +def _openai_kwargs(model: Union[AzureChatOpenAI, ChatOpenAI], n_labels: int): + encoding = tiktoken.encoding_for_model(model.model_name) logit_bias = {} for i in range(n_labels): diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py index 00cc2718..1322e208 100644 --- a/src/controlflow/llm/completions.py +++ b/src/controlflow/llm/completions.py @@ -1,49 +1,38 @@ +import datetime import math -from typing import AsyncGenerator, Callable, Generator, Union +from typing import AsyncGenerator, Callable, Generator, Optional, Union, cast -import litellm -from litellm.utils import trim_messages +import langchain_core.language_models as lc_models import controlflow +import controlflow.llm.models from controlflow.llm.handlers import ( CompletionEvent, CompletionHandler, ResponseHandler, ) -from controlflow.llm.messages import ( - ControlFlowMessage, - as_cf_messages, - as_oai_messages, -) +from controlflow.llm.messages import AIMessage, AIMessageChunk, MessageType from controlflow.llm.tools import ( as_tools, - get_tool_calls, handle_tool_call, handle_tool_call_async, ) def _completion_generator( - messages: list[Union[dict, ControlFlowMessage]], - model: str, - tools: list[Callable], - assistant_name: str, + messages: list[MessageType], + model: lc_models.BaseChatModel, + tools: Optional[list[Callable]], max_iterations: int, - message_preprocessor: Callable[[ControlFlowMessage], ControlFlowMessage], stream: bool, **kwargs, ) -> Generator[CompletionEvent, None, None]: response_messages = [] response_message = None - if "api_key" not in kwargs: - kwargs["api_key"] = controlflow.settings.llm_api_key - if "api_version" not in kwargs: - kwargs["api_version"] = controlflow.settings.llm_api_version - if "api_base" not in kwargs: - kwargs["api_base"] = controlflow.settings.llm_api_base - - tools = as_tools(tools or []) + if tools: + tools = as_tools(tools) + model = model.bind_tools(tools) counter = 0 try: @@ -51,54 +40,58 @@ def _completion_generator( # continue as long as the last response message contains tool calls (or # there is no response message yet) - while not response_message or get_tool_calls(response_message): - # the input messages are the provided messages plus all response messages - # including tool calls and results - input_messages = as_oai_messages(messages + response_messages) - # apply message preprocessor if provided - if message_preprocessor: - input_messages = [ - m - for msg in input_messages - if (m := message_preprocessor(msg)) is not None - ] - response = litellm.completion( - model=model, - messages=trim_messages(input_messages, model=model), - tools=[t.model_dump() for t in tools] if tools else None, - stream=stream, - **kwargs, - ) - - # if streaming is enabled, we need to handle the deltas - if stream: - deltas = [] - for delta in response: + while not response_message or response_message.tool_calls: + timestamp = datetime.datetime.now(datetime.timezone.utc) + if not stream: + response_message = model.invoke( + input=messages + response_messages, + **kwargs, + ) + response_message = AIMessage.from_message(response_message) + + else: + deltas: list[AIMessageChunk] = [] + snapshot: AIMessageChunk = None + + for delta in model.stream( + input=messages + response_messages, + **kwargs, + ): + delta = AIMessageChunk.from_chunk(delta) deltas.append(delta) - snapshot = litellm.stream_chunk_builder(deltas) - delta_message, snapshot_message = as_cf_messages([delta, snapshot]) - delta_message.name, snapshot_message.name = ( - assistant_name, - assistant_name, - ) + + if snapshot is None: + snapshot = delta + else: + snapshot = snapshot + delta if len(deltas) == 1: + if delta.tool_call_chunks: + yield CompletionEvent( + type="tool_call_created", payload=dict(delta=delta) + ) + else: + yield CompletionEvent( + type="message_created", payload=dict(delta=delta) + ) + + if delta.tool_call_chunks: yield CompletionEvent( - type="message_created", payload=dict(delta=delta_message) + type="tool_call_delta", + payload=dict(delta=delta, snapshot=snapshot), + ) + else: + yield CompletionEvent( + type="message_delta", + payload=dict(delta=delta, snapshot=snapshot), ) - yield CompletionEvent( - type="message_delta", - payload=dict(delta=delta_message, snapshot=snapshot_message), - ) # the last snapshot message is the response message - response_message = snapshot_message + response_message = snapshot.to_message() - else: - [response_message] = as_cf_messages([response]) - response_message.name = assistant_name + response_message.timestamp = timestamp - if response_message.has_tool_calls(): + if response_message.tool_calls: yield CompletionEvent( type="tool_call_done", payload=dict(message=response_message) ) @@ -111,7 +104,7 @@ def _completion_generator( response_messages.append(response_message) # handle tool calls - for tool_call in get_tool_calls(response_message): + for tool_call in response_message.tool_calls: tool_result_message = handle_tool_call(tool_call, tools) yield CompletionEvent( type="tool_result_done", payload=dict(message=tool_result_message) @@ -130,75 +123,80 @@ def _completion_generator( async def _completion_async_generator( - messages: list[Union[dict, ControlFlowMessage]], - model: str, - tools: list[Callable], - assistant_name: str, + messages: list[MessageType], + model: lc_models.BaseChatModel, + tools: Optional[list[Callable]], max_iterations: int, - message_preprocessor: Callable[[ControlFlowMessage], ControlFlowMessage], stream: bool, **kwargs, ) -> AsyncGenerator[CompletionEvent, None]: response_messages = [] response_message = None - if "api_key" not in kwargs: - kwargs["api_key"] = controlflow.settings.llm_api_key - if "api_version" not in kwargs: - kwargs["api_version"] = controlflow.settings.llm_api_version - if "api_base" not in kwargs: - kwargs["api_base"] = controlflow.settings.llm_api_base - - tools = as_tools(tools or []) + if tools: + tools = as_tools(tools) + model = model.bind_tools(tools) counter = 0 try: yield CompletionEvent(type="start", payload={}) - while not response_message or get_tool_calls(response_message): - input_messages = as_oai_messages(messages + response_messages) - if message_preprocessor: - input_messages = [ - m - for msg in input_messages - if (m := message_preprocessor(msg)) is not None - ] - - response = await litellm.acompletion( - model=model, - messages=trim_messages(input_messages, model=model), - tools=[t.model_dump() for t in tools] if tools else None, - stream=stream, - **kwargs, - ) - - if stream: - deltas = [] - async for delta in response: + # continue as long as the last response message contains tool calls (or + # there is no response message yet) + while not response_message or response_message.tool_calls: + timestamp = datetime.datetime.now(datetime.timezone.utc) + if not stream: + response_message = await model.ainvoke( + input=messages + response_messages, + tools=tools or None, + **kwargs, + ) + response_message = AIMessage.from_message(response_message) + + else: + deltas: list[AIMessageChunk] = [] + snapshot: AIMessageChunk = None + + async for delta in model.astream( + input=messages + response_messages, + tools=tools or None, + **kwargs, + ): + delta = cast(AIMessageChunk, delta) deltas.append(delta) - snapshot = litellm.stream_chunk_builder(deltas) - delta_message, snapshot_message = as_cf_messages([delta, snapshot]) - delta_message.name, snapshot_message.name = ( - assistant_name, - assistant_name, - ) + + if snapshot is None: + snapshot = delta + else: + snapshot = snapshot + delta if len(deltas) == 1: + if delta.tool_call_chunks: + yield CompletionEvent( + type="tool_call_created", payload=dict(delta=delta) + ) + else: + yield CompletionEvent( + type="message_created", payload=dict(delta=delta) + ) + + if delta.tool_call_chunks: yield CompletionEvent( - type="message_created", payload=dict(delta=delta_message) + type="tool_call_delta", + payload=dict(delta=delta, snapshot=snapshot), + ) + else: + yield CompletionEvent( + type="message_delta", + payload=dict(delta=delta, snapshot=snapshot), ) - yield CompletionEvent( - type="message_delta", - payload=dict(delta=delta_message, snapshot=snapshot_message), - ) - response_message = snapshot_message + # the last snapshot message is the response message + response_message = snapshot.to_message() - else: - [response_message] = as_cf_messages([response]) - response_message.name = assistant_name + response_message.timestamp = timestamp - if response_message.has_tool_calls(): + if response_message.tool_calls: yield CompletionEvent( type="tool_call_done", payload=dict(message=response_message) ) @@ -207,9 +205,11 @@ async def _completion_async_generator( type="message_done", payload=dict(message=response_message) ) + # append the response message to the list of response messages response_messages.append(response_message) - for tool_call in get_tool_calls(response_message): + # handle tool calls + for tool_call in response_message.tool_calls: tool_result_message = await handle_tool_call_async(tool_call, tools) yield CompletionEvent( type="tool_result_done", payload=dict(message=tool_result_message) @@ -246,18 +246,16 @@ async def _handle_events_async( def completion( - messages: list[Union[dict, ControlFlowMessage]], - model: str = None, + messages: list[MessageType], + model: lc_models.BaseChatModel = None, tools: list[Callable] = None, - assistant_name: str = None, max_iterations: int = None, handlers: list[CompletionHandler] = None, - message_preprocessor: Callable[[ControlFlowMessage], ControlFlowMessage] = None, stream: bool = False, **kwargs, -) -> Union[list[ControlFlowMessage], Generator[ControlFlowMessage, None, None]]: +) -> Union[list[MessageType], Generator[MessageType, None, None]]: if model is None: - model = controlflow.settings.llm_model + model = controlflow.llm.models.get_default_model() response_handler = ResponseHandler() handlers = handlers or [] @@ -267,9 +265,7 @@ def completion( messages=messages, model=model, tools=tools, - assistant_name=assistant_name, max_iterations=max_iterations, - message_preprocessor=message_preprocessor, stream=stream, **kwargs, ) @@ -285,18 +281,16 @@ def completion( async def completion_async( - messages: list[Union[dict, ControlFlowMessage]], - model: str = None, + messages: list[MessageType], + model: lc_models.BaseChatModel = None, tools: list[Callable] = None, - assistant_name: str = None, max_iterations: int = None, handlers: list[CompletionHandler] = None, - message_preprocessor: Callable[[ControlFlowMessage], ControlFlowMessage] = None, stream: bool = False, **kwargs, -) -> Union[list[ControlFlowMessage], Generator[ControlFlowMessage, None, None]]: +) -> Union[list[MessageType], Generator[MessageType, None, None]]: if model is None: - model = controlflow.settings.llm_model + model = controlflow.llm.models.get_default_model() response_handler = ResponseHandler() handlers = handlers or [] @@ -306,9 +300,7 @@ async def completion_async( messages=messages, model=model, tools=tools, - assistant_name=assistant_name, max_iterations=max_iterations, - message_preprocessor=message_preprocessor, stream=stream, **kwargs, ) diff --git a/src/controlflow/llm/formatting.py b/src/controlflow/llm/formatting.py index a97cc9a2..438145db 100644 --- a/src/controlflow/llm/formatting.py +++ b/src/controlflow/llm/formatting.py @@ -7,11 +7,10 @@ from rich.panel import Panel from controlflow.llm.messages import ( - AssistantMessage, + AIMessage, MessageType, ToolMessage, ) -from controlflow.llm.tools import get_tool_calls ROLE_COLORS = { "system": "gray", @@ -34,8 +33,8 @@ def format_message( ) -> Panel: if isinstance(message, ToolMessage): return format_tool_message(message) - elif get_tool_calls(message): - return format_assistant_message_with_tool_calls(message) + elif isinstance(message, AIMessage) and message.tool_calls: + return format_ai_message_with_tool_calls(message) else: return format_text_message(message) @@ -60,9 +59,9 @@ def format_text_message(message: MessageType) -> Panel: ) -def format_assistant_message_with_tool_calls(message: AssistantMessage) -> Group: +def format_ai_message_with_tool_calls(message: AIMessage) -> Group: panels = [] - for tool_call in get_tool_calls(message): + for tool_call in message.tool_calls: if message.role == "assistant" and message.name: title = f"Tool Call: {message.name}" else: @@ -75,9 +74,7 @@ def format_assistant_message_with_tool_calls(message: AssistantMessage) -> Group ```json {args} ``` - """).format( - name=tool_call.function.name, args=tool_call.function.arguments - ) + """).format(name=tool_call["name"], args=tool_call["args"]) ) panels.append( @@ -99,11 +96,13 @@ def format_assistant_message_with_tool_calls(message: AssistantMessage) -> Group def format_tool_message(message: ToolMessage) -> Panel: if message.tool_metadata.get("is_failed"): - content = f"❌ The tool call to [markdown.code]{message.tool_call.function.name}[/] failed." + content = ( + f"❌ The tool call to [markdown.code]{message.tool_call['name']}[/] failed." + ) elif not message.tool_metadata.get("is_task_status_tool"): content_type = "json" if isinstance(message.tool_result, (dict, list)) else "" content = Group( - f"✅ Received output from the [markdown.code]{message.tool_call.function.name}[/] tool.\n", + f"✅ Received output from the [markdown.code]{message.tool_call['name']}[/] tool.\n", Markdown(f"```{content_type}\n{message.content or ''}\n```"), ) else: diff --git a/src/controlflow/llm/handlers.py b/src/controlflow/llm/handlers.py index a25f93fb..3e728f8d 100644 --- a/src/controlflow/llm/handlers.py +++ b/src/controlflow/llm/handlers.py @@ -3,8 +3,9 @@ from controlflow.llm.formatting import format_message from controlflow.llm.messages import ( - AssistantMessage, - ControlFlowMessage, + AIMessage, + AIMessageChunk, + MessageType, ToolMessage, ) from controlflow.utilities.context import ctx @@ -34,28 +35,28 @@ def on_end(self): def on_exception(self, exc: Exception): pass - def on_message_created(self, delta: AssistantMessage): + def on_message_created(self, delta: AIMessageChunk): pass - def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage): + def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): pass - def on_message_done(self, message: AssistantMessage): + def on_message_done(self, message: AIMessage): pass - def on_tool_call_created(self, delta: AssistantMessage): + def on_tool_call_created(self, delta: AIMessageChunk): pass - def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage): + def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): pass - def on_tool_call_done(self, message: AssistantMessage): + def on_tool_call_done(self, message: AIMessage): pass def on_tool_result_done(self, message: ToolMessage): pass - def on_response_message(self, message: ControlFlowMessage): + def on_response_message(self, message: MessageType): """ This handler is called whenever a message is generated that should be included in the completion history (e.g. a `message`, `tool_call` or @@ -74,18 +75,16 @@ class ResponseHandler(CompletionHandler): def __init__(self): self.response_messages = [] - def on_response_message(self, message: ControlFlowMessage): + def on_response_message(self, message: MessageType): self.response_messages.append(message) class TUIHandler(CompletionHandler): - def on_message_delta( - self, delta: AssistantMessage, snapshot: AssistantMessage - ) -> None: + def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk) -> None: if tui := ctx.get("tui"): tui.update_message(message=snapshot) - def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage): + def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): if tui := ctx.get("tui"): tui.update_message(message=snapshot) @@ -96,7 +95,7 @@ def on_tool_result_done(self, message: ToolMessage): class PrintHandler(CompletionHandler): def __init__(self): - self.messages: dict[str, ControlFlowMessage] = {} + self.messages: dict[str, MessageType] = {} self.live = Live(auto_refresh=False) def on_start(self): @@ -116,11 +115,11 @@ def update_live(self): self.live.update(Group(*content), refresh=True) - def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage): + def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): self.messages[snapshot.id] = snapshot self.update_live() - def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage): + def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): self.messages[snapshot.id] = snapshot self.update_live() diff --git a/src/controlflow/llm/history.py b/src/controlflow/llm/history.py index f26ab623..c9ecb8d6 100644 --- a/src/controlflow/llm/history.py +++ b/src/controlflow/llm/history.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import ClassVar -from litellm.utils import trim_messages from pydantic import Field, field_validator import controlflow @@ -28,20 +27,6 @@ def load_messages( ) -> list[MessageType]: raise NotImplementedError() - def load_messages_to_token_limit( - self, thread_id: str, model: str = None - ) -> list[MessageType]: - messages = [] - # as long as the messages are not trimmed, keep loading more - while messages == (trim := trim_messages(messages, model=model)): - batch = self.load_messages( - thread_id, - limit=50, - before=None if not messages else messages[-1].timestamp, - ) - messages.extend(batch) - return trim - @abc.abstractmethod def save_messages(self, thread_id: str, messages: list[MessageType]): raise NotImplementedError() diff --git a/src/controlflow/llm/messages.py b/src/controlflow/llm/messages.py index 91e2928e..d4e22cca 100644 --- a/src/controlflow/llm/messages.py +++ b/src/controlflow/llm/messages.py @@ -1,263 +1,84 @@ import datetime -import inspect -import json import uuid -from typing import Any, List, Literal, Optional, Union +from typing import Any, Literal, Union -import litellm -from pydantic import ( - Field, - TypeAdapter, - field_serializer, - field_validator, - model_validator, -) +import langchain_core.messages +from langchain_core.messages import ToolCall +from pydantic.v1 import Field as v1_Field from controlflow.utilities.jinja import jinja_env -from controlflow.utilities.types import _OpenAIBaseType -# ----------------------------------------------- -# Messages -# ----------------------------------------------- +class MessageMixin(langchain_core.messages.BaseMessage): + class Config: + validate_assignment = True -Role = Literal["system", "user", "assistant", "tool"] + # default ID value + id: str = v1_Field(default_factory=lambda: uuid.uuid4().hex) - -class TextContent(_OpenAIBaseType): - type: Literal["text"] = "text" - text: str - - -class ImageDetails(_OpenAIBaseType): - url: str - detail: Literal["auto", "high", "low"] = "auto" - - -class ImageContent(_OpenAIBaseType): - type: Literal["image_url"] = "image_url" - image_url: ImageDetails - - -class ControlFlowMessage(_OpenAIBaseType): - # ---- begin openai fields - role: Role = Field(openai_field=True) - content: Optional[Union[str, List[Union[TextContent, ImageContent]]]] = None - _openai_fields: set[str] = {"role"} - # ---- end openai fields - - id: str = Field(default_factory=lambda: uuid.uuid4().hex, repr=False) - timestamp: datetime.datetime = Field( + # add timestamp + timestamp: datetime.datetime = v1_Field( default_factory=lambda: datetime.datetime.now(datetime.timezone.utc), ) - llm_response: Optional[litellm.ModelResponse] = Field(None, repr=False) - - def __init__( - self, - content: Optional[Union[str, List[Union[TextContent, ImageContent]]]] = None, - **kwargs, - ): - # allow content to be passed as a positional argument - super().__init__(content=content, **kwargs) - - @field_validator("role", mode="before") - def _lowercase_role(cls, v): - if isinstance(v, str): - v = v.lower() - return v - - @field_validator("timestamp", mode="before") - def _validate_timestamp(cls, v): - if isinstance(v, int): - v = datetime.datetime.fromtimestamp(v, tz=datetime.timezone.utc) - return v - - @model_validator(mode="after") - def _finalize(self): - self._openai_fields = ( - getattr(super(), "_openai_fields", set()) | self._openai_fields - ) - return self - - @field_serializer("timestamp") - def _serialize_timestamp(self, timestamp: datetime.datetime): - return timestamp.isoformat() - - def as_openai_message(self, include: set[str] = None, **kwargs) -> dict: - include = self._openai_fields | (include or set()) - return self.model_dump(include=include, **kwargs) - - -class SystemMessage(ControlFlowMessage): - # ---- begin openai fields - role: Literal["system"] = "system" - content: str - name: Optional[str] = None - _openai_fields = {"role", "content", "name"} - # ---- end openai fields - - @field_validator("content", mode="before") - def _validate_content(cls, v): - if isinstance(v, str): - v = inspect.cleandoc(v) - return v - def render(self, **kwargs) -> "SystemMessage": + def render(self, **kwargs) -> "MessageType": """ Renders the content as a jinja template with the given keyword arguments - and returns a new SystemMessage. + and returns a new Message. """ content = jinja_env.from_string(self.content).render(**kwargs) - return self.model_copy(update=dict(content=content)) + return self.copy(update=dict(content=content)) -class UserMessage(ControlFlowMessage): - # ---- begin openai fields - role: Literal["user"] = "user" - content: List[Union[TextContent, ImageContent]] - name: Optional[str] = None - _openai_fields = {"role", "content", "name"} - # ---- end openai fields +class HumanMessage(langchain_core.messages.HumanMessage, MessageMixin): + role: Literal["human"] = v1_Field("human", exclude=True) - @field_validator("content", mode="before") - def _validate_content(cls, v): - if isinstance(v, str): - v = inspect.cleandoc(v) - v = [TextContent(text=v)] - return v - def render(self, **kwargs) -> "UserMessage": - """ - Renders the content as a jinja template with the given keyword arguments - and returns a new SystemMessage. - """ - content = [] - for c in self.content: - if isinstance(c, TextContent): - text = jinja_env.from_string(c.text).render(**kwargs) - content.append(TextContent(text=text)) - else: - content.append(c) - return self.model_copy(update=dict(content=content)) - - -class AssistantMessage(ControlFlowMessage): - """A message from the assistant.""" - - # ---- begin openai fields - role: Literal["assistant"] = "assistant" - content: Optional[str] = None - name: Optional[str] = None - tool_calls: Optional[List["ToolCall"]] = None - _openai_fields = {"role", "content", "name", "tool_calls"} - # ---- end openai fields - - is_delta: bool = Field( - default=False, - description="If True, this message is a streamed delta, or chunk, of a full message.", - ) +class AIMessage(langchain_core.messages.AIMessage, MessageMixin): + role: Literal["ai"] = v1_Field("ai", exclude=True) - @field_validator("content", mode="before") - def _validate_content(cls, v): - if isinstance(v, str): - v = inspect.cleandoc(v) - return v + def has_tool_calls(self) -> bool: + return any(self.tool_calls) - def has_tool_calls(self): - return bool(self.tool_calls) - - def render(self, **kwargs) -> "AssistantMessage": - """ - Renders the content as a jinja template with the given keyword arguments - and returns a new AssistantMessage. - """ - if self.content is None: - content = self.content - else: - content = jinja_env.from_string(self.content).render(**kwargs) - return self.model_copy(update=dict(content=content)) + @classmethod + def from_message(cls, message: langchain_core.messages.AIMessage, **kwargs): + return cls(**dict(message) | kwargs | {"role": "ai"}) -class ToolMessage(ControlFlowMessage): - """A message for reporting the result of a tool call.""" +class AIMessageChunk(langchain_core.messages.AIMessageChunk, AIMessage): + role: Literal["ai"] = v1_Field("ai", exclude=True) - # ---- begin openai fields - role: Literal["tool"] = "tool" - content: str = Field(description="The string result of the tool call.") - tool_call_id: str = Field(description="The ID of the tool call.") - _openai_fields = {"role", "content", "tool_call_id"} - # ---- end openai fields + def has_tool_calls(self) -> bool: + return any(self.tool_call_chunks) - tool_call: "ToolCall" = Field(repr=False) - tool_result: Any = Field(None, exclude=True) - tool_metadata: dict = Field(default_factory=dict) + @classmethod + def from_chunk( + cls, chunk: langchain_core.messages.AIMessageChunk, **kwargs + ) -> "AIMessageChunk": + return cls(**dict(chunk) | kwargs | {"role": "ai"}) - @field_validator("content", mode="before") - def _validate_content(cls, v): - if isinstance(v, str): - v = inspect.cleandoc(v) - return v + def to_message(self, **kwargs) -> AIMessage: + return AIMessage(**self.dict(exclude={"type"}) | kwargs) + def __add__(self, other: Any) -> "AIMessageChunk": # type: ignore + result = super().__add__(other) + result.timestamp = self.timestamp + return result -MessageType = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] +class SystemMessage(langchain_core.messages.SystemMessage, MessageMixin): + role: Literal["system"] = v1_Field("system", exclude=True) -class ToolCall(_OpenAIBaseType): - id: Optional[str] - type: Literal["function"] = "function" - function: "ToolCallFunction" +class ToolMessage(langchain_core.messages.ToolMessage, MessageMixin): + class Config: + arbitrary_types_allowed = True -class ToolCallFunction(_OpenAIBaseType): - name: Optional[str] - arguments: str - - def json_arguments(self): - return json.loads(self.arguments) - - -def as_cf_messages( - messages: list[Union[litellm.Message, litellm.ModelResponse]], -) -> list[Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]]: - message_ta = TypeAdapter( - Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] - ) + role: Literal["tool"] = v1_Field("tool", exclude=True) - result = [] - for msg in messages: - if isinstance(msg, ControlFlowMessage): - result.append(msg) - elif isinstance(msg, litellm.Message): - new_msg = message_ta.validate_python(msg.model_dump()) - result.append(new_msg) - elif isinstance(msg, litellm.ModelResponse): - for i, choice in enumerate(msg.choices): - # handle delta messages streaming from the assistant; sometimes these are not fully populated - if hasattr(choice, "delta"): - values = choice.delta.model_dump() - if values["content"] is None: - values["content"] = "" - if values["role"] is None: - values["role"] = "assistant" - new_msg = AssistantMessage(**values, is_delta=True) - else: - new_msg = message_ta.validate_python(choice.message.model_dump()) - new_msg.id = f"{msg.id}-{i}" - new_msg.timestamp = msg.created - new_msg.llm_response = msg - result.append(new_msg) - else: - raise ValueError(f"Invalid message type: {type(msg)}") - return result + tool_call: ToolCall + tool_result: Any = v1_Field(exclude=True) + tool_metadata: dict[str, Any] = v1_Field(default_factory=dict) -def as_oai_messages(messages: list[Union[dict, ControlFlowMessage, litellm.Message]]): - result = [] - for msg in messages: - if isinstance(msg, ControlFlowMessage): - result.append(msg.as_openai_message()) - elif isinstance(msg, (dict, litellm.Message)): - result.append(msg) - else: - raise ValueError(f"Invalid message type: {type(msg)}") - return result +MessageType = Union[HumanMessage, AIMessage, SystemMessage, ToolMessage] diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py new file mode 100644 index 00000000..eccdcb6c --- /dev/null +++ b/src/controlflow/llm/models.py @@ -0,0 +1,44 @@ +from langchain_core.language_models import BaseChatModel + +import controlflow + + +def model_from_string(model: str) -> BaseChatModel: + if "/" not in model: + provider, model = "openai", model + provider, model = model.split("/") + + if provider == "openai": + try: + from langchain_openai import ChatOpenAI + except ImportError: + raise ImportError( + "To use OpenAI models, please install the `langchain-openai` package." + ) + cls = ChatOpenAI + elif provider == "azure_openai": + try: + from langchain_openai import AzureChatOpenAI + except ImportError: + raise ImportError( + "To use Azure OpenAI models, please install the `langchain-openai` package." + ) + cls = AzureChatOpenAI + elif provider == "anthropic": + try: + from langchain_anthropic import ChatAnthropic + except ImportError: + raise ImportError( + "To use Anthropic models, please install the `langchain-anthropic` package." + ) + cls = ChatAnthropic + else: + raise ValueError( + f"Could not load provider automatically: {provider}. Please create your model manually." + ) + + return cls(model=model) + + +def get_default_model() -> BaseChatModel: + return model_from_string(model=controlflow.settings.llm_model) diff --git a/src/controlflow/llm/tools.py b/src/controlflow/llm/tools.py index ad483634..d366674d 100644 --- a/src/controlflow/llm/tools.py +++ b/src/controlflow/llm/tools.py @@ -1,134 +1,85 @@ import functools import inspect -from functools import partial, update_wrapper -from typing import Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union +import langchain_core +import langchain_core.tools import pydantic +import pydantic.v1 +from langchain_core.messages import ToolCall from prefect.utilities.asyncutils import run_coro_as_sync +from pydantic import Field, create_model -from controlflow.llm.messages import ( - AssistantMessage, - ControlFlowMessage, - ToolCall, - ToolMessage, -) -from controlflow.utilities.types import ControlFlowModel +if TYPE_CHECKING: + from controlflow.llm.messages import ToolMessage -class ToolFunction(ControlFlowModel): - name: str - parameters: dict - description: str = "" +def pydantic_model_from_function(fn: Callable): + sig = inspect.signature(fn) + fields = {} + for name, param in sig.parameters.items(): + annotation = ( + param.annotation if param.annotation is not inspect.Parameter.empty else Any + ) + default = param.default if param.default is not inspect.Parameter.empty else ... + fields[name] = (annotation, Field(default=default)) + return create_model(fn.__name__, **fields) + + +def _sync_wrapper(coro): + """ + Wrapper that runs a coroutine as a synchronous function with deffered args + """ + + @functools.wraps(coro) + def wrapper(*args, **kwargs): + return run_coro_as_sync(coro(*args, **kwargs)) + + return wrapper + +class Tool(langchain_core.tools.StructuredTool): + """ + A subclass of StructuredTool that is compatible with Pydantic v1 models + (which Langchain uses) and v2 models (which ControlFlow users). -class Tool(ControlFlowModel): - type: Literal["function"] = "function" - function: ToolFunction - _fn: Callable = pydantic.PrivateAttr() - _metadata: dict = pydantic.PrivateAttr(default_factory=dict) + Note that THIS is a Pydantic v1 model because it subclasses the Langchain class. + """ - def __init__(self, *, _fn: Callable, _metadata: dict = None, **kwargs): - super().__init__(**kwargs) - self._fn = _fn - self._metadata = _metadata or {} + tags: dict[str, Any] = pydantic.v1.Field(default_factory=dict) + args_schema: Optional[type[Union[pydantic.v1.BaseModel, pydantic.BaseModel]]] @classmethod - def from_function( - cls, - fn: Callable, - name: Optional[str] = None, - description: Optional[str] = None, - metadata: Optional[dict] = None, - ): - if name is None and fn.__name__ == "": - name = "__lambda__" - - return cls( - function=ToolFunction( - name=name or fn.__name__, - description=inspect.cleandoc(description or fn.__doc__ or ""), - parameters=pydantic.TypeAdapter( - fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True) - ).json_schema(), - ), - _fn=fn, - _metadata=metadata or getattr(fn, "__metadata__", {}), + def from_function(cls, fn=None, *args, **kwargs): + args_schema = pydantic_model_from_function(fn) + if inspect.iscoroutinefunction(fn): + fn, coro = _sync_wrapper(fn), fn + else: + coro = None + return super().from_function( + *args, func=fn, coroutine=coro, args_schema=args_schema, **kwargs ) - def __call__(self, *args, **kwargs): - return self._fn(*args, **kwargs) - def tool( fn: Optional[Callable] = None, *, name: Optional[str] = None, description: Optional[str] = None, - metadata: Optional[dict] = None, + tags: Optional[dict] = None, ) -> Tool: if fn is None: - return partial(tool, name=name, description=description, metadata=metadata) - return Tool.from_function(fn, name=name, description=description, metadata=metadata) - - -def annotate_fn( - fn: Callable, - name: Optional[str] = None, - description: Optional[str] = None, - metadata: Optional[dict] = None, -) -> Callable: - """ - Annotate a function with a new name and description without modifying the - original. Useful when you want to provide a custom name and description for - a tool, but without creating a new tool object. - """ - new_fn = functools.partial(fn) - new_fn.__name__ = name or fn.__name__ - new_fn.__doc__ = description or fn.__doc__ - new_fn.__metadata__ = getattr(fn, "__metadata__", {}) | metadata - return new_fn - - -def as_tools(tools: list[Union[Tool, Callable]]) -> list[Tool]: - tools = [t if isinstance(t, Tool) else tool(t) for t in tools] - if len({t.function.name for t in tools}) != len(tools): - duplicates = {t.function.name for t in tools if tools.count(t) > 1} - raise ValueError( - f"Tool names must be unique, but found duplicates: {', '.join(duplicates)}" - ) - return tools - - -def as_tool_lookup(tools: list[Union[Tool, Callable]]) -> dict[str, Tool]: - return {t.function.name: t for t in as_tools(tools)} - - -def custom_partial(func: Callable, **fixed_kwargs: Any) -> Callable: - """ - Returns a new function with partial application of the given keyword arguments. - The new function has the same __name__ and docstring as the original, and its - signature excludes the provided kwargs. - """ - - # Define the new function with a dynamic signature - def wrapper(**kwargs): - # Merge the provided kwargs with the fixed ones, prioritizing the former - all_kwargs = {**fixed_kwargs, **kwargs} - return func(**all_kwargs) + return functools.partial(tool, name=name, description=description, tags=tags) + return Tool.from_function(fn, name=name, description=description, tags=tags or {}) - # Update the wrapper function's metadata to match the original function - update_wrapper(wrapper, func) - # Modify the signature to exclude the fixed kwargs - original_sig = inspect.signature(func) - new_params = [ - param - for param in original_sig.parameters.values() - if param.name not in fixed_kwargs - ] - wrapper.__signature__ = original_sig.replace(parameters=new_params) - - return wrapper +def as_tools(tools: list[Union[Callable, Tool]]) -> list[Tool]: + new_tools = [] + for t in tools: + if not isinstance(t, Tool): + t = Tool.from_function(t) + new_tools.append(t) + return new_tools def output_to_string(output: Any) -> str: @@ -145,23 +96,9 @@ def output_to_string(output: Any) -> str: return output -def get_tool_calls( - messages: list[ControlFlowMessage], -) -> list[ToolCall]: - if not isinstance(messages, list): - messages = [messages] - return [ - tc - for m in messages - if isinstance(m, AssistantMessage) and m.tool_calls - for tc in m.tool_calls - ] - - -def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMessage: - tool_lookup = as_tool_lookup(tools) - fn_name = tool_call.function.name - fn_args = None +def handle_tool_call(tool_call: ToolCall, tools: list[Tool]) -> "ToolMessage": + tool_lookup = {t.name: t for t in tools} + fn_name = tool_call["name"] metadata = {} try: if fn_name not in tool_lookup: @@ -169,17 +106,19 @@ def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMe metadata["is_failed"] = True else: tool = tool_lookup[fn_name] - metadata.update(tool._metadata) - fn_args = tool_call.function.json_arguments() - fn_output = tool(**fn_args) + fn_args = tool_call["args"] + fn_output = tool.invoke(input=fn_args) if inspect.isawaitable(fn_output): fn_output = run_coro_as_sync(fn_output) except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' metadata["is_failed"] = True + + from controlflow.llm.messages import ToolMessage + return ToolMessage( content=output_to_string(fn_output), - tool_call_id=tool_call.id, + tool_call_id=tool_call["id"], tool_call=tool_call, tool_result=fn_output, tool_metadata=metadata, @@ -187,11 +126,10 @@ def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMe async def handle_tool_call_async( - tool_call: ToolCall, tools: list[dict, Callable] -) -> ToolMessage: - tool_lookup = as_tool_lookup(tools) - fn_name = tool_call.function.name - fn_args = None + tool_call: ToolCall, tools: list[Tool] +) -> "ToolMessage": + tool_lookup = {t.name: t for t in tools} + fn_name = tool_call["name"] metadata = {} try: if fn_name not in tool_lookup: @@ -199,17 +137,17 @@ async def handle_tool_call_async( metadata["is_failed"] = True else: tool = tool_lookup[fn_name] - metadata = tool._metadata - fn_args = tool_call.function.json_arguments() - fn_output = tool(**fn_args) - if inspect.isawaitable(fn_output): - fn_output = await fn_output + fn_args = tool_call["args"] + fn_output = await tool.ainvoke(input=fn_args) except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' metadata["is_failed"] = True + + from controlflow.llm.messages import ToolMessage + return ToolMessage( content=output_to_string(fn_output), - tool_call_id=tool_call.id, + tool_call_id=tool_call["id"], tool_call=tool_call, tool_result=fn_output, tool_metadata=metadata, diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 2ec6bc75..fd518e23 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -4,9 +4,8 @@ from contextlib import contextmanager from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union -import litellm from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -53,7 +52,7 @@ def apply(self): class Settings(ControlFlowSettings): assistant_model: str = "gpt-4o" max_task_iterations: Union[int, None] = Field( - None, + 100, description="The maximum number of iterations to attempt to complete a task " "before raising an error. If None, the task will run indefinitely. " "This setting can be overridden by the `max_iterations` attribute " @@ -89,14 +88,7 @@ class Settings(ControlFlowSettings): # ------------ LLM settings ------------ - llm_model: str = Field("gpt-4o", description="The LLM model to use.") - llm_api_key: Optional[str] = Field(None, description="The LLM API key to use.") - llm_api_base: Optional[str] = Field( - None, description="The LLM API base URL to use." - ) - llm_api_version: Optional[str] = Field( - None, description="The LLM API version to use." - ) + llm_model: str = Field("openai/gpt-4o", description="The LLM model to use.") # ------------ Flow visualization settings ------------ @@ -125,12 +117,6 @@ def _validate_home_path(cls, v): v.mkdir(parents=True, exist_ok=True) return v - @field_validator("llm_model", mode="before") - def _validate_model(cls, v): - if not litellm.supports_function_calling(model=v): - raise ValueError(f"Model '{v}' does not support function calling.") - return v - settings = Settings() diff --git a/src/controlflow/tui/app.py b/src/controlflow/tui/app.py index 3a189fc0..cb28f834 100644 --- a/src/controlflow/tui/app.py +++ b/src/controlflow/tui/app.py @@ -11,7 +11,7 @@ import controlflow import controlflow.utilities import controlflow.utilities.asyncio -from controlflow.llm.messages import AssistantMessage, ToolMessage, UserMessage +from controlflow.llm.messages import AIMessage, HumanMessage, ToolMessage from .basic import Column, Row from .task import TUITask @@ -116,7 +116,7 @@ def update_task(self, task: "controlflow.Task"): self.query_one("#tasks-container", Column).mount(new_task) new_task.scroll_visible() - def update_message(self, message: Union[UserMessage, AssistantMessage]): + def update_message(self, message: Union[HumanMessage, AIMessage]): try: component = self.query_one(f"#message-{message.id}", TUIMessage) component.message = message diff --git a/src/controlflow/tui/test.py b/src/controlflow/tui/test.py index d5ae3d53..4eb382f8 100644 --- a/src/controlflow/tui/test.py +++ b/src/controlflow/tui/test.py @@ -5,7 +5,7 @@ from controlflow import Task from controlflow.core.flow import Flow -from controlflow.llm.messages import AssistantMessage +from controlflow.llm.messages import AIMessage from controlflow.tui.app import TUIApp @@ -50,13 +50,13 @@ async def run(): ) await asyncio.sleep(1) t0.mark_failed(message="this is my result") - app.update_message(AssistantMessage(content="hello there")) + app.update_message(AIMessage(content="hello there")) await asyncio.sleep(1) - app.update_message(AssistantMessage(content="hello there")) + app.update_message(AIMessage(content="hello there")) await asyncio.sleep(1) - app.update_message(AssistantMessage(content="hello there" * 50)) + app.update_message(AIMessage(content="hello there" * 50)) await asyncio.sleep(1) - app.update_message(AssistantMessage(content="hello there")) + app.update_message(AIMessage(content="hello there")) await asyncio.sleep(1) await asyncio.sleep(inf) diff --git a/src/controlflow/tui/thread.py b/src/controlflow/tui/thread.py index 9c53c5ed..e6268cf6 100644 --- a/src/controlflow/tui/thread.py +++ b/src/controlflow/tui/thread.py @@ -5,7 +5,7 @@ from textual.widgets import Static from controlflow.llm.formatting import format_message, format_tool_message -from controlflow.llm.messages import AssistantMessage, ToolMessage, UserMessage +from controlflow.llm.messages import AIMessage, HumanMessage, ToolMessage def format_timestamp(timestamp: datetime.datetime) -> str: @@ -13,11 +13,11 @@ def format_timestamp(timestamp: datetime.datetime) -> str: class TUIMessage(Static): - message: reactive[Union[UserMessage, AssistantMessage]] = reactive( + message: reactive[Union[HumanMessage, AIMessage]] = reactive( None, always_update=True, layout=True ) - def __init__(self, message: Union[UserMessage, AssistantMessage], **kwargs): + def __init__(self, message: Union[HumanMessage, AIMessage], **kwargs): super().__init__(**kwargs) self.message = message diff --git a/tests/fixtures/mocks.py b/tests/fixtures/mocks.py index a4cbf656..b794443d 100644 --- a/tests/fixtures/mocks.py +++ b/tests/fixtures/mocks.py @@ -1,187 +1,187 @@ -from typing import Any -from unittest.mock import AsyncMock, Mock, patch +# from typing import Any +# from unittest.mock import AsyncMock, Mock, patch -import litellm -import pytest -from controlflow.core.agent import Agent -from controlflow.core.task import Task, TaskStatus -from controlflow.llm.completions import Response -from controlflow.settings import temporary_settings +# import litellm +# import pytest +# from controlflow.core.agent import Agent +# from controlflow.core.task import Task, TaskStatus +# from controlflow.llm.completions import Response +# from controlflow.settings import temporary_settings -def new_chunk(): - chunk = litellm.ModelResponse() - chunk.choices = [litellm.utils.StreamingChoices()] - return chunk +# def new_chunk(): +# chunk = litellm.ModelResponse() +# chunk.choices = [litellm.utils.StreamingChoices()] +# return chunk -@pytest.fixture -def prevent_openai_calls(): - """Prevent any calls to the OpenAI API from being made.""" - with temporary_settings(llm_api_key="unset"): - yield +# @pytest.fixture +# def prevent_openai_calls(): +# """Prevent any calls to the OpenAI API from being made.""" +# with temporary_settings(llm_api_key="unset"): +# yield -@pytest.fixture -def mock_controller_run_agent(monkeypatch, prevent_openai_calls): - MockRunAgent = AsyncMock() - MockThreadGetMessages = Mock() +# @pytest.fixture +# def mock_controller_run_agent(monkeypatch, prevent_openai_calls): +# MockRunAgent = AsyncMock() +# MockThreadGetMessages = Mock() - async def _run_agent(agent: Agent, tasks: list[Task] = None, thread=None): - for task in tasks: - if agent in task.get_agents(): - # we can't call mark_successful because we don't know the result - task.status = TaskStatus.SUCCESSFUL +# async def _run_agent(agent: Agent, tasks: list[Task] = None, thread=None): +# for task in tasks: +# if agent in task.get_agents(): +# # we can't call mark_successful because we don't know the result +# task.status = TaskStatus.SUCCESSFUL - MockRunAgent.side_effect = _run_agent +# MockRunAgent.side_effect = _run_agent - def get_messages(*args, **kwargs): - return [] +# def get_messages(*args, **kwargs): +# return [] - MockThreadGetMessages.side_effect = get_messages +# MockThreadGetMessages.side_effect = get_messages - monkeypatch.setattr( - "controlflow.core.controller.controller.Controller._run_agent", MockRunAgent - ) - yield MockRunAgent +# monkeypatch.setattr( +# "controlflow.core.controller.controller.Controller._run_agent", MockRunAgent +# ) +# yield MockRunAgent -@pytest.fixture -def mock_controller_choose_agent(monkeypatch): - MockChooseAgent = Mock() +# @pytest.fixture +# def mock_controller_choose_agent(monkeypatch): +# MockChooseAgent = Mock() - def choose_agent(agents, **kwargs): - return agents[0] +# def choose_agent(agents, **kwargs): +# return agents[0] - MockChooseAgent.side_effect = choose_agent +# MockChooseAgent.side_effect = choose_agent - monkeypatch.setattr( - "controlflow.core.controller.controller.Controller.choose_agent", - MockChooseAgent, - ) - yield MockChooseAgent +# monkeypatch.setattr( +# "controlflow.core.controller.controller.Controller.choose_agent", +# MockChooseAgent, +# ) +# yield MockChooseAgent -@pytest.fixture -def mock_controller(mock_controller_choose_agent, mock_controller_run_agent): - pass +# @pytest.fixture +# def mock_controller(mock_controller_choose_agent, mock_controller_run_agent): +# pass -@pytest.fixture -def mock_completion(monkeypatch): - """ - Mock the completion function from the LLM module. Use this fixture to set - the response value ahead of calling the completion. +# @pytest.fixture +# def mock_completion(monkeypatch): +# """ +# Mock the completion function from the LLM module. Use this fixture to set +# the response value ahead of calling the completion. - Example: +# Example: - def test_completion(mock_completion): - mock_completion.set_response("Hello, world!") - response = litellm.completion(...) - assert response == "Hello, world!" - """ - response = litellm.ModelResponse() +# def test_completion(mock_completion): +# mock_completion.set_response("Hello, world!") +# response = litellm.completion(...) +# assert response == "Hello, world!" +# """ +# response = litellm.ModelResponse() - def set_response(message: str): - response.choices[0].message.content = message +# def set_response(message: str): +# response.choices[0].message.content = message - def mock_func(*args, **kwargs): - return Response(responses=[response], messages=[]) +# def mock_func(*args, **kwargs): +# return Response(responses=[response], messages=[]) - monkeypatch.setattr("controlflow.llm.completions.completion", mock_func) - mock_func.set_response = set_response +# monkeypatch.setattr("controlflow.llm.completions.completion", mock_func) +# mock_func.set_response = set_response - return mock_func +# return mock_func -@pytest.fixture -def mock_completion_stream(monkeypatch): - """ - Mock the completion function from the LLM module. Use this fixture to set - the response value ahead of calling the completion. +# @pytest.fixture +# def mock_completion_stream(monkeypatch): +# """ +# Mock the completion function from the LLM module. Use this fixture to set +# the response value ahead of calling the completion. - Example: +# Example: - def test_completion(mock_completion): - mock_completion.set_response("Hello, world!") - response = litellm.completion(...) - assert response == "Hello, world!" - """ - response = litellm.ModelResponse() - chunk = litellm.ModelResponse() - chunk.choices = [litellm.utils.StreamingChoices()] +# def test_completion(mock_completion): +# mock_completion.set_response("Hello, world!") +# response = litellm.completion(...) +# assert response == "Hello, world!" +# """ +# response = litellm.ModelResponse() +# chunk = litellm.ModelResponse() +# chunk.choices = [litellm.utils.StreamingChoices()] - def set_response(message: str): - response.choices[0].message.content = message +# def set_response(message: str): +# response.choices[0].message.content = message - def mock_func_deltas(*args, **kwargs): - for c in response.choices[0].message.content: - chunk = new_chunk() - chunk.choices[0].delta.content = c - yield chunk, response +# def mock_func_deltas(*args, **kwargs): +# for c in response.choices[0].message.content: +# chunk = new_chunk() +# chunk.choices[0].delta.content = c +# yield chunk, response - monkeypatch.setattr( - "controlflow.llm.completions.completion_stream", mock_func_deltas - ) - mock_func_deltas.set_response = set_response +# monkeypatch.setattr( +# "controlflow.llm.completions.completion_stream", mock_func_deltas +# ) +# mock_func_deltas.set_response = set_response - return mock_func_deltas +# return mock_func_deltas -@pytest.fixture -def mock_completion_async(monkeypatch): - """ - Mock the completion function from the LLM module. Use this fixture to set - the response value ahead of calling the completion. +# @pytest.fixture +# def mock_completion_async(monkeypatch): +# """ +# Mock the completion function from the LLM module. Use this fixture to set +# the response value ahead of calling the completion. - Example: +# Example: - def test_completion(mock_completion): - mock_completion.set_response("Hello, world!") - response = litellm.completion(...) - assert response == "Hello, world!" - """ - response = litellm.ModelResponse() +# def test_completion(mock_completion): +# mock_completion.set_response("Hello, world!") +# response = litellm.completion(...) +# assert response == "Hello, world!" +# """ +# response = litellm.ModelResponse() - def set_response(message: str): - response.choices[0].message.content = message +# def set_response(message: str): +# response.choices[0].message.content = message - async def mock_func(*args, **kwargs): - return Response(responses=[response], messages=[]) +# async def mock_func(*args, **kwargs): +# return Response(responses=[response], messages=[]) - monkeypatch.setattr("controlflow.llm.completions.completion_async", mock_func) - mock_func.set_response = set_response +# monkeypatch.setattr("controlflow.llm.completions.completion_async", mock_func) +# mock_func.set_response = set_response - return mock_func +# return mock_func -@pytest.fixture -def mock_completion_stream_async(monkeypatch): - """ - Mock the completion function from the LLM module. Use this fixture to set - the response value ahead of calling the completion. +# @pytest.fixture +# def mock_completion_stream_async(monkeypatch): +# """ +# Mock the completion function from the LLM module. Use this fixture to set +# the response value ahead of calling the completion. - Example: +# Example: - def test_completion(mock_completion): - mock_completion.set_response("Hello, world!") - response = litellm.completion(...) - assert response == "Hello, world!" - """ - response = litellm.ModelResponse() +# def test_completion(mock_completion): +# mock_completion.set_response("Hello, world!") +# response = litellm.completion(...) +# assert response == "Hello, world!" +# """ +# response = litellm.ModelResponse() - def set_response(message: str): - response.choices[0].message.content = message +# def set_response(message: str): +# response.choices[0].message.content = message - async def mock_func_deltas(*args, **kwargs): - for c in response.choices[0].message.content: - chunk = new_chunk() - chunk.choices[0].delta.content = c - yield chunk, response +# async def mock_func_deltas(*args, **kwargs): +# for c in response.choices[0].message.content: +# chunk = new_chunk() +# chunk.choices[0].delta.content = c +# yield chunk, response - monkeypatch.setattr( - "controlflow.llm.completions.completion_stream_async", mock_func_deltas - ) - mock_func_deltas.set_response = set_response +# monkeypatch.setattr( +# "controlflow.llm.completions.completion_stream_async", mock_func_deltas +# ) +# mock_func_deltas.set_response = set_response - return mock_func_deltas +# return mock_func_deltas diff --git a/tests/llm/test_handlers.py b/tests/llm/test_handlers.py index ff8d684b..0800b98d 100644 --- a/tests/llm/test_handlers.py +++ b/tests/llm/test_handlers.py @@ -1,53 +1,53 @@ -from collections import Counter - -import litellm -from controlflow.llm.completions import _completion_stream -from controlflow.llm.handlers import CompletionHandler -from controlflow.llm.messages import AssistantMessage -from controlflow.llm.tools import ToolResult -from pydantic import BaseModel - - -class StreamCall(BaseModel): - method: str - args: dict - - -class MockCompletionHandler(CompletionHandler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.calls: list[StreamCall] = [] - - def on_message_created(self, delta: litellm.utils.Delta): - self.calls.append( - StreamCall(method="on_message_created", args=dict(delta=delta)) - ) - - def on_message_delta(self, delta: litellm.utils.Delta, snapshot: litellm.Message): - self.calls.append( - StreamCall( - method="on_message_delta", args=dict(delta=delta, snapshot=snapshot) - ) - ) - - def on_message_done(self, message: AssistantMessage): - self.calls.append( - StreamCall(method="on_message_done", args=dict(message=message)) - ) - - def on_tool_call_done(self, tool_call: ToolResult): - self.calls.append( - StreamCall(method="on_tool_call", args=dict(tool_call=tool_call)) - ) - - -class TestCompletionHandler: - def test_stream(self): - handler = MockCompletionHandler() - gen = _completion_stream(messages=[{"text": "Hello"}]) - handler.stream(gen) - - method_counts = Counter(call.method for call in handler.calls) - assert method_counts["on_message_created"] == 1 - assert method_counts["on_message_delta"] == 4 - assert method_counts["on_message_done"] == 1 +# from collections import Counter + +# import litellm +# from controlflow.llm.completions import _completion_stream +# from controlflow.llm.handlers import CompletionHandler +# from controlflow.llm.messages import AIMessage +# from controlflow.llm.tools import ToolResult +# from pydantic import BaseModel + + +# class StreamCall(BaseModel): +# method: str +# args: dict + + +# class MockCompletionHandler(CompletionHandler): +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self.calls: list[StreamCall] = [] + +# def on_message_created(self, delta: litellm.utils.Delta): +# self.calls.append( +# StreamCall(method="on_message_created", args=dict(delta=delta)) +# ) + +# def on_message_delta(self, delta: litellm.utils.Delta, snapshot: litellm.Message): +# self.calls.append( +# StreamCall( +# method="on_message_delta", args=dict(delta=delta, snapshot=snapshot) +# ) +# ) + +# def on_message_done(self, message: AIMessage): +# self.calls.append( +# StreamCall(method="on_message_done", args=dict(message=message)) +# ) + +# def on_tool_call_done(self, tool_call: ToolResult): +# self.calls.append( +# StreamCall(method="on_tool_call", args=dict(tool_call=tool_call)) +# ) + + +# class TestCompletionHandler: +# def test_stream(self): +# handler = MockCompletionHandler() +# gen = _completion_stream(messages=[{"text": "Hello"}]) +# handler.stream(gen) + +# method_counts = Counter(call.method for call in handler.calls) +# assert method_counts["on_message_created"] == 1 +# assert method_counts["on_message_delta"] == 4 +# assert method_counts["on_message_done"] == 1