diff --git a/pyproject.toml b/pyproject.toml index fd215e77..adb28bd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,3 +83,8 @@ skip-magic-trailing-comma = false [tool.pytest.ini_options] timeout = 120 +asyncio_mode = "auto" +filterwarnings = [ + "ignore::DeprecationWarning:litellm.*", + "ignore::PendingDeprecationWarning:litellm.*", +] diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index 9f0c39c4..972dbe33 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -1,8 +1,12 @@ from .settings import settings +import controlflow.llm + +default_history = controlflow.llm.history.InMemoryHistory() from .core.flow import Flow from .core.task import Task from .core.agent import Agent from .core.controller.controller import Controller + from .instructions import instructions from .decorators import flow, task diff --git a/src/controlflow/core/agent.py b/src/controlflow/core/agent.py index ad33705b..763ab152 100644 --- a/src/controlflow/core/agent.py +++ b/src/controlflow/core/agent.py @@ -89,7 +89,7 @@ class LiteAgent(ControlFlowModel, ExposeSyncMethodsMixin): description="The model used by the agent. If not provided, the default model will be used.", ) - async def say_async(self, messages: Union[str, dict]) -> Response: + async def run_async(self, messages: Union[str, dict]) -> Response: if not isinstance(messages, list): raise ValueError("Messages must be provided as a list.") @@ -102,7 +102,7 @@ async def say_async(self, messages: Union[str, dict]) -> Response: messages=messages, model=self.model, tools=self.tools ) - async def say(self, messages: Union[str, dict]) -> Response: + async def run(self, messages: Union[str, dict]) -> Response: if not isinstance(messages, list): raise ValueError("Messages must be provided as a list.") diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index 397798eb..7e3bdc3d 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -26,6 +26,7 @@ from controlflow.core.graph import Graph from controlflow.core.task import Task from controlflow.instructions import get_instructions +from controlflow.llm.history import BaseHistory from controlflow.tui.app import TUIApp as TUI from controlflow.utilities.context import ctx from controlflow.utilities.prefect import ( @@ -34,7 +35,7 @@ wrap_prefect_tool, ) from controlflow.utilities.tasks import all_complete, any_incomplete -from controlflow.utilities.types import FunctionTool, Thread +from controlflow.utilities.types import FunctionTool logger = logging.getLogger(__name__) @@ -67,6 +68,7 @@ class Controller(BaseModel, ExposeSyncMethodsMixin): description="Tasks that the controller will complete.", ) agents: Union[list[Agent], None] = None + history: BaseHistory = Field() context: dict = {} model_config: dict = dict(extra="forbid") enable_tui: bool = Field(default_factory=lambda: controlflow.settings.enable_tui) @@ -101,20 +103,13 @@ def help_im_stuck(): return help_im_stuck - async def _run_agent( - self, agent: Agent, tasks: list[Task] = None, thread: Thread = None - ) -> Run: + async def _run_agent(self, agent: Agent, tasks: list[Task] = None) -> Run: """ Run a single agent. """ @prefect_task(task_run_name=f'Run Agent: "{agent.name}"') - async def _run_agent( - controller: Controller, - agent: Agent, - tasks: list[Task], - thread: Thread = None, - ): + async def _run_agent(controller: Controller, agent: Agent, tasks: list[Task]): from controlflow.core.controller.instruction_template import MainTemplate tasks = tasks or controller.tasks diff --git a/src/controlflow/llm/__init__.py b/src/controlflow/llm/__init__.py index e69de29b..a4df82c1 100644 --- a/src/controlflow/llm/__init__.py +++ b/src/controlflow/llm/__init__.py @@ -0,0 +1,3 @@ +import controlflow.llm.history +import controlflow.llm.tools +import controlflow.llm.completions diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py index 3e548800..2c994dfa 100644 --- a/src/controlflow/llm/completions.py +++ b/src/controlflow/llm/completions.py @@ -1,31 +1,46 @@ -from typing import AsyncGenerator, Callable, Generator, Tuple, Union +import math +from typing import AsyncGenerator, Callable, Generator, Optional, Tuple, Union import litellm +from litellm.utils import trim_messages import controlflow from controlflow.llm.tools import ( - function_to_tool_dict, + as_tools, handle_tool_calls, handle_tool_calls_async, + handle_tool_calls_gen, + handle_tool_calls_gen_async, has_tool_calls, ) -from controlflow.utilities.types import ControlFlowModel +from controlflow.utilities.types import ControlFlowModel, ToolCall class Response(ControlFlowModel): - message: litellm.Message - response: litellm.ModelResponse - intermediate_messages: list[litellm.Message] = [] - intermediate_responses: list[litellm.ModelResponse] = [] + messages: list[litellm.Message] = [] + responses: list[litellm.ModelResponse] = [] + + def last_message(self) -> Optional[litellm.Message]: + return self.messages[-1] if self.messages else None + + def last_response(self) -> Optional[litellm.ModelResponse]: + return self.responses[-1] if self.responses else None + + def tool_calls(self) -> list[ToolCall]: + return [ + m["_tool_call"] + for m in self.messages + if m.role == "tool" and m.get("_tool_call") is not None + ] def completion( messages: list[Union[dict, litellm.Message]], model=None, tools: list[Callable] = None, - use_tools=True, + max_iterations=None, **kwargs, -) -> litellm.ModelResponse: +) -> Response: """ Perform completion using the LLM model. @@ -33,53 +48,48 @@ def completion( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - use_tools: A boolean indicating whether to use the provided tools during completion. + call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.completion function. Returns: - A litellm.ModelResponse object representing the completion response. + A Response object representing the completion response. """ - intermediate_messages = [] - intermediate_responses = [] + response = None + responses = [] + new_messages = [] if model is None: model = controlflow.settings.model - tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None + tools = as_tools(tools or []) - response = litellm.completion( - model=model, - messages=messages, - tools=tool_dicts, - **kwargs, - ) - - while use_tools and has_tool_calls(response): - intermediate_responses.append(response) - intermediate_messages.append(response.choices[0].message) - tool_messages = handle_tool_calls(response, tools) - intermediate_messages.extend(tool_messages) + while not response or has_tool_calls(response): response = litellm.completion( model=model, - messages=messages + intermediate_messages, - tools=tool_dicts, + messages=trim_messages(messages + new_messages, model=model), + tools=[t.model_dump() for t in tools] if tools else None, **kwargs, ) + responses.append(response) + new_messages.append(response.choices[0].message) + new_messages.extend(handle_tool_calls(response, tools)) + + if len(responses) >= (max_iterations or math.inf): + break + return Response( - message=response.choices[0].message, - response=response, - intermediate_messages=intermediate_messages, - intermediate_responses=intermediate_responses, + messages=new_messages, + responses=responses, ) -def stream_completion( +def completion_stream( messages: list[Union[dict, litellm.Message]], model=None, tools: list[Callable] = None, - use_tools: bool = True, + max_iterations: int = None, **kwargs, ) -> Generator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None, None]: """ @@ -89,56 +99,54 @@ def stream_completion( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - use_tools: A boolean indicating whether to use the provided tools during completion. + call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.completion function. Yields: - A tuple containing the current completion chunk and the snapshot of the completion response. + A tuple containing the current completion delta and the snapshot of the completion response. Returns: The final completion response as a litellm.ModelResponse object. """ + response = None + messages = messages.copy() + if model is None: model = controlflow.settings.model - tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None - - chunks = [] - for chunk in litellm.completion( - model=model, - messages=messages, - stream=True, - tools=tool_dicts, - **kwargs, - ): - chunks.append(chunk) - snapshot = litellm.stream_chunk_builder(chunks) - yield chunk, snapshot + tools = as_tools(tools or []) - response = snapshot + i = 0 + while not response or has_tool_calls(response): + deltas = [] - while use_tools and has_tool_calls(response): - messages.append(response.choices[0].message) - tool_messages = handle_tool_calls(response, tools) - messages.extend(tool_messages) - chunks = [] - for chunk in litellm.completion( + for delta in litellm.completion( model=model, - messages=messages, - tools=tool_dicts, - stream=True**kwargs, + messages=trim_messages(messages, model=model), + tools=[t.model_dump() for t in tools] if tools else None, + stream=True, + **kwargs, ): - chunks.append(chunk) - snapshot = litellm.stream_chunk_builder(chunks) - yield chunk, snapshot - response = snapshot + deltas.append(delta) + response = litellm.stream_chunk_builder(deltas) + yield delta, response + + for tool_msg in handle_tool_calls_gen(response, tools): + messages.append(tool_msg) + yield None, tool_msg + + messages.append(response.choices[0].message) + + i += 1 + if i >= (max_iterations or math.inf): + break async def completion_async( messages: list[Union[dict, litellm.Message]], model=None, tools: list[Callable] = None, - use_tools=True, + max_iterations=None, **kwargs, ) -> Response: """ @@ -148,52 +156,44 @@ async def completion_async( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - use_tools: A boolean indicating whether to use the provided tools during completion. + call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. Returns: Response """ - intermediate_messages = [] - intermediate_responses = [] + response = None + responses = [] + new_messages = [] if model is None: model = controlflow.settings.model - tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None - - response = await litellm.acompletion( - model=model, - messages=messages, - tools=tool_dicts, - **kwargs, - ) + tools = as_tools(tools or []) - while use_tools and has_tool_calls(response): - intermediate_responses.append(response) - intermediate_messages.append(response.choices[0].message) - tool_messages = await handle_tool_calls_async(response, tools) - intermediate_messages.extend(tool_messages) + while not response or has_tool_calls(response): response = await litellm.acompletion( model=model, - messages=messages + intermediate_messages, - tools=tool_dicts, + messages=trim_messages(messages + new_messages, model=model), + tools=[t.model_dump() for t in tools] if tools else None, **kwargs, ) + responses.append(response) + new_messages.append(response.choices[0].message) + new_messages.extend(await handle_tool_calls_async(response, tools)) + return Response( - message=response.choices[0].message, - response=response, - intermediate_messages=intermediate_messages, - intermediate_responses=intermediate_responses, + messages=new_messages, + responses=responses, ) -async def stream_completion_async( +async def completion_stream_async( messages: list[Union[dict, litellm.Message]], model=None, tools: list[Callable] = None, - use_tools: bool = True, + max_iterations: int = None, **kwargs, ) -> AsyncGenerator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None]: """ @@ -203,48 +203,43 @@ async def stream_completion_async( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - use_tools: A boolean indicating whether to use the provided tools during completion. + call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. Yields: - A tuple containing the current completion chunk and the snapshot of the completion response. + A tuple containing the current completion delta and the snapshot of the completion response. Returns: The final completion response as a litellm.ModelResponse object. """ + response = None + messages = messages.copy() + if model is None: model = controlflow.settings.model - tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None - - chunks = [] - async for chunk in litellm.acompletion( - model=model, - messages=messages, - stream=True, - tools=tool_dicts, - **kwargs, - ): - chunks.append(chunk) - snapshot = litellm.stream_chunk_builder(chunks) - yield chunk, snapshot + tools = as_tools(tools or []) - response = snapshot + i = 0 + while not response or has_tool_calls(response): + deltas = [] - while use_tools and has_tool_calls(response): - messages.append(response.choices[0].message) - tool_messages = await handle_tool_calls_async(response, tools) - messages.extend(tool_messages) - chunks = [] - async for chunk in litellm.acompletion( + async for delta in litellm.acompletion( model=model, - messages=messages, - tools=tool_dicts, + messages=trim_messages(messages, model=model), + tools=[t.model_dump() for t in tools] if tools else None, stream=True, **kwargs, ): - chunks.append(chunk) - snapshot = litellm.stream_chunk_builder(chunks) - yield chunk, snapshot + deltas.append(delta) + response = litellm.stream_chunk_builder(deltas) + yield delta, response + + async for tool_msg in handle_tool_calls_gen_async(response, tools): + messages.append(tool_msg) + yield None, tool_msg + messages.append(response.choices[0].message) - response = snapshot + i += 1 + if i >= (max_iterations or math.inf): + break diff --git a/src/controlflow/llm/history.py b/src/controlflow/llm/history.py index 39367906..08499e40 100644 --- a/src/controlflow/llm/history.py +++ b/src/controlflow/llm/history.py @@ -1,14 +1,61 @@ -import uuid +import abc +import json +from pathlib import Path -from pydantic import Field +from pydantic import Field, field_validator +import controlflow from controlflow.utilities.types import ControlFlowModel, Message +_IN_MEMORY_HISTORY = dict() -class Thread(ControlFlowModel): - id: str = Field(default_factory=uuid.uuid4().hex[:8]) +class BaseHistory(ControlFlowModel, abc.ABC): + @abc.abstractmethod + def load_messages(self, thread_id: str, limit: int = None) -> list[Message]: + raise NotImplementedError() -class History(ControlFlowModel): - thread: Thread - messages: list[Message] + @abc.abstractmethod + def save_messages(self, thread_id: str, messages: list[Message]): + raise NotImplementedError() + + +class InMemoryHistory(BaseHistory): + def load_messages(self, thread_id: str, limit: int = None) -> list[Message]: + return _IN_MEMORY_HISTORY.get(thread_id)[-limit:] + + def save_messages(self, thread_id: str, messages: list[Message]): + _IN_MEMORY_HISTORY.setdefault(thread_id, []).extend(messages) + + +class FileHistory(BaseHistory): + base_path: Path = Field( + default_factory=lambda: controlflow.settings.home_path / "history" + ) + + def path(self, thread_id: str) -> Path: + return self.base_path / f"{thread_id}.json" + + @field_validator("base_path", mode="before") + def _validate_path(cls, v): + v = Path(v).expanduser() + if not v.exists(): + v.mkdir(parents=True, exist_ok=True) + return v + + def load_messages(self, thread_id: str, limit: int = None) -> list[Message]: + if not self.path(thread_id).exists(): + return [] + with open(self.path(thread_id), "r") as f: + all_messages = json.load(f) + return [Message.model_validate(msg) for msg in all_messages[-limit:]] + + def save_messages(self, thread_id: str, messages: list[Message]): + if self.path(thread_id).exists(): + with open(self.path(thread_id), "r") as f: + all_messages = json.load(f) + else: + all_messages = [] + all_messages.extend([msg.model_dump(mode="json") for msg in messages]) + with open(self.path(thread_id), "w") as f: + json.dump(all_messages, f) diff --git a/src/controlflow/llm/streaming.py b/src/controlflow/llm/streaming.py new file mode 100644 index 00000000..61950f01 --- /dev/null +++ b/src/controlflow/llm/streaming.py @@ -0,0 +1,108 @@ +import inspect +from typing import Generator, Optional + +import litellm + +from controlflow.llm.tools import ToolCall +from controlflow.utilities.types import Message + + +class StreamHandler: + def stream( + self, + gen: Generator[ + tuple[Optional[litellm.ModelResponse], litellm.ModelResponse], None, None + ], + ): + last_snapshot = None + for delta, snapshot in gen: + snapshot_message = snapshot.choices[0].message + + # handle tool call outputs + if delta is None and snapshot_message.role == "tool": + self.on_tool_call(tool_call=snapshot_message._tool_call) + self.on_message_done(snapshot_message) + continue + + delta_message = delta.choices[0].delta + + # handle new messages + if not last_snapshot or snapshot.id != last_snapshot.id: + self.on_message_created(delta_message) + + # handle updated messages + self.on_message_delta(delta=delta_message, snapshot=snapshot_message) + + # handle completed messages + if delta.choices[0].finish_reason: + self.on_message_done(snapshot_message) + + last_snapshot = snapshot + + def on_message_created(self, delta: litellm.utils.Delta): + pass + + def on_message_delta(self, delta: litellm.utils.Delta, snapshot: litellm.Message): + pass + + def on_message_done(self, message: Message): + pass + + def on_tool_call(self, tool_call: ToolCall): + pass + + +async def _maybe_coro(maybe_coro): + if inspect.isawaitable(maybe_coro): + return await maybe_coro + + +class AsyncStreamHandler(StreamHandler): + async def stream( + self, + gen: Generator[ + tuple[Optional[litellm.ModelResponse], litellm.ModelResponse], None, None + ], + ): + last_snapshot = None + async for delta, snapshot in gen: + snapshot_message = snapshot.choices[0].message + + # handle tool call outputs + if delta is None and snapshot_message.role == "tool": + await _maybe_coro( + self.on_tool_call(tool_call=snapshot_message._tool_call) + ) + await _maybe_coro(self.on_message_done(snapshot_message)) + continue + + delta_message = delta.choices[0].delta + + # handle new messages + if not last_snapshot or snapshot.id != last_snapshot.id: + await _maybe_coro(self.on_message_created(delta_message)) + + # handle updated messages + await _maybe_coro( + self.on_message_delta(delta=delta_message, snapshot=snapshot_message) + ) + + # handle completed messages + if delta.choices[0].finish_reason: + await _maybe_coro(self.on_message_done(snapshot_message)) + + last_snapshot = snapshot + + async def on_message_created(self, delta: litellm.utils.Delta): + pass + + async def on_message_delta( + self, delta: litellm.utils.Delta, snapshot: litellm.Message + ): + pass + + async def on_message_done(self, message: Message): + pass + + async def on_tool_call(self, tool_call: ToolCall): + pass diff --git a/src/controlflow/llm/tools.py b/src/controlflow/llm/tools.py index f73d36a2..336f0f82 100644 --- a/src/controlflow/llm/tools.py +++ b/src/controlflow/llm/tools.py @@ -1,12 +1,38 @@ +import datetime import inspect import json -from functools import update_wrapper -from typing import Any, Callable, Optional +from functools import partial, update_wrapper +from typing import Any, AsyncGenerator, Callable, Generator, Optional, Union, cast import litellm import pydantic -from controlflow.utilities.types import Message +from controlflow.utilities.types import Message, Tool, ToolCall + + +def tool( + fn: Optional[Callable] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, +) -> Tool: + if fn is None: + return partial(tool, name=name, description=description) + return Tool.from_function(fn, name=name, description=description) + + +def as_tools(tools: list[Union[Tool, Callable]]) -> list[Tool]: + tools = [t if isinstance(t, Tool) else tool(t) for t in tools] + if len({t.function.name for t in tools}) != len(tools): + duplicates = {t.function.name for t in tools if tools.count(t) > 1} + raise ValueError( + f"Tool names must be unique, but found duplicates: {', '.join(duplicates)}" + ) + return tools + + +def as_tool_lookup(tools: list[Union[Tool, Callable]]) -> dict[str, Tool]: + return {t.function.name: t for t in as_tools(tools)} def custom_partial(func: Callable, **fixed_kwargs: Any) -> Callable: @@ -37,29 +63,6 @@ def wrapper(**kwargs): return wrapper -def function_to_tool_dict( - fn: Callable, - name: Optional[str] = None, - description: Optional[str] = None, -) -> dict: - """ - Creates an OpenAI-compatible tool dict from a Python function. - """ - - schema = pydantic.TypeAdapter( - fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True) - ).json_schema() - - return dict( - type="function", - function=dict( - name=name or fn.__name__, - description=inspect.cleandoc(description or fn.__doc__ or ""), - parameters=schema, - ), - ) - - def has_tool_calls(response: litellm.ModelResponse) -> bool: """ Check if the model response contains tool calls. @@ -81,68 +84,80 @@ def output_to_string(output: Any) -> str: return output -def handle_tool_calls(response: litellm.ModelResponse, tools: list[dict, Callable]): - messages = [] - tool_lookup = {function_to_tool_dict(t)["function"]["name"]: t for t in tools} - - response_message = response.choices[0].message - tool_calls: list[litellm.utils.ChatCompletionMessageToolCall] = ( - response_message.tool_calls - ) +def handle_tool_calls_gen( + response: litellm.ModelResponse, tools: list[dict, Callable] +) -> Generator[Message, None, None]: + tool_lookup = as_tool_lookup(tools) - for tool_call in tool_calls: + for tool_call in response.choices[0].message.get("tool_calls", []): + tool_call = cast(litellm.utils.ChatCompletionMessageToolCall, tool_call) fn_name = tool_call.function.name try: if fn_name not in tool_lookup: raise ValueError(f'Function "{fn_name}" not found.') - fn = tool_lookup[fn_name] + tool = tool_lookup[fn_name] fn_args = json.loads(tool_call.function.arguments) - fn_output = fn(**fn_args) + fn_output = tool(**fn_args) except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' - messages.append( - Message( - role="tool", - name=fn_name, - content=output_to_string(fn_output), + + yield Message( + role="tool", + name=fn_name, + content=output_to_string(fn_output), + tool_call_id=tool_call.id, + _tool_call=ToolCall( tool_call_id=tool_call.id, - ) + tool_name=fn_name, + tool=tool, + args=fn_args, + output=fn_output, + timestamp=datetime.datetime.now(datetime.timezone.utc), + ), ) - return messages - -async def handle_tool_calls_async( +def handle_tool_calls( response: litellm.ModelResponse, tools: list[dict, Callable] -): - messages = [] - tools = [function_to_tool_dict(t) if not isinstance(t, dict) else t for t in tools] - tool_dict = {t["function"]["name"]: t for t in tools} +) -> list[Message]: + return list(handle_tool_calls_gen(response, tools)) - response_message = response.choices[0].message - tool_calls: list[litellm.utils.ChatCompletionMessageToolCall] = ( - response_message.tool_calls - ) - for tool_call in tool_calls: +async def handle_tool_calls_gen_async( + response: litellm.ModelResponse, tools: list[dict, Callable] +) -> AsyncGenerator[Message, None]: + tool_lookup = as_tool_lookup(tools) + + for tool_call in response.choices[0].message.get("tool_calls", []): + tool_call = cast(litellm.utils.ChatCompletionMessageToolCall, tool_call) fn_name = tool_call.function.name try: - if fn_name not in tool_dict: + if fn_name not in tool_lookup: raise ValueError(f'Function "{fn_name}" not found.') - fn = tool_dict[fn_name] + tool = tool_lookup[fn_name] fn_args = json.loads(tool_call.function.arguments) - fn_output = fn(**fn_args) + fn_output = tool(**fn_args) if inspect.isawaitable(fn_output): fn_output = await fn_output except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' - messages.append( - Message( - role="tool", - name=fn_name, - content=output_to_string(fn_output), + yield Message( + role="tool", + name=fn_name, + content=output_to_string(fn_output), + tool_call_id=tool_call.id, + _tool_call=ToolCall( tool_call_id=tool_call.id, - ) + tool_name=fn_name, + tool=tool, + args=fn_args, + output=fn_output, + timestamp=datetime.datetime.now(datetime.timezone.utc), + ), ) - return messages + +async def handle_tool_calls_async( + response: litellm.ModelResponse, tools: list[dict, Callable] +) -> list[Message]: + return [t async for t in handle_tool_calls_gen_async(response, tools)] diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 5990df61..2c1378a3 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -3,12 +3,16 @@ import warnings from contextlib import contextmanager from copy import deepcopy -from typing import Any, Optional, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, Union import litellm from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict +if TYPE_CHECKING: + pass + class ControlFlowSettings(BaseSettings): model_config: SettingsConfigDict = SettingsConfigDict( @@ -58,6 +62,14 @@ class Settings(ControlFlowSettings): prefect: PrefectSettings = Field(default_factory=PrefectSettings) openai_api_key: Optional[str] = Field(None, validate_assignment=True) + # ------------ home settings ------------ + + home_path: Path = Field( + "~/.controlflow", + description="The path to the ControlFlow home directory.", + validate_default=True, + ) + # ------------ flow settings ------------ eager_mode: bool = Field( @@ -103,6 +115,13 @@ def _apply_api_key(cls, v): marvin.settings.openai.api_key = v return v + @field_validator("home_path", mode="before") + def _validate_home_path(cls, v): + v = Path(v).expanduser() + if not v.exists(): + v.mkdir(parents=True, exist_ok=True) + return v + @field_validator("model", mode="before") def _validate_model(cls, v): if not litellm.supports_function_calling(model=v): diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index efa504c2..6a87e967 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -1,11 +1,15 @@ -from typing import Callable, Union +import inspect +import json +from functools import partial, update_wrapper +from typing import Any, Callable, Literal, Optional, Union -from litellm import Message +import litellm +import pydantic from marvin.beta.assistants import Assistant, Thread from marvin.beta.assistants.assistants import AssistantTool from marvin.types import FunctionTool from marvin.utilities.asyncio import ExposeSyncMethodsMixin -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr # flag for unset defaults NOTSET = "__NOTSET__" @@ -37,3 +41,54 @@ class PandasSeries(ControlFlowModel): index: list[str] = None name: str = None dtype: str = None + + +class ToolFunction(ControlFlowModel): + name: str + parameters: dict + description: str = "" + + +class Tool(ControlFlowModel): + type: Literal["function"] = "function" + function: ToolFunction + _fn: Callable = PrivateAttr() + + def __init__(self, *, _fn: Callable, **kwargs): + super().__init__(**kwargs) + self._fn = _fn + + @classmethod + def from_function( + cls, fn: Callable, name: Optional[str] = None, description: Optional[str] = None + ): + return cls( + function=ToolFunction( + name=name or fn.__name__, + description=inspect.cleandoc(description or fn.__doc__ or ""), + parameters=pydantic.TypeAdapter( + fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True) + ).json_schema(), + ), + _fn=fn, + ) + + def __call__(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + +class ToolCall(ControlFlowModel): + model_config = dict(allow_arbitrary_types=True) + tool_call_id: str + tool_name: str + tool: Tool + args: dict + output: Any + + +class Message(litellm.Message): + _tool_call: ToolCall = PrivateAttr() + + def __init__(self, *args, tool_output: Any = None, **kwargs): + super().__init__(*args, **kwargs) + self._tool_output = tool_output diff --git a/tests/fixtures/mocks.py b/tests/fixtures/mocks.py index 0ae712f5..4009476b 100644 --- a/tests/fixtures/mocks.py +++ b/tests/fixtures/mocks.py @@ -1,12 +1,20 @@ from typing import Any from unittest.mock import AsyncMock, Mock, patch +import litellm import pytest from controlflow.core.agent import Agent from controlflow.core.task import Task, TaskStatus +from controlflow.llm.completions import Response from marvin.settings import temporary_settings as temporary_marvin_settings +def new_chunk(): + chunk = litellm.ModelResponse() + chunk.choices = [litellm.utils.StreamingChoices()] + return chunk + + @pytest.fixture def prevent_openai_calls(): """Prevent any calls to the OpenAI API from being made.""" @@ -85,3 +93,123 @@ def choose_agent(agents, **kwargs): @pytest.fixture def mock_controller(mock_controller_choose_agent, mock_controller_run_agent): pass + + +@pytest.fixture +def mock_completion(monkeypatch): + """ + Mock the completion function from the LLM module. Use this fixture to set + the response value ahead of calling the completion. + + Example: + + def test_completion(mock_completion): + mock_completion.set_response("Hello, world!") + response = litellm.completion(...) + assert response == "Hello, world!" + """ + response = litellm.ModelResponse() + + def set_response(message: str): + response.choices[0].message.content = message + + def mock_func(*args, **kwargs): + return Response(responses=[response], messages=[]) + + monkeypatch.setattr("controlflow.llm.completions.completion", mock_func) + mock_func.set_response = set_response + + return mock_func + + +@pytest.fixture +def mock_completion_stream(monkeypatch): + """ + Mock the completion function from the LLM module. Use this fixture to set + the response value ahead of calling the completion. + + Example: + + def test_completion(mock_completion): + mock_completion.set_response("Hello, world!") + response = litellm.completion(...) + assert response == "Hello, world!" + """ + response = litellm.ModelResponse() + chunk = litellm.ModelResponse() + chunk.choices = [litellm.utils.StreamingChoices()] + + def set_response(message: str): + response.choices[0].message.content = message + + def mock_func_deltas(*args, **kwargs): + for c in response.choices[0].message.content: + chunk = new_chunk() + chunk.choices[0].delta.content = c + yield chunk, response + + monkeypatch.setattr( + "controlflow.llm.completions.completion_stream", mock_func_deltas + ) + mock_func_deltas.set_response = set_response + + return mock_func_deltas + + +@pytest.fixture +def mock_completion_async(monkeypatch): + """ + Mock the completion function from the LLM module. Use this fixture to set + the response value ahead of calling the completion. + + Example: + + def test_completion(mock_completion): + mock_completion.set_response("Hello, world!") + response = litellm.completion(...) + assert response == "Hello, world!" + """ + response = litellm.ModelResponse() + + def set_response(message: str): + response.choices[0].message.content = message + + async def mock_func(*args, **kwargs): + return Response(responses=[response], messages=[]) + + monkeypatch.setattr("controlflow.llm.completions.completion_async", mock_func) + mock_func.set_response = set_response + + return mock_func + + +@pytest.fixture +def mock_completion_stream_async(monkeypatch): + """ + Mock the completion function from the LLM module. Use this fixture to set + the response value ahead of calling the completion. + + Example: + + def test_completion(mock_completion): + mock_completion.set_response("Hello, world!") + response = litellm.completion(...) + assert response == "Hello, world!" + """ + response = litellm.ModelResponse() + + def set_response(message: str): + response.choices[0].message.content = message + + async def mock_func_deltas(*args, **kwargs): + for c in response.choices[0].message.content: + chunk = new_chunk() + chunk.choices[0].delta.content = c + yield chunk, response + + monkeypatch.setattr( + "controlflow.llm.completions.completion_stream_async", mock_func_deltas + ) + mock_func_deltas.set_response = set_response + + return mock_func_deltas diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/llm/test_completions.py b/tests/llm/test_completions.py new file mode 100644 index 00000000..19eef7e4 --- /dev/null +++ b/tests/llm/test_completions.py @@ -0,0 +1,38 @@ +import controlflow.llm.completions + + +def test_mock_completion(mock_completion): + mock_completion.set_response("Hello, world! xyz") + response = controlflow.llm.completions.completion(messages=[{"content": "Hello"}]) + assert response.last_response().choices[0].message.content == "Hello, world! xyz" + + +async def test_mock_completion_async(mock_completion_async): + mock_completion_async.set_response("Hello, world! xyz") + response = await controlflow.llm.completions.completion_async( + messages=[{"content": "Hello"}] + ) + assert response.last_response().choices[0].message.content == "Hello, world! xyz" + + +def test_mock_completion_stream(mock_completion_stream): + mock_completion_stream.set_response("Hello, world! xyz") + response = controlflow.llm.completions.completion_stream( + messages=[{"content": "Hello"}], + ) + deltas = [] + for delta, snapshot in response: + deltas.append(delta) + + assert [d.choices[0].delta.content for d in deltas[:5]] == ["H", "e", "l", "l", "o"] + + +async def test_mock_completion_stream_async(mock_completion_stream_async): + mock_completion_stream_async.set_response("Hello, world! xyz") + response = controlflow.llm.completions.completion_stream_async( + messages=[{"content": "Hello"}], stream=True + ) + deltas = [] + async for delta, snapshot in response: + deltas.append(delta) + assert [d.choices[0].delta.content for d in deltas[:5]] == ["H", "e", "l", "l", "o"] diff --git a/tests/llm/test_streaming.py b/tests/llm/test_streaming.py new file mode 100644 index 00000000..5cc4b0e2 --- /dev/null +++ b/tests/llm/test_streaming.py @@ -0,0 +1,53 @@ +from collections import Counter + +import litellm +from controlflow.llm.completions import completion_stream +from controlflow.llm.streaming import StreamHandler +from controlflow.llm.tools import ToolCall +from controlflow.utilities.types import Message +from pydantic import BaseModel + + +class StreamCall(BaseModel): + method: str + args: dict + + +class MockStreamHandler(StreamHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.calls: list[StreamCall] = [] + + def on_message_created(self, delta: litellm.utils.Delta): + self.calls.append( + StreamCall(method="on_message_created", args=dict(delta=delta)) + ) + + def on_message_delta(self, delta: litellm.utils.Delta, snapshot: litellm.Message): + self.calls.append( + StreamCall( + method="on_message_delta", args=dict(delta=delta, snapshot=snapshot) + ) + ) + + def on_message_done(self, message: Message): + self.calls.append( + StreamCall(method="on_message_done", args=dict(message=message)) + ) + + def on_tool_call(self, tool_call: ToolCall): + self.calls.append( + StreamCall(method="on_tool_call", args=dict(tool_call=tool_call)) + ) + + +class TestStreamHandler: + def test_stream(self): + handler = MockStreamHandler() + gen = completion_stream(messages=[{"text": "Hello"}]) + handler.stream(gen) + + method_counts = Counter(call.method for call in handler.calls) + assert method_counts["on_message_created"] == 1 + assert method_counts["on_message_delta"] == 4 + assert method_counts["on_message_done"] == 1