diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index 03aa652f..1a013c10 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -20,7 +20,7 @@ from controlflow.tui.app import TUIApp as TUI from controlflow.utilities.context import ctx from controlflow.utilities.tasks import all_complete, any_incomplete -from controlflow.utilities.types import FunctionTool, Message +from controlflow.utilities.types import FunctionTool, SystemMessage logger = logging.getLogger(__name__) @@ -119,25 +119,23 @@ async def _run_agent(self, agent: Agent, tasks: list[Task] = None): instructions = instructions_template.render() # prepare messages - system_message = Message(content=instructions, role="system") + system_message = SystemMessage(content=instructions) messages = self.history.load_messages(thread_id=self.flow.thread_id) # call llm - r = [] - async for _ in completion_stream_async( + response_messages = [] + async for msg in completion_stream_async( messages=[system_message] + messages, model=agent.model, tools=tools, handlers=[TUIHandler()] if controlflow.settings.enable_tui else None, max_iterations=1, - response_callback=r.append, ): - pass - response = r[0] + response_messages.append(msg) # save history self.history.save_messages( - thread_id=self.flow.thread_id, messages=response.messages + thread_id=self.flow.thread_id, messages=response_messages ) # create_json_artifact( diff --git a/src/controlflow/core/flow.py b/src/controlflow/core/flow.py index c458bced..529e40d8 100644 --- a/src/controlflow/core/flow.py +++ b/src/controlflow/core/flow.py @@ -6,7 +6,7 @@ from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger -from controlflow.utilities.types import ControlFlowModel, Message +from controlflow.utilities.types import ControlFlowModel, MessageType if TYPE_CHECKING: from controlflow.core.agent import Agent @@ -68,7 +68,7 @@ def get_flow() -> Optional[Flow]: return flow -def get_flow_messages(limit: int = None) -> list[Message]: +def get_flow_messages(limit: int = None) -> list[MessageType]: """ Loads messages from the flow's thread. diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 2cfb1124..592370f0 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -36,7 +36,6 @@ from controlflow.utilities.types import ( NOTSET, ControlFlowModel, - Message, PandasDataFrame, PandasSeries, ToolType, @@ -58,54 +57,6 @@ class TaskStatus(Enum): SKIPPED = "SKIPPED" -class LoadMessage(ControlFlowModel): - """ - This special object can be used to indicate that a task result should be - loaded from a recent message posted to the flow's thread. - """ - - type: Literal["LoadMessage"] = Field( - 'You must provide this value as "LoadMessage".' - ) - - num_messages_ago: int = Field( - 1, - description="The number of messages ago to retrieve. Default is 1, or the most recent message.", - ) - - strip_prefix: str = Field( - None, - description="These characters will be removed from the start " - "of the message. For example, remove text like your name prefix.", - ) - - strip_suffix: Optional[str] = Field( - None, - description="These characters will be removed from the end of " - "the message. For example, remove comments like 'I'll mark the task complete now.'", - ) - - def trim_message(self, message: Message) -> str: - content = message.content[0].text.value - if self.strip_prefix: - if content.startswith(self.strip_prefix): - content = content[len(self.strip_prefix) :] - else: - raise ValueError( - f'Invalid strip prefix "{self.strip_prefix}"; messages ' - f'starts with "{content[:len(self.strip_prefix) + 10]}"' - ) - if self.strip_suffix: - if content.endswith(self.strip_suffix): - content = content[: -len(self.strip_suffix)] - else: - raise ValueError( - f'Invalid strip suffix "{self.strip_suffix}"; messages ' - f'ends with "{content[-len(self.strip_suffix) - 10:]}"' - ) - return content.strip() - - class Task(ControlFlowModel): id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5])) objective: str = Field( diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py index 1cda1105..6515799e 100644 --- a/src/controlflow/llm/completions.py +++ b/src/controlflow/llm/completions.py @@ -1,10 +1,9 @@ import inspect import math -from typing import AsyncGenerator, Callable, Generator, Optional, Tuple, Union +from typing import AsyncGenerator, Callable, Generator, Tuple, Union import litellm from litellm.utils import trim_messages -from pydantic import field_validator import controlflow from controlflow.llm.handlers import AsyncStreamHandler, StreamHandler @@ -12,15 +11,12 @@ as_tools, get_tool_calls, handle_tool_call, - has_tool_calls, ) -from controlflow.utilities.types import ControlFlowModel, Message, ToolResult - - -def as_cf_message(message: Union[Message, litellm.Message]) -> Message: - if isinstance(message, Message): - return message - return Message(**message.model_dump()) +from controlflow.utilities.types import ( + ControlFlowMessage, + as_cf_messages, + as_oai_messages, +) async def maybe_coro(coro): @@ -28,36 +24,14 @@ async def maybe_coro(coro): await coro -class Response(ControlFlowModel): - messages: list[Message] = [] - responses: list[litellm.ModelResponse] = [] - - @field_validator("messages", mode="before") - def _validate_messages(cls, v): - return [as_cf_message(m) for m in v] - - def last_message(self) -> Optional[Message]: - return self.messages[-1] if self.messages else None - - def last_response(self) -> Optional[litellm.ModelResponse]: - return self.responses[-1] if self.responses else None - - def tool_calls(self) -> list[ToolResult]: - return [ - m["_tool_call"] - for m in self.messages - if m.role == "tool" and m.get("_tool_call") is not None - ] - - def completion( - messages: list[Union[dict, Message]], + messages: list[Union[dict, ControlFlowMessage]], model=None, tools: list[Callable] = None, max_iterations=None, handlers: list[StreamHandler] = None, **kwargs, -) -> Response: +) -> list[ControlFlowMessage]: """ Perform completion using the LLM model. @@ -65,15 +39,12 @@ def completion( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.completion function. Returns: - A Response object representing the completion response. + A list of ControlFlowMessage objects representing the completion response. """ - - response = None - responses = [] + response_messages = [] new_messages = [] if handlers is None: @@ -84,45 +55,50 @@ def completion( tools = as_tools(tools or []) - while not response or has_tool_calls(response): + counter = 0 + while not response_messages or get_tool_calls(response_messages): + completion_messages = trim_messages( + as_oai_messages(messages + new_messages), model=model + ) response = litellm.completion( model=model, - messages=trim_messages(messages + new_messages, model=model), + messages=completion_messages, tools=[t.model_dump() for t in tools] if tools else None, **kwargs, ) - responses.append(response) + response_messages = as_cf_messages([response]) # on message done for h in handlers: - h.on_message_done(response) - new_messages.append(response.choices[0].message) + for msg in response_messages: + if msg.has_tool_calls(): + h.on_tool_call_done(msg) + else: + h.on_message_done(msg) - for tool_call in get_tool_calls(response): - for h in handlers: - h.on_tool_call_done(tool_call=tool_call) + new_messages.extend(response_messages) + for tool_call in get_tool_calls(response_messages): tool_message = handle_tool_call(tool_call, tools) + + # on tool result for h in handlers: h.on_tool_result(tool_message) new_messages.append(tool_message) - if len(responses) >= (max_iterations or math.inf): + counter += 1 + if counter >= (max_iterations or math.inf): break - return Response( - messages=new_messages, - responses=responses, - ) + return new_messages def completion_stream( - messages: list[Union[dict, Message]], + messages: list[Union[dict, ControlFlowMessage]], model=None, tools: list[Callable] = None, max_iterations: int = None, handlers: list[StreamHandler] = None, - response_callback: Callable[[Response], None] = None, **kwargs, ) -> Generator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None, None]: """ @@ -132,18 +108,16 @@ def completion_stream( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.completion function. Yields: - A tuple containing the current completion delta and the snapshot of the completion response. + Each message Returns: The final completion response as a litellm.ModelResponse object. """ - response = None - responses = [] + snapshot_message = None new_messages = [] if handlers is None: @@ -154,73 +128,72 @@ def completion_stream( tools = as_tools(tools or []) - while not response or has_tool_calls(response): - deltas = [] - is_tool_call = False - for delta in litellm.completion( + counter = 0 + while not snapshot_message or get_tool_calls([snapshot_message]): + completion_messages = trim_messages( + as_oai_messages(messages + new_messages), model=model + ) + response = litellm.completion( model=model, - messages=trim_messages(messages + new_messages, model=model), + messages=completion_messages, tools=[t.model_dump() for t in tools] if tools else None, stream=True, **kwargs, - ): + ) + + deltas = [] + for delta in response: deltas.append(delta) - response = litellm.stream_chunk_builder(deltas) + snapshot = litellm.stream_chunk_builder(deltas) + delta_message, snapshot_message = as_cf_messages([delta, snapshot]) # on message created if len(deltas) == 1: - if get_tool_calls(response): - is_tool_call = True for h in handlers: - if is_tool_call: - h.on_tool_call_created(delta=delta) + if snapshot_message.has_tool_calls(): + h.on_tool_call_created(delta=delta_message) else: - h.on_message_created(delta=delta) + h.on_message_created(delta=delta_message) # on message delta for h in handlers: - if is_tool_call: - h.on_tool_call_delta(delta=delta, snapshot=response) + if snapshot_message.has_tool_calls(): + h.on_tool_call_delta(delta=delta_message, snapshot=snapshot_message) else: - h.on_message_delta(delta=delta, snapshot=response) + h.on_message_delta(delta=delta_message, snapshot=snapshot_message) - # yield - yield delta, response + yield snapshot_message - responses.append(response) + new_messages.append(snapshot_message) # on message done - if not is_tool_call: - for h in handlers: - h.on_message_done(response) - new_messages.append(response.choices[0].message) + for h in handlers: + if snapshot_message.has_tool_calls(): + h.on_tool_call_done(snapshot_message) + else: + h.on_message_done(snapshot_message) # tool calls - for tool_call in get_tool_calls(response): - for h in handlers: - h.on_tool_call_done(tool_call=tool_call) + for tool_call in get_tool_calls([snapshot_message]): tool_message = handle_tool_call(tool_call, tools) for h in handlers: h.on_tool_result(tool_message) new_messages.append(tool_message) + yield tool_message - yield None, tool_message - - if len(responses) >= (max_iterations or math.inf): + counter += 1 + if counter >= (max_iterations or math.inf): break - if response_callback: - response_callback(Response(messages=new_messages, responses=responses)) - async def completion_async( - messages: list[Union[dict, Message]], + messages: list[Union[dict, ControlFlowMessage]], model=None, tools: list[Callable] = None, max_iterations=None, handlers: list[Union[AsyncStreamHandler, StreamHandler]] = None, **kwargs, -) -> Response: +) -> list[ControlFlowMessage]: """ Perform asynchronous completion using the LLM model. @@ -228,14 +201,12 @@ async def completion_async( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. Returns: - Response + A list of ControlFlowMessage objects representing the completion response. """ - response = None - responses = [] + response_messages = [] new_messages = [] if handlers is None: @@ -246,47 +217,50 @@ async def completion_async( tools = as_tools(tools or []) - while not response or has_tool_calls(response): + counter = 0 + while not response_messages or get_tool_calls(response_messages): + completion_messages = trim_messages( + as_oai_messages(messages + new_messages), model=model + ) response = await litellm.acompletion( model=model, - messages=trim_messages(messages + new_messages, model=model), + messages=completion_messages, tools=[t.model_dump() for t in tools] if tools else None, **kwargs, ) - responses.append(response) + response_messages = as_cf_messages([response]) # on message done for h in handlers: - await maybe_coro(h.on_message_done(response)) - new_messages.append(response.choices[0].message) + for msg in response_messages: + if msg.has_tool_calls(): + await maybe_coro(h.on_tool_call_done(msg)) + else: + await maybe_coro(h.on_message_done(msg)) - for tool_call in get_tool_calls(response): - for h in handlers: - await maybe_coro(h.on_tool_call_done(tool_call=tool_call)) + new_messages.extend(response_messages) + for tool_call in get_tool_calls(response_messages): tool_message = handle_tool_call(tool_call, tools) for h in handlers: await maybe_coro(h.on_tool_result(tool_message)) new_messages.append(tool_message) - if len(responses) >= (max_iterations or math.inf): + counter += 1 + if counter >= (max_iterations or math.inf): break - return Response( - messages=new_messages, - responses=responses, - ) + return new_messages async def completion_stream_async( - messages: list[Union[dict, Message]], + messages: list[Union[dict, ControlFlowMessage]], model=None, tools: list[Callable] = None, max_iterations: int = None, handlers: list[Union[AsyncStreamHandler, StreamHandler]] = None, - response_callback: Callable[[Response], None] = None, **kwargs, -) -> AsyncGenerator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None]: +) -> AsyncGenerator[ControlFlowMessage, None]: """ Perform asynchronous streaming completion using the LLM model. @@ -294,18 +268,16 @@ async def completion_stream_async( messages: A list of messages to be used for completion. model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. - call_tools: A boolean indicating whether to use the provided tools during completion. **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. Yields: - A tuple containing the current completion delta and the snapshot of the completion response. + Each message Returns: - The final completion response as a litellm.ModelResponse object. + The final completion response as a list of ControlFlowMessage objects. """ - response = None - responses = [] + snapshot_message = None new_messages = [] if handlers is None: @@ -316,62 +288,67 @@ async def completion_stream_async( tools = as_tools(tools or []) - while not response or has_tool_calls(response): - deltas = [] - is_tool_call = False - async for delta in await litellm.acompletion( + counter = 0 + while not snapshot_message or get_tool_calls([snapshot_message]): + completion_messages = trim_messages( + as_oai_messages(messages + new_messages), model=model + ) + response = await litellm.acompletion( model=model, - messages=trim_messages(messages + new_messages, model=model), + messages=completion_messages, tools=[t.model_dump() for t in tools] if tools else None, stream=True, **kwargs, - ): + ) + + deltas = [] + async for delta in response: deltas.append(delta) - response = litellm.stream_chunk_builder(deltas) + snapshot = litellm.stream_chunk_builder(deltas) + delta_message, snapshot_message = as_cf_messages([delta, snapshot]) - # on message / tool call created + # on message created if len(deltas) == 1: - if get_tool_calls(response): - is_tool_call = True for h in handlers: - if is_tool_call: - await maybe_coro(h.on_tool_call_created(delta=delta)) + if snapshot_message.has_tool_calls(): + await maybe_coro(h.on_tool_call_created(delta=delta_message)) else: - await maybe_coro(h.on_message_created(delta=delta)) + await maybe_coro(h.on_message_created(delta=delta_message)) - # on message / tool call delta + # on message delta for h in handlers: - if is_tool_call: + if snapshot_message.has_tool_calls(): await maybe_coro( - h.on_tool_call_delta(delta=delta, snapshot=response) + h.on_tool_call_delta( + delta=delta_message, snapshot=snapshot_message + ) ) else: - await maybe_coro(h.on_message_delta(delta=delta, snapshot=response)) + await maybe_coro( + h.on_message_delta( + delta=delta_message, snapshot=snapshot_message + ) + ) - # yield - yield delta, response + yield snapshot_message - responses.append(response) + new_messages.append(snapshot_message) # on message done - if not is_tool_call: - for h in handlers: - await maybe_coro(h.on_message_done(response)) - new_messages.append(response.choices[0].message) + for h in handlers: + if snapshot_message.has_tool_calls(): + await maybe_coro(h.on_tool_call_done(snapshot_message)) + else: + await maybe_coro(h.on_message_done(snapshot_message)) # tool calls - for tool_call in get_tool_calls(response): - for h in handlers: - await maybe_coro(h.on_tool_call_done(tool_call=tool_call)) + for tool_call in get_tool_calls([snapshot_message]): tool_message = handle_tool_call(tool_call, tools) for h in handlers: await maybe_coro(h.on_tool_result(tool_message)) new_messages.append(tool_message) + yield tool_message - yield None, tool_message - - if len(responses) >= (max_iterations or math.inf): + counter += 1 + if counter >= (max_iterations or math.inf): break - - if response_callback: - response_callback(Response(messages=new_messages, responses=responses)) diff --git a/src/controlflow/llm/handlers.py b/src/controlflow/llm/handlers.py index 59f99d7c..459687a3 100644 --- a/src/controlflow/llm/handlers.py +++ b/src/controlflow/llm/handlers.py @@ -1,112 +1,89 @@ -import datetime - -import litellm - -from controlflow.llm.tools import ToolResult, get_tool_calls +from controlflow.llm.tools import get_tool_calls from controlflow.utilities.context import ctx -from controlflow.utilities.types import Message +from controlflow.utilities.types import AssistantMessage, ToolMessage class StreamHandler: - def on_message_created(self, delta: litellm.ModelResponse): + def on_message_created(self, delta: AssistantMessage): pass - def on_message_delta( - self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse - ): + def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage): pass - def on_message_done(self, response: litellm.ModelResponse): + def on_message_done(self, message: AssistantMessage): pass - def on_tool_call_created(self, delta: litellm.ModelResponse): + def on_tool_call_created(self, delta: AssistantMessage): pass - def on_tool_call_delta( - self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse - ): + def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage): pass - def on_tool_call_done(self, tool_call: Message): + def on_tool_call_done(self, message: AssistantMessage): pass - def on_tool_result(self, tool_result: ToolResult): + def on_tool_result(self, message: ToolMessage): pass class AsyncStreamHandler(StreamHandler): - async def on_message_created(self, delta: litellm.ModelResponse): + async def on_message_created(self, delta: AssistantMessage): pass async def on_message_delta( - self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse + self, delta: AssistantMessage, snapshot: AssistantMessage ): pass - async def on_message_done(self, response: litellm.ModelResponse): + async def on_message_done(self, message: AssistantMessage): pass - async def on_tool_call_created(self, delta: litellm.ModelResponse): + async def on_tool_call_created(self, delta: AssistantMessage): pass async def on_tool_call_delta( - self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse + self, delta: AssistantMessage, snapshot: AssistantMessage ): pass - async def on_tool_call_done(self, tool_call: Message): + async def on_tool_call_done(self, message: AssistantMessage): pass - async def on_tool_result(self, tool_result: ToolResult): + async def on_tool_result(self, message: ToolMessage): pass class TUIHandler(AsyncStreamHandler): async def on_message_delta( - self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse + self, delta: AssistantMessage, snapshot: AssistantMessage ) -> None: if tui := ctx.get("tui"): - tui.update_message( - m_id=snapshot.id, - message=snapshot.choices[0].message.content, - role=snapshot.choices[0].message.role, - timestamp=datetime.datetime.fromtimestamp(snapshot.created), - ) + tui.update_message(message=snapshot) async def on_tool_call_delta( - self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse + self, delta: AssistantMessage, snapshot: AssistantMessage ): if tui := ctx.get("tui"): for tool_call in get_tool_calls(snapshot): - tui.update_tool_call( - t_id=snapshot.id, - tool_name=tool_call.function.name, - tool_args=tool_call.function.arguments, - timestamp=datetime.datetime.fromtimestamp(snapshot.created), - ) - - async def on_tool_result(self, message: Message): + tui.update_message(message=snapshot) + + async def on_tool_result(self, message: ToolMessage): if tui := ctx.get("tui"): - tui.update_tool_result( - t_id=message.tool_result.tool_call_id, - tool_name=message.tool_result.tool_name, - tool_result=message.content, - timestamp=datetime.datetime.now(), - ) + tui.update_tool_result(message=message) class PrintHandler(AsyncStreamHandler): - def on_message_created(self, delta: litellm.ModelResponse): + def on_message_created(self, delta: AssistantMessage): print(f"Created: {delta}\n") - def on_message_done(self, response: litellm.ModelResponse): - print(f"Done: {response}\n") + def on_message_done(self, message: AssistantMessage): + print(f"Done: {message}\n") - def on_tool_call_created(self, delta: litellm.ModelResponse): + def on_tool_call_created(self, delta: AssistantMessage): print(f"Tool call created: {delta}\n") - def on_tool_call_done(self, tool_call: Message): - print(f"Tool call: {tool_call}\n") + def on_tool_call_done(self, message: AssistantMessage): + print(f"Tool call: {message}\n") - def on_tool_result(self, tool_result: ToolResult): - print(f"Tool result: {tool_result}\n") + def on_tool_result(self, message: ToolMessage): + print(f"Tool result: {message}\n") diff --git a/src/controlflow/llm/history.py b/src/controlflow/llm/history.py index 3cefa6c4..52efdd98 100644 --- a/src/controlflow/llm/history.py +++ b/src/controlflow/llm/history.py @@ -9,7 +9,7 @@ from pydantic import Field, field_validator import controlflow -from controlflow.utilities.types import ControlFlowModel, Message +from controlflow.utilities.types import ControlFlowModel, MessageType def get_default_history() -> "History": @@ -24,12 +24,12 @@ def load_messages( limit: int = None, before: datetime.datetime = None, after: datetime.datetime = None, - ) -> list[Message]: + ) -> list[MessageType]: raise NotImplementedError() def load_messages_to_token_limit( self, thread_id: str, model: str = None - ) -> list[Message]: + ) -> list[MessageType]: messages = [] # as long as the messages are not trimmed, keep loading more while messages == (trim := trim_messages(messages, model=model)): @@ -42,12 +42,12 @@ def load_messages_to_token_limit( return trim @abc.abstractmethod - def save_messages(self, thread_id: str, messages: list[Message]): + def save_messages(self, thread_id: str, messages: list[MessageType]): raise NotImplementedError() class InMemoryHistory(History): - _history: ClassVar[dict[str, list[Message]]] = {} + _history: ClassVar[dict[str, list[MessageType]]] = {} def __init__(self, **kwargs): super().__init__(**kwargs) @@ -58,7 +58,7 @@ def load_messages( limit: int = None, before: datetime.datetime = None, after: datetime.datetime = None, - ) -> list[Message]: + ) -> list[MessageType]: messages = InMemoryHistory._history.get(thread_id, []) filtered_messages = [ msg @@ -69,7 +69,7 @@ def load_messages( ] return list(reversed(filtered_messages)) - def save_messages(self, thread_id: str, messages: list[Message]): + def save_messages(self, thread_id: str, messages: list[MessageType]): InMemoryHistory._history.setdefault(thread_id, []).extend(messages) @@ -94,7 +94,7 @@ def load_messages( limit: int = None, before: datetime.datetime = None, after: datetime.datetime = None, - ) -> list[Message]: + ) -> list[MessageType]: if not self.path(thread_id).exists(): return [] @@ -103,7 +103,7 @@ def load_messages( messages = [] for msg in reversed(all_messages): - message = Message.model_validate(msg) + message = MessageType.model_validate(msg) if before is None or message.timestamp < before: if after is None or message.timestamp > after: messages.append(message) @@ -112,7 +112,7 @@ def load_messages( return list(reversed(messages)) - def save_messages(self, thread_id: str, messages: list[Message]): + def save_messages(self, thread_id: str, messages: list[MessageType]): if self.path(thread_id).exists(): with open(self.path(thread_id), "r") as f: all_messages = json.load(f) diff --git a/src/controlflow/llm/tools.py b/src/controlflow/llm/tools.py index 62808817..70a1c41e 100644 --- a/src/controlflow/llm/tools.py +++ b/src/controlflow/llm/tools.py @@ -4,10 +4,15 @@ from functools import partial, update_wrapper from typing import Any, Callable, Optional, Union -import litellm import pydantic -from controlflow.utilities.types import Message, Tool, ToolResult +from controlflow.utilities.types import ( + AssistantMessage, + ControlFlowMessage, + Tool, + ToolCall, + ToolMessage, +) def tool( @@ -77,13 +82,6 @@ def wrapper(**kwargs): return wrapper -def has_tool_calls(response: litellm.ModelResponse) -> bool: - """ - Check if the model response contains tool calls. - """ - return bool(response.choices[0].message.get("tool_calls")) - - def output_to_string(output: Any) -> str: """ Function outputs must be provided as strings @@ -99,55 +97,52 @@ def output_to_string(output: Any) -> str: def get_tool_calls( - response: litellm.ModelResponse, -) -> list[litellm.utils.ChatCompletionMessageToolCall]: - return response.choices[0].message.get("tool_calls", []) + messages: list[ControlFlowMessage], +) -> list[ToolCall]: + return [ + tc + for m in messages + if isinstance(m, AssistantMessage) and m.tool_calls + for tc in m.tool_calls + ] -def handle_tool_call( - tool_call: litellm.utils.ChatCompletionMessageToolCall, tools: list[dict, Callable] -) -> Message: +def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMessage: tool_lookup = as_tool_lookup(tools) fn_name = tool_call.function.name fn_args = (None,) try: - is_error = False + tool_failed = False if fn_name not in tool_lookup: fn_output = f'Function "{fn_name}" not found.' - is_error = True + tool_failed = True else: tool = tool_lookup[fn_name] fn_args = json.loads(tool_call.function.arguments) fn_output = tool(**fn_args) except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' - is_error = True - return Message( - role="tool", - name=fn_name, + tool_failed = True + return ToolMessage( content=output_to_string(fn_output), tool_call_id=tool_call.id, - tool_result=ToolResult( - tool_call_id=tool_call.id, - tool_name=fn_name, - tool=tool, - args=fn_args, - result=fn_output, - ), + tool_call=tool_call, + tool_result=fn_output, + tool_failed=tool_failed, ) async def handle_tool_call_async( - tool_call: litellm.utils.ChatCompletionMessageToolCall, tools: list[dict, Callable] -) -> Message: + tool_call: ToolCall, tools: list[dict, Callable] +) -> ToolMessage: tool_lookup = as_tool_lookup(tools) fn_name = tool_call.function.name fn_args = (None,) try: - is_error = False + tool_failed = False if fn_name not in tool_lookup: fn_output = f'Function "{fn_name}" not found.' - is_error = True + tool_failed = True else: tool = tool_lookup[fn_name] fn_args = json.loads(tool_call.function.arguments) @@ -156,18 +151,11 @@ async def handle_tool_call_async( fn_output = await fn_output except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' - is_error = True - return Message( - role="tool", - name=fn_name, + tool_failed = True + return ToolMessage( content=output_to_string(fn_output), tool_call_id=tool_call.id, - tool_result=ToolResult( - tool_call_id=tool_call.id, - tool_name=fn_name, - tool=tool, - args=fn_args, - is_error=is_error, - result=fn_output, - ), + tool_call=tool_call, + tool_result=fn_output, + tool_failed=tool_failed, ) diff --git a/src/controlflow/tui/app.py b/src/controlflow/tui/app.py index 7950338f..2fd58fcf 100644 --- a/src/controlflow/tui/app.py +++ b/src/controlflow/tui/app.py @@ -1,7 +1,6 @@ import asyncio -import datetime from contextlib import asynccontextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from textual.app import App, ComposeResult from textual.containers import Container @@ -10,10 +9,11 @@ from textual.widgets import Footer, Header, Label import controlflow +from controlflow.utilities.types import AssistantMessage, ToolMessage, UserMessage from .basic import Column, Row from .task import TUITask -from .thread import TUIMessage, TUIToolCall, TUIToolResult +from .thread import TUIMessage, TUIToolMessage if TYPE_CHECKING: import controlflow @@ -108,54 +108,23 @@ def update_task(self, task: "controlflow.Task"): self.query_one("#tasks-container", Column).mount(new_task) new_task.scroll_visible() - def update_message( - self, m_id: str, message: str, role: str, timestamp: datetime.datetime = None - ): + def update_message(self, message: Union[UserMessage, AssistantMessage]): try: - component = self.query_one(f"#message-{m_id}", TUIMessage) + component = self.query_one(f"#message-{message.id}", TUIMessage) component.message = message component.scroll_visible() except NoMatches: - new_message = TUIMessage( - message=message, - role=role, - timestamp=timestamp, - id=f"message-{m_id}", - ) + new_message = TUIMessage(message=message, id=f"message-{message.id}") self.query_one("#thread-container", Column).mount(new_message) new_message.scroll_visible() - def update_tool_call( - self, t_id: str, tool_name: str, tool_args: str, timestamp: datetime.datetime - ): + def update_tool_result(self, message: ToolMessage): try: - component = self.query_one(f"#tool-call-{t_id}", TUIToolCall) - component.tool_args = tool_args - component.scroll_visible() - except NoMatches: - new_step = TUIToolCall( - tool_name=tool_name, - tool_args=tool_args, - timestamp=timestamp, - id=f"tool-call-{t_id}", - ) - self.query_one("#thread-container", Column).mount(new_step) - new_step.scroll_visible() - - def update_tool_result( - self, t_id: str, tool_name: str, tool_result: str, timestamp: datetime.datetime - ): - try: - component = self.query_one(f"#tool-result-{t_id}", TUIToolResult) - component.tool_result = tool_result + component = self.query_one(f"#message-{message.id}", TUIToolMessage) + component.message = message component.scroll_visible() except NoMatches: - new_step = TUIToolResult( - tool_name=tool_name, - tool_result=tool_result, - timestamp=timestamp, - id=f"tool-result-{t_id}", - ) + new_step = TUIToolMessage(message=message, id=f"message-{message.id}") self.query_one("#thread-container", Column).mount(new_step) new_step.scroll_visible() diff --git a/src/controlflow/tui/thread.py b/src/controlflow/tui/thread.py index f9851a9f..a4cb22f7 100644 --- a/src/controlflow/tui/thread.py +++ b/src/controlflow/tui/thread.py @@ -1,14 +1,16 @@ import datetime import inspect -from typing import Literal +from typing import Union from rich import box +from rich.console import Group from rich.markdown import Markdown from rich.panel import Panel from textual.reactive import reactive from textual.widgets import Static from controlflow.core.task import TaskStatus +from controlflow.utilities.types import AssistantMessage, ToolMessage, UserMessage def bool_to_emoji(value: bool) -> str: @@ -33,20 +35,12 @@ def format_timestamp(timestamp: datetime.datetime) -> str: class TUIMessage(Static): - message: reactive[str] = reactive(None, always_update=True) - - def __init__( - self, - message: str, - role: Literal["user", "assistant"] = "assistant", - timestamp: datetime.datetime = None, - **kwargs, - ): + message: reactive[Union[UserMessage, AssistantMessage]] = reactive( + None, always_update=True + ) + + def __init__(self, message: Union[UserMessage, AssistantMessage], **kwargs): super().__init__(**kwargs) - if timestamp is None: - timestamp = datetime.datetime.now() - self._timestamp = timestamp - self._role = role self.message = message def render(self): @@ -54,53 +48,27 @@ def render(self): "user": "green", "assistant": "blue", } + if isinstance(self.message, AssistantMessage) and self.message.has_tool_calls(): + content = Markdown( + inspect.cleandoc(""" + :hammer_and_wrench: Calling `{name}` with the following arguments: + + ```json + {args} + ``` + """).format(name=self.tool_name, args=self.tool_args) + ) + title = "Tool Call" + else: + content = self.message.content + title = self.message.role.capitalize() return Panel( - self.message, - title=f"[bold]{self._role.capitalize()}[/]", - subtitle=f"[italic]{format_timestamp(self._timestamp)}[/]", - title_align="left", - subtitle_align="right", - border_style=role_colors.get(self._role, "red"), - box=box.ROUNDED, - width=100, - expand=True, - padding=(1, 2), - ) - - -class TUIToolCall(Static): - tool_name: reactive[str] = reactive(None, always_update=True) - tool_args: reactive[str] = reactive(None, always_update=True) - - def __init__( - self, - tool_name: str, - tool_args: str, - timestamp: datetime.datetime = None, - **kwargs, - ): - super().__init__(**kwargs) - if timestamp is None: - timestamp = datetime.datetime.now() - self._timestamp = timestamp - self.tool_name = tool_name - self.tool_args = tool_args - - def render(self): - content = inspect.cleandoc(""" - :hammer_and_wrench: Calling `{name}` with the following arguments: - - ```json - {args} - ``` - """).format(name=self.tool_name, args=self.tool_args) - return Panel( - Markdown(content), - title="Tool Call", - subtitle=f"[italic]{format_timestamp(self._timestamp)}[/]", + content, + title=f"[bold]{title}[/]", + subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]", title_align="left", subtitle_align="right", - border_style="blue", + border_style=role_colors.get(self.message.role, "red"), box=box.ROUNDED, width=100, expand=True, @@ -108,29 +76,21 @@ def render(self): ) -class TUIToolResult(Static): - tool_name: reactive[str] = reactive(None, always_update=True) - tool_result: reactive[str] = reactive(None, always_update=True) +class TUIToolMessage(Static): + message: reactive[ToolMessage] = reactive(None, always_update=True) - def __init__( - self, - tool_name: str, - tool_result: str, - timestamp: datetime.datetime = None, - **kwargs, - ): + def __init__(self, message: ToolMessage, **kwargs): super().__init__(**kwargs) - if timestamp is None: - timestamp = datetime.datetime.now() - self._timestamp = timestamp - self.tool_name = tool_name - self.tool_result = tool_result + self.message = message def render(self): - content = Markdown( - f":white_check_mark: Received output from the [markdown.code]{self.tool_name}[/] tool." - f"\n\n```json\n{self.tool_result}\n```", - ) + if self.message.tool_failed: + content = f":x: The tool call to [markdown.code]{self.message.tool_name}[/] failed." + else: + content = Group( + f":white_check_mark: Received output from the [markdown.code]{self.message.tool_call.function.name}[/] tool.\n", + Markdown(f"```json\n{self.tool_result}\n```"), + ) return Panel( content, diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index 444aea34..5537f655 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -1,8 +1,10 @@ import datetime import inspect import json +import uuid +from enum import Enum from functools import partial, update_wrapper -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import litellm import pydantic @@ -10,7 +12,18 @@ from marvin.beta.assistants.assistants import AssistantTool from marvin.types import FunctionTool from marvin.utilities.asyncio import ExposeSyncMethodsMixin -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import ( + BaseModel, + Field, + PrivateAttr, + TypeAdapter, + computed_field, + field_serializer, + field_validator, + model_validator, + validator, +) +from sqlalchemy import desc from traitlets import default # flag for unset defaults @@ -64,6 +77,9 @@ def __init__(self, *, _fn: Callable, **kwargs): def from_function( cls, fn: Callable, name: Optional[str] = None, description: Optional[str] = None ): + if name is None and fn.__name__ == "": + name = "__lambda__" + return cls( function=ToolFunction( name=name or fn.__name__, @@ -79,26 +95,190 @@ def __call__(self, *args, **kwargs): return self._fn(*args, **kwargs) -class ToolResult(ControlFlowModel): - model_config = dict(allow_arbitrary_types=True) - tool_call_id: str - tool_name: str - tool: Tool - args: dict - is_error: bool - result: Any = Field(None, exclude=True) +# ----------------------------------------------- +# Messages +# ----------------------------------------------- + + +class _OpenAIBaseType(ControlFlowModel): + model_config = dict(extra="allow") + + +Role = Literal["system", "user", "assistant", "tool"] + + +class TextContent(_OpenAIBaseType): + type: Literal["text"] = "text" + text: str + + +class ImageDetails(_OpenAIBaseType): + url: str + detail: Literal["auto", "high", "low"] = "auto" + + +class ImageContent(_OpenAIBaseType): + type: Literal["image_url"] = "image_url" + image_url: ImageDetails + +class ControlFlowMessage(_OpenAIBaseType): + # ---- begin openai fields + role: Role = Field(openai_field=True) + _openai_fields: set[str] = {"role"} + # ---- end openai fields -class Message(litellm.Message): - model_config = dict(validate_assignment=True) + id: str = Field(default_factory=lambda: uuid.uuid4().hex, repr=False) timestamp: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc), ) + llm_response: Optional[litellm.ModelResponse] = Field(None, repr=False) - tool_result: Optional[ToolResult] = None + @field_validator("role", mode="before") + def _lowercase_role(cls, v): + if isinstance(v, str): + v = v.lower() + return v - def __init__( - self, content: str, *, role: str = None, tool_result: Any = None, **kwargs - ): - super().__init__(content=content, role=role, **kwargs) - self.tool_result = tool_result + @field_validator("timestamp", mode="before") + def _validate_timestamp(cls, v): + if isinstance(v, int): + v = datetime.datetime.fromtimestamp(v) + return v + + @model_validator(mode="after") + def _finalize(self): + self._openai_fields = ( + getattr(super(), "_openai_fields", set()) | self._openai_fields + ) + return self + + @field_serializer("timestamp") + def _serialize_timestamp(self, timestamp: datetime.datetime): + return timestamp.isoformat() + + def as_openai_message(self, include: set[str] = None, **kwargs) -> dict: + include = self._openai_fields | (include or set()) + return self.model_dump(include=include, **kwargs) + + +class SystemMessage(ControlFlowMessage): + # ---- begin openai fields + role: Literal["system"] = "system" + content: str + name: Optional[str] = None + _openai_fields = {"role", "content", "name"} + # ---- end openai fields + + +class UserMessage(ControlFlowMessage): + # ---- begin openai fields + role: Literal["user"] = "user" + content: List[Union[TextContent, ImageContent]] + name: Optional[str] = None + _openai_fields = {"role", "content", "name"} + # ---- end openai fields + + @field_validator("content", mode="before") + def _validate_content(cls, v): + if isinstance(v, str): + v = [TextContent(text=v)] + return v + + +class AssistantMessage(ControlFlowMessage): + """A message from the assistant.""" + + # ---- begin openai fields + role: Literal["assistant"] = "assistant" + content: Optional[str] = None + tool_calls: Optional[List["ToolCall"]] = None + _openai_fields = {"role", "content", "tool_calls"} + # ---- end openai fields + + is_delta: bool = Field( + default=False, + description="If True, this message is a streamed delta, or chunk, of a full message.", + ) + + def has_tool_calls(self): + return bool(self.tool_calls) + + +class ToolMessage(ControlFlowMessage): + """A message for reporting the result of a tool call.""" + + # ---- begin openai fields + role: Literal["tool"] = "tool" + content: str = Field(description="The string result of the tool call.") + tool_call_id: str = Field(description="The ID of the tool call.") + _openai_fields = {"role", "content", "tool_call_id"} + # ---- end openai fields + + tool_call: "ToolCall" = Field(cf_field=True, repr=False) + tool_result: Any = Field(None, cf_field=True, exclude=True) + tool_failed: bool = Field(False, cf_field=True) + + +MessageType = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] + + +class ToolCallFunction(_OpenAIBaseType): + name: Optional[str] + arguments: str + + def json_arguments(self): + return json.loads(self.arguments) + + +class ToolCall(_OpenAIBaseType): + id: Optional[str] + type: Literal["function"] = "function" + function: ToolCallFunction + + +def as_cf_messages( + messages: list[Union[litellm.Message, litellm.ModelResponse]], +) -> list[Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]]: + message_ta = TypeAdapter( + Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] + ) + + result = [] + for msg in messages: + if isinstance(msg, ControlFlowMessage): + result.append(msg) + elif isinstance(msg, litellm.Message): + new_msg = message_ta.validate_python(msg.model_dump()) + result.append(new_msg) + elif isinstance(msg, litellm.ModelResponse): + for i, choice in enumerate(msg.choices): + # handle delta messages streaming from the assistant + if hasattr(choice, "delta"): + if choice.delta.role is None: + new_msg = AssistantMessage(is_delta=True) + else: + new_msg = AssistantMessage( + **choice.delta.model_dump(), is_delta=True + ) + else: + new_msg = message_ta.validate_python(choice.message.model_dump()) + new_msg.id = f"{msg.id}-{i}" + new_msg.timestamp = msg.created + new_msg.llm_response = msg + result.append(new_msg) + else: + raise ValueError(f"Invalid message type: {type(msg)}") + return result + + +def as_oai_messages(messages: list[Union[dict, ControlFlowMessage, litellm.Message]]): + result = [] + for msg in messages: + if isinstance(msg, ControlFlowMessage): + result.append(msg.as_openai_message()) + elif isinstance(msg, (dict, litellm.Message)): + result.append(msg) + else: + raise ValueError(f"Invalid message type: {type(msg)}") + return result diff --git a/tests/llm/test_streaming.py b/tests/llm/test_streaming.py index 604eea5a..18db6612 100644 --- a/tests/llm/test_streaming.py +++ b/tests/llm/test_streaming.py @@ -4,7 +4,7 @@ from controlflow.llm.completions import completion_stream from controlflow.llm.handlers import StreamHandler from controlflow.llm.tools import ToolResult -from controlflow.utilities.types import Message +from controlflow.utilities.types import AssistantMessage from pydantic import BaseModel @@ -30,7 +30,7 @@ def on_message_delta(self, delta: litellm.utils.Delta, snapshot: litellm.Message ) ) - def on_message_done(self, message: Message): + def on_message_done(self, message: AssistantMessage): self.calls.append( StreamCall(method="on_message_done", args=dict(message=message)) )