From 7a98e2b83c1ec1f8315c2f88b354141e3a8c7a9c Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 20 May 2024 23:39:43 -0400 Subject: [PATCH] Add litellm --- pyproject.toml | 2 + requirements-dev.lock | 67 +++++++- requirements.lock | 67 +++++++- src/controlflow/llm/__init__.py | 0 src/controlflow/llm/completions.py | 259 +++++++++++++++++++++++++++++ src/controlflow/llm/history.py | 14 ++ src/controlflow/llm/tools.py | 148 +++++++++++++++++ src/controlflow/settings.py | 17 ++ src/controlflow/utilities/types.py | 1 + 9 files changed, 573 insertions(+), 2 deletions(-) create mode 100644 src/controlflow/llm/__init__.py create mode 100644 src/controlflow/llm/completions.py create mode 100644 src/controlflow/llm/history.py create mode 100644 src/controlflow/llm/tools.py diff --git a/pyproject.toml b/pyproject.toml index 6b172a32..734d755a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,8 @@ dependencies = [ # can remove when prefect fully migrates to pydantic 2 "pydantic>=2", "textual>=0.61.1", + "litellm>=1.37.17", + "numpydoc>=1.7.0", ] readme = "README.md" requires-python = ">= 3.9" diff --git a/requirements-dev.lock b/requirements-dev.lock index 302a431d..0cba9cde 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -8,8 +8,14 @@ # with-sources: false -e file:. +aiohttp==3.9.5 + # via litellm +aiosignal==1.3.1 + # via aiohttp aiosqlite==0.20.0 # via prefect +alabaster==0.7.16 + # via sphinx alembic==1.13.1 # via prefect annotated-types==0.6.0 @@ -29,12 +35,14 @@ asttokens==2.4.1 asyncpg==0.29.0 # via prefect attrs==23.2.0 + # via aiohttp # via cattrs # via ddtrace # via jsonschema # via referencing babel==2.15.0 # via mkdocs-material + # via sphinx boto3==1.34.103 # via moto botocore==1.34.103 @@ -68,6 +76,7 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via apprise + # via litellm # via mkdocs # via mkdocstrings # via prefect @@ -109,6 +118,8 @@ distro==1.9.0 # via openai docker==6.1.3 # via prefect +docutils==0.21.2 + # via sphinx envier==0.5.1 # via ddtrace execnet==2.1.1 @@ -118,8 +129,13 @@ executing==2.0.1 fastapi==0.110.0 # via marvin filelock==3.13.3 + # via huggingface-hub # via virtualenv +frozenlist==1.4.1 + # via aiohttp + # via aiosignal fsspec==2024.3.1 + # via huggingface-hub # via prefect ghp-import==2.1.0 # via mkdocs @@ -147,6 +163,8 @@ httpx==0.27.0 # via openai # via prefect # via respx +huggingface-hub==0.23.0 + # via tokenizers humanize==4.9.0 # via jinja2-humanize-extension # via prefect @@ -158,7 +176,11 @@ idna==3.6 # via anyio # via httpx # via requests + # via yarl +imagesize==1.4.1 + # via sphinx importlib-metadata==7.0.0 + # via litellm # via mike # via opentelemetry-api importlib-resources==6.1.3 @@ -174,6 +196,7 @@ jedi==0.19.1 # via ipython jinja2==3.1.3 # via jinja2-humanize-extension + # via litellm # via marvin # via mike # via mkdocs @@ -181,6 +204,7 @@ jinja2==3.1.3 # via mkdocstrings # via moto # via prefect + # via sphinx jinja2-humanize-extension==0.4.0 # via prefect jmespath==1.0.1 @@ -199,6 +223,8 @@ kubernetes==29.0.0 # via prefect linkify-it-py==2.0.3 # via markdown-it-py +litellm==1.37.17 + # via controlflow mako==1.3.2 # via alembic markdown==3.6 @@ -255,6 +281,9 @@ mkdocstrings-python==1.9.2 # via prefect moto==5.0.6 # via prefect +multidict==6.0.5 + # via aiohttp + # via yarl mypy==1.10.0 # via prefect mypy-extensions==1.0.0 @@ -263,10 +292,13 @@ nodeenv==1.8.0 # via pre-commit numpy==1.26.4 # via prefect +numpydoc==1.7.0 + # via controlflow oauthlib==3.2.2 # via kubernetes # via requests-oauthlib openai==1.28.1 + # via litellm # via marvin opentelemetry-api==1.24.0 # via ddtrace @@ -274,9 +306,11 @@ orjson==3.10.0 # via prefect packaging==24.0 # via docker + # via huggingface-hub # via mkdocs # via prefect # via pytest + # via sphinx paginate==0.5.6 # via mkdocs-material parso==0.8.3 @@ -302,7 +336,7 @@ pluggy==1.4.0 # via pytest pre-commit==3.7.0 # via prefect -prefect @ git+https://github.com/prefecthq/prefect@8d56742dd83273af0c9d6c986b752f2a8439e9a8 +prefect @ git+https://github.com/prefecthq/prefect@b7ce5ed9eb0cf813dfb20a9f1fe8815dd2008ca1 # via controlflow prompt-toolkit==3.0.43 # via ipython @@ -338,6 +372,7 @@ pygments==2.17.2 # via ipython # via mkdocs-material # via rich + # via sphinx pymdown-extensions==10.8.1 # via mkdocs-material # via mkdocstrings @@ -377,6 +412,7 @@ python-dateutil==2.9.0.post0 # via prefect # via time-machine python-dotenv==1.0.1 + # via litellm # via pydantic-settings python-multipart==0.0.9 # via prefect @@ -390,6 +426,7 @@ pytz==2024.1 # via prefect pyyaml==6.0.1 # via apprise + # via huggingface-hub # via kubernetes # via mike # via mkdocs @@ -414,12 +451,15 @@ regex==2023.12.25 requests==2.31.0 # via apprise # via docker + # via huggingface-hub # via kubernetes + # via litellm # via mkdocs-material # via moto # via prefect # via requests-oauthlib # via responses + # via sphinx # via tiktoken requests-oauthlib==2.0.0 # via apprise @@ -468,6 +508,22 @@ sniffio==1.3.1 # via httpx # via openai # via prefect +snowballstemmer==2.2.0 + # via sphinx +sphinx==7.3.7 + # via numpydoc +sphinxcontrib-applehelp==1.0.8 + # via sphinx +sphinxcontrib-devhelp==1.0.6 + # via sphinx +sphinxcontrib-htmlhelp==2.0.5 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.7 + # via sphinx +sphinxcontrib-serializinghtml==1.1.10 + # via sphinx sqlalchemy==2.0.29 # via alembic # via prefect @@ -477,20 +533,26 @@ stack-data==0.6.3 # via ipython starlette==0.36.3 # via fastapi +tabulate==0.9.0 + # via numpydoc text-unidecode==1.3 # via python-slugify textual==0.61.1 # via controlflow tiktoken==0.6.0 + # via litellm # via marvin time-machine==2.14.1 # via pendulum tinycss2==1.3.0 # via cairosvg # via cssselect2 +tokenizers==0.19.1 + # via litellm toml==0.10.2 # via prefect tqdm==4.66.2 + # via huggingface-hub # via openai traitlets==5.14.2 # via ipython @@ -512,6 +574,7 @@ typing-extensions==4.10.0 # via alembic # via ddtrace # via fastapi + # via huggingface-hub # via marvin # via mypy # via openai @@ -567,5 +630,7 @@ wrapt==1.16.0 xmltodict==0.13.0 # via ddtrace # via moto +yarl==1.9.4 + # via aiohttp zipp==3.18.1 # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index f06764c6..2dfdbf5c 100644 --- a/requirements.lock +++ b/requirements.lock @@ -8,8 +8,14 @@ # with-sources: false -e file:. +aiohttp==3.9.5 + # via litellm +aiosignal==1.3.1 + # via aiohttp aiosqlite==0.20.0 # via prefect +alabaster==0.7.16 + # via sphinx alembic==1.13.1 # via prefect annotated-types==0.6.0 @@ -29,12 +35,14 @@ asttokens==2.4.1 asyncpg==0.29.0 # via prefect attrs==23.2.0 + # via aiohttp # via cattrs # via ddtrace # via jsonschema # via referencing babel==2.15.0 # via mkdocs-material + # via sphinx boto3==1.34.103 # via moto botocore==1.34.103 @@ -68,6 +76,7 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via apprise + # via litellm # via mkdocs # via mkdocstrings # via prefect @@ -109,6 +118,8 @@ distro==1.9.0 # via openai docker==6.1.3 # via prefect +docutils==0.21.2 + # via sphinx envier==0.5.1 # via ddtrace execnet==2.1.1 @@ -118,8 +129,13 @@ executing==2.0.1 fastapi==0.110.0 # via marvin filelock==3.14.0 + # via huggingface-hub # via virtualenv +frozenlist==1.4.1 + # via aiohttp + # via aiosignal fsspec==2024.3.1 + # via huggingface-hub # via prefect ghp-import==2.1.0 # via mkdocs @@ -147,6 +163,8 @@ httpx==0.27.0 # via openai # via prefect # via respx +huggingface-hub==0.23.0 + # via tokenizers humanize==4.9.0 # via jinja2-humanize-extension # via prefect @@ -158,7 +176,11 @@ idna==3.6 # via anyio # via httpx # via requests + # via yarl +imagesize==1.4.1 + # via sphinx importlib-metadata==7.0.0 + # via litellm # via mike # via opentelemetry-api importlib-resources==6.1.3 @@ -174,6 +196,7 @@ jedi==0.19.1 # via ipython jinja2==3.1.3 # via jinja2-humanize-extension + # via litellm # via marvin # via mike # via mkdocs @@ -181,6 +204,7 @@ jinja2==3.1.3 # via mkdocstrings # via moto # via prefect + # via sphinx jinja2-humanize-extension==0.4.0 # via prefect jmespath==1.0.1 @@ -199,6 +223,8 @@ kubernetes==29.0.0 # via prefect linkify-it-py==2.0.3 # via markdown-it-py +litellm==1.37.17 + # via controlflow mako==1.3.2 # via alembic markdown==3.6 @@ -255,6 +281,9 @@ mkdocstrings-python==1.9.2 # via prefect moto==5.0.6 # via prefect +multidict==6.0.5 + # via aiohttp + # via yarl mypy==1.10.0 # via prefect mypy-extensions==1.0.0 @@ -263,10 +292,13 @@ nodeenv==1.8.0 # via pre-commit numpy==1.26.4 # via prefect +numpydoc==1.7.0 + # via controlflow oauthlib==3.2.2 # via kubernetes # via requests-oauthlib openai==1.28.1 + # via litellm # via marvin opentelemetry-api==1.24.0 # via ddtrace @@ -274,9 +306,11 @@ orjson==3.10.0 # via prefect packaging==24.0 # via docker + # via huggingface-hub # via mkdocs # via prefect # via pytest + # via sphinx paginate==0.5.6 # via mkdocs-material parso==0.8.4 @@ -302,7 +336,7 @@ pluggy==1.5.0 # via pytest pre-commit==3.7.1 # via prefect -prefect @ git+https://github.com/prefecthq/prefect@8d56742dd83273af0c9d6c986b752f2a8439e9a8 +prefect @ git+https://github.com/prefecthq/prefect@b7ce5ed9eb0cf813dfb20a9f1fe8815dd2008ca1 # via controlflow prompt-toolkit==3.0.43 # via ipython @@ -338,6 +372,7 @@ pygments==2.17.2 # via ipython # via mkdocs-material # via rich + # via sphinx pymdown-extensions==10.8.1 # via mkdocs-material # via mkdocstrings @@ -377,6 +412,7 @@ python-dateutil==2.9.0.post0 # via prefect # via time-machine python-dotenv==1.0.1 + # via litellm # via pydantic-settings python-multipart==0.0.9 # via prefect @@ -390,6 +426,7 @@ pytz==2024.1 # via prefect pyyaml==6.0.1 # via apprise + # via huggingface-hub # via kubernetes # via mike # via mkdocs @@ -414,12 +451,15 @@ regex==2023.12.25 requests==2.31.0 # via apprise # via docker + # via huggingface-hub # via kubernetes + # via litellm # via mkdocs-material # via moto # via prefect # via requests-oauthlib # via responses + # via sphinx # via tiktoken requests-oauthlib==2.0.0 # via apprise @@ -468,6 +508,22 @@ sniffio==1.3.1 # via httpx # via openai # via prefect +snowballstemmer==2.2.0 + # via sphinx +sphinx==7.3.7 + # via numpydoc +sphinxcontrib-applehelp==1.0.8 + # via sphinx +sphinxcontrib-devhelp==1.0.6 + # via sphinx +sphinxcontrib-htmlhelp==2.0.5 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.7 + # via sphinx +sphinxcontrib-serializinghtml==1.1.10 + # via sphinx sqlalchemy==2.0.29 # via alembic # via prefect @@ -477,20 +533,26 @@ stack-data==0.6.3 # via ipython starlette==0.36.3 # via fastapi +tabulate==0.9.0 + # via numpydoc text-unidecode==1.3 # via python-slugify textual==0.61.1 # via controlflow tiktoken==0.6.0 + # via litellm # via marvin time-machine==2.14.1 # via pendulum tinycss2==1.3.0 # via cairosvg # via cssselect2 +tokenizers==0.19.1 + # via litellm toml==0.10.2 # via prefect tqdm==4.66.2 + # via huggingface-hub # via openai traitlets==5.14.3 # via ipython @@ -512,6 +574,7 @@ typing-extensions==4.10.0 # via alembic # via ddtrace # via fastapi + # via huggingface-hub # via marvin # via mypy # via openai @@ -567,5 +630,7 @@ wrapt==1.16.0 xmltodict==0.13.0 # via ddtrace # via moto +yarl==1.9.4 + # via aiohttp zipp==3.18.1 # via importlib-metadata diff --git a/src/controlflow/llm/__init__.py b/src/controlflow/llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py new file mode 100644 index 00000000..bc22ed08 --- /dev/null +++ b/src/controlflow/llm/completions.py @@ -0,0 +1,259 @@ +from typing import AsyncGenerator, Callable, Generator, Tuple, Union + +import litellm + +import controlflow +from controlflow.llm.tools import ( + function_to_tool_dict, + handle_tool_calls, + handle_tool_calls_async, + has_tool_calls, +) +from controlflow.utilities.types import ControlFlowModel + + +class Response(ControlFlowModel): + message: litellm.Message + response: litellm.ModelResponse + intermediate_messages: list[litellm.Message] = [] + intermediate_responses: list[litellm.ModelResponse] = [] + + +def completion( + messages: list[Union[dict, litellm.Message]], + model=None, + tools: list[Callable] = None, + use_tools=True, + **kwargs, +) -> litellm.ModelResponse: + """ + Perform completion using the LLM model. + + Args: + messages: A list of messages to be used for completion. + model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. + tools: A list of callable tools to be used during completion. + use_tools: A boolean indicating whether to use the provided tools during completion. + **kwargs: Additional keyword arguments to be passed to the litellm.completion function. + + Returns: + A litellm.ModelResponse object representing the completion response. + """ + intermediate_messages = [] + intermediate_responses = [] + + if model is None: + model = controlflow.settings.model + if tools is not None: + tool_dicts = [function_to_tool_dict(tool) for tool in tools] + else: + tool_dicts = None + response = litellm.completion( + model=model, + messages=messages, + tools=tool_dicts, + **kwargs, + ) + + while use_tools and has_tool_calls(response): + intermediate_responses.append(response) + intermediate_messages.append(response.choices[0].message) + tool_messages = handle_tool_calls(response, tools) + intermediate_messages.extend(tool_messages) + response = litellm.completion( + model=model, + messages=messages + intermediate_messages, + tools=tool_dicts, + **kwargs, + ) + + return Response( + message=response.choices[0].message, + response=response, + intermediate_messages=intermediate_messages, + intermediate_responses=intermediate_responses, + ) + + +def stream_completion( + messages: list[Union[dict, litellm.Message]], + model=None, + tools: list[Callable] = None, + use_tools: bool = True, + **kwargs, +) -> Generator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None, None]: + """ + Perform streaming completion using the LLM model. + + Args: + messages: A list of messages to be used for completion. + model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. + tools: A list of callable tools to be used during completion. + use_tools: A boolean indicating whether to use the provided tools during completion. + **kwargs: Additional keyword arguments to be passed to the litellm.completion function. + + Yields: + A tuple containing the current completion chunk and the snapshot of the completion response. + + Returns: + The final completion response as a litellm.ModelResponse object. + """ + if model is None: + model = controlflow.settings.model + + if tools is not None: + tool_dicts = [function_to_tool_dict(tool) for tool in tools] + else: + tool_dicts = None + + chunks = [] + for chunk in litellm.completion( + model=model, + messages=messages, + stream=True, + tools=tool_dicts, + **kwargs, + ): + chunks.append(chunk) + snapshot = litellm.stream_chunk_builder(chunks) + yield chunk, snapshot + + response = snapshot + + while use_tools and has_tool_calls(response): + messages.append(response.choices[0].message) + tool_messages = handle_tool_calls(response, tools) + messages.extend(tool_messages) + chunks = [] + for chunk in litellm.completion( + model=model, + messages=messages, + tools=tool_dicts, + stream=True**kwargs, + ): + chunks.append(chunk) + snapshot = litellm.stream_chunk_builder(chunks) + yield chunk, snapshot + response = snapshot + + +async def completion_async( + messages: list[Union[dict, litellm.Message]], + model=None, + tools: list[Callable] = None, + use_tools=True, + **kwargs, +) -> Response: + """ + Perform asynchronous completion using the LLM model. + + Args: + messages: A list of messages to be used for completion. + model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. + tools: A list of callable tools to be used during completion. + use_tools: A boolean indicating whether to use the provided tools during completion. + **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. + + Returns: + Response + """ + intermediate_messages = [] + intermediate_responses = [] + + if model is None: + model = controlflow.settings.model + + if tools is not None: + tool_dicts = [function_to_tool_dict(tool) for tool in tools] + else: + tool_dicts = None + + response = await litellm.acompletion( + model=model, + messages=messages, + tools=tool_dicts, + **kwargs, + ) + + while use_tools and has_tool_calls(response): + intermediate_responses.append(response) + intermediate_messages.append(response.choices[0].message) + tool_messages = await handle_tool_calls_async(response, tools) + intermediate_messages.extend(tool_messages) + response = await litellm.acompletion( + model=model, + messages=messages + intermediate_messages, + tools=tool_dicts, + **kwargs, + ) + + return Response( + message=response.choices[0].message, + response=response, + intermediate_messages=intermediate_messages, + intermediate_responses=intermediate_responses, + ) + + +async def stream_completion_async( + messages: list[Union[dict, litellm.Message]], + model=None, + tools: list[Callable] = None, + use_tools: bool = True, + **kwargs, +) -> AsyncGenerator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None]: + """ + Perform asynchronous streaming completion using the LLM model. + + Args: + messages: A list of messages to be used for completion. + model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. + tools: A list of callable tools to be used during completion. + use_tools: A boolean indicating whether to use the provided tools during completion. + **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. + + Yields: + A tuple containing the current completion chunk and the snapshot of the completion response. + + Returns: + The final completion response as a litellm.ModelResponse object. + """ + if model is None: + model = controlflow.settings.model + + if tools is not None: + tool_dicts = [function_to_tool_dict(tool) for tool in tools] + else: + tool_dicts = None + + chunks = [] + async for chunk in litellm.acompletion( + model=model, + messages=messages, + stream=True, + tools=tool_dicts, + **kwargs, + ): + chunks.append(chunk) + snapshot = litellm.stream_chunk_builder(chunks) + yield chunk, snapshot + + response = snapshot + + while use_tools and has_tool_calls(response): + messages.append(response.choices[0].message) + tool_messages = await handle_tool_calls_async(response, tools) + messages.extend(tool_messages) + chunks = [] + async for chunk in litellm.acompletion( + model=model, + messages=messages, + tools=tool_dicts, + stream=True, + **kwargs, + ): + chunks.append(chunk) + snapshot = litellm.stream_chunk_builder(chunks) + yield chunk, snapshot + + response = snapshot diff --git a/src/controlflow/llm/history.py b/src/controlflow/llm/history.py new file mode 100644 index 00000000..39367906 --- /dev/null +++ b/src/controlflow/llm/history.py @@ -0,0 +1,14 @@ +import uuid + +from pydantic import Field + +from controlflow.utilities.types import ControlFlowModel, Message + + +class Thread(ControlFlowModel): + id: str = Field(default_factory=uuid.uuid4().hex[:8]) + + +class History(ControlFlowModel): + thread: Thread + messages: list[Message] diff --git a/src/controlflow/llm/tools.py b/src/controlflow/llm/tools.py new file mode 100644 index 00000000..f73d36a2 --- /dev/null +++ b/src/controlflow/llm/tools.py @@ -0,0 +1,148 @@ +import inspect +import json +from functools import update_wrapper +from typing import Any, Callable, Optional + +import litellm +import pydantic + +from controlflow.utilities.types import Message + + +def custom_partial(func: Callable, **fixed_kwargs: Any) -> Callable: + """ + Returns a new function with partial application of the given keyword arguments. + The new function has the same __name__ and docstring as the original, and its + signature excludes the provided kwargs. + """ + + # Define the new function with a dynamic signature + def wrapper(**kwargs): + # Merge the provided kwargs with the fixed ones, prioritizing the former + all_kwargs = {**fixed_kwargs, **kwargs} + return func(**all_kwargs) + + # Update the wrapper function's metadata to match the original function + update_wrapper(wrapper, func) + + # Modify the signature to exclude the fixed kwargs + original_sig = inspect.signature(func) + new_params = [ + param + for param in original_sig.parameters.values() + if param.name not in fixed_kwargs + ] + wrapper.__signature__ = original_sig.replace(parameters=new_params) + + return wrapper + + +def function_to_tool_dict( + fn: Callable, + name: Optional[str] = None, + description: Optional[str] = None, +) -> dict: + """ + Creates an OpenAI-compatible tool dict from a Python function. + """ + + schema = pydantic.TypeAdapter( + fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True) + ).json_schema() + + return dict( + type="function", + function=dict( + name=name or fn.__name__, + description=inspect.cleandoc(description or fn.__doc__ or ""), + parameters=schema, + ), + ) + + +def has_tool_calls(response: litellm.ModelResponse) -> bool: + """ + Check if the model response contains tool calls. + """ + return bool(response.choices[0].message.get("tool_calls")) + + +def output_to_string(output: Any) -> str: + """ + Function outputs must be provided as strings + """ + if output is None: + output = "" + elif not isinstance(output, str): + try: + output = pydantic.TypeAdapter(type(output)).dump_json(output).decode() + except Exception: + output = str(output) + return output + + +def handle_tool_calls(response: litellm.ModelResponse, tools: list[dict, Callable]): + messages = [] + tool_lookup = {function_to_tool_dict(t)["function"]["name"]: t for t in tools} + + response_message = response.choices[0].message + tool_calls: list[litellm.utils.ChatCompletionMessageToolCall] = ( + response_message.tool_calls + ) + + for tool_call in tool_calls: + fn_name = tool_call.function.name + try: + if fn_name not in tool_lookup: + raise ValueError(f'Function "{fn_name}" not found.') + fn = tool_lookup[fn_name] + fn_args = json.loads(tool_call.function.arguments) + fn_output = fn(**fn_args) + except Exception as exc: + fn_output = f'Error calling function "{fn_name}": {exc}' + messages.append( + Message( + role="tool", + name=fn_name, + content=output_to_string(fn_output), + tool_call_id=tool_call.id, + ) + ) + + return messages + + +async def handle_tool_calls_async( + response: litellm.ModelResponse, tools: list[dict, Callable] +): + messages = [] + tools = [function_to_tool_dict(t) if not isinstance(t, dict) else t for t in tools] + tool_dict = {t["function"]["name"]: t for t in tools} + + response_message = response.choices[0].message + tool_calls: list[litellm.utils.ChatCompletionMessageToolCall] = ( + response_message.tool_calls + ) + + for tool_call in tool_calls: + fn_name = tool_call.function.name + try: + if fn_name not in tool_dict: + raise ValueError(f'Function "{fn_name}" not found.') + fn = tool_dict[fn_name] + fn_args = json.loads(tool_call.function.arguments) + fn_output = fn(**fn_args) + if inspect.isawaitable(fn_output): + fn_output = await fn_output + except Exception as exc: + fn_output = f'Error calling function "{fn_name}": {exc}' + messages.append( + Message( + role="tool", + name=fn_name, + content=output_to_string(fn_output), + tool_call_id=tool_call.id, + ) + ) + + return messages diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index b3d97cd7..5990df61 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -5,6 +5,7 @@ from copy import deepcopy from typing import Any, Optional, Union +import litellm from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -56,6 +57,9 @@ class Settings(ControlFlowSettings): ) prefect: PrefectSettings = Field(default_factory=PrefectSettings) openai_api_key: Optional[str] = Field(None, validate_assignment=True) + + # ------------ flow settings ------------ + eager_mode: bool = Field( True, description="If True, @task- and @flow-decorated functions are run immediately. " @@ -71,6 +75,13 @@ class Settings(ControlFlowSettings): description="If False, calling Task.run() outside a flow context will automatically " "create a flow and run the task within it. If True, an error will be raised.", ) + + # ------------ LLM settings ------------ + + model: str = Field("gpt-4o", description="The LLM model to use.") + + # ------------ TUI settings ------------ + enable_tui: bool = Field( False, description="If True, the TUI will be enabled. If False, the TUI will be disabled.", @@ -92,6 +103,12 @@ def _apply_api_key(cls, v): marvin.settings.openai.api_key = v return v + @field_validator("model", mode="before") + def _validate_model(cls, v): + if not litellm.supports_function_calling(model=v): + raise ValueError(f"Model '{v}' does not support function calling.") + return v + settings = Settings() diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index 569fe236..efa504c2 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -1,5 +1,6 @@ from typing import Callable, Union +from litellm import Message from marvin.beta.assistants import Assistant, Thread from marvin.beta.assistants.assistants import AssistantTool from marvin.types import FunctionTool