From f915b640517c61a12944afc3be07087e306daff3 Mon Sep 17 00:00:00 2001
From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com>
Date: Wed, 22 May 2024 15:54:26 -0400
Subject: [PATCH 1/3] Completions working with new message types

---
 src/controlflow/core/controller/controller.py |  11 +-
 src/controlflow/llm/completions.py            | 266 ++++++++----------
 src/controlflow/llm/handlers.py               |  75 ++---
 src/controlflow/llm/tools.py                  |  76 +++--
 src/controlflow/tui/app.py                    |  51 +---
 src/controlflow/tui/thread.py                 | 114 +++-----
 src/controlflow/utilities/types.py            | 208 +++++++++++++-
 7 files changed, 439 insertions(+), 362 deletions(-)

diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py
index 03aa652f..9055a3f9 100644
--- a/src/controlflow/core/controller/controller.py
+++ b/src/controlflow/core/controller/controller.py
@@ -123,21 +123,20 @@ async def _run_agent(self, agent: Agent, tasks: list[Task] = None):
         messages = self.history.load_messages(thread_id=self.flow.thread_id)
 
         # call llm
-        r = []
-        async for _ in completion_stream_async(
+        response_messages = []
+        async for msg in completion_stream_async(
             messages=[system_message] + messages,
             model=agent.model,
             tools=tools,
             handlers=[TUIHandler()] if controlflow.settings.enable_tui else None,
             max_iterations=1,
-            response_callback=r.append,
+            yield_deltas=False,
         ):
-            pass
-        response = r[0]
+            response_messages.append(msg)
 
         # save history
         self.history.save_messages(
-            thread_id=self.flow.thread_id, messages=response.messages
+            thread_id=self.flow.thread_id, messages=response_messages
         )
 
         # create_json_artifact(
diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py
index 1cda1105..4456b91a 100644
--- a/src/controlflow/llm/completions.py
+++ b/src/controlflow/llm/completions.py
@@ -1,10 +1,9 @@
 import inspect
 import math
-from typing import AsyncGenerator, Callable, Generator, Optional, Tuple, Union
+from typing import AsyncGenerator, Callable, Generator, Tuple, Union
 
 import litellm
 from litellm.utils import trim_messages
-from pydantic import field_validator
 
 import controlflow
 from controlflow.llm.handlers import AsyncStreamHandler, StreamHandler
@@ -12,9 +11,13 @@
     as_tools,
     get_tool_calls,
     handle_tool_call,
-    has_tool_calls,
 )
-from controlflow.utilities.types import ControlFlowModel, Message, ToolResult
+from controlflow.utilities.types import (
+    ControlFlowMessage,
+    Message,
+    as_cf_messages,
+    as_oai_messages,
+)
 
 
 def as_cf_message(message: Union[Message, litellm.Message]) -> Message:
@@ -28,36 +31,14 @@ async def maybe_coro(coro):
         await coro
 
 
-class Response(ControlFlowModel):
-    messages: list[Message] = []
-    responses: list[litellm.ModelResponse] = []
-
-    @field_validator("messages", mode="before")
-    def _validate_messages(cls, v):
-        return [as_cf_message(m) for m in v]
-
-    def last_message(self) -> Optional[Message]:
-        return self.messages[-1] if self.messages else None
-
-    def last_response(self) -> Optional[litellm.ModelResponse]:
-        return self.responses[-1] if self.responses else None
-
-    def tool_calls(self) -> list[ToolResult]:
-        return [
-            m["_tool_call"]
-            for m in self.messages
-            if m.role == "tool" and m.get("_tool_call") is not None
-        ]
-
-
 def completion(
-    messages: list[Union[dict, Message]],
+    messages: list[Union[dict, ControlFlowMessage]],
     model=None,
     tools: list[Callable] = None,
     max_iterations=None,
     handlers: list[StreamHandler] = None,
     **kwargs,
-) -> Response:
+) -> list[ControlFlowMessage]:
     """
     Perform completion using the LLM model.
 
@@ -65,15 +46,12 @@ def completion(
         messages: A list of messages to be used for completion.
         model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used.
         tools: A list of callable tools to be used during completion.
-        call_tools: A boolean indicating whether to use the provided tools during completion.
         **kwargs: Additional keyword arguments to be passed to the litellm.completion function.
 
     Returns:
-        A Response object representing the completion response.
+        A list of ControlFlowMessage objects representing the completion response.
     """
-
-    response = None
-    responses = []
+    response_messages = []
     new_messages = []
 
     if handlers is None:
@@ -84,45 +62,49 @@ def completion(
 
     tools = as_tools(tools or [])
 
-    while not response or has_tool_calls(response):
+    counter = 0
+    while not response_messages or get_tool_calls(response_messages):
         response = litellm.completion(
             model=model,
-            messages=trim_messages(messages + new_messages, model=model),
+            messages=trim_messages(
+                messages + as_oai_messages(new_messages), model=model
+            ),
             tools=[t.model_dump() for t in tools] if tools else None,
             **kwargs,
         )
 
-        responses.append(response)
+        response_messages = as_cf_messages([response])
 
         # on message done
         for h in handlers:
-            h.on_message_done(response)
-        new_messages.append(response.choices[0].message)
+            for msg in response_messages:
+                if msg.has_tool_calls():
+                    h.on_tool_call_done(msg)
+                else:
+                    h.on_message_done(msg)
 
-        for tool_call in get_tool_calls(response):
-            for h in handlers:
-                h.on_tool_call_done(tool_call=tool_call)
+        new_messages.extend(response_messages)
+        for tool_call in get_tool_calls(response_messages):
             tool_message = handle_tool_call(tool_call, tools)
+
+            # on tool result
             for h in handlers:
                 h.on_tool_result(tool_message)
             new_messages.append(tool_message)
 
-        if len(responses) >= (max_iterations or math.inf):
+        counter += 1
+        if counter >= (max_iterations or math.inf):
             break
 
-    return Response(
-        messages=new_messages,
-        responses=responses,
-    )
+    return new_messages
 
 
 def completion_stream(
-    messages: list[Union[dict, Message]],
+    messages: list[Union[dict, ControlFlowMessage]],
     model=None,
     tools: list[Callable] = None,
     max_iterations: int = None,
     handlers: list[StreamHandler] = None,
-    response_callback: Callable[[Response], None] = None,
     **kwargs,
 ) -> Generator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None, None]:
     """
@@ -132,18 +114,16 @@ def completion_stream(
         messages: A list of messages to be used for completion.
         model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used.
         tools: A list of callable tools to be used during completion.
-        call_tools: A boolean indicating whether to use the provided tools during completion.
         **kwargs: Additional keyword arguments to be passed to the litellm.completion function.
 
     Yields:
-        A tuple containing the current completion delta and the snapshot of the completion response.
+        Each message
 
     Returns:
         The final completion response as a litellm.ModelResponse object.
     """
 
-    response = None
-    responses = []
+    snapshot_message = None
     new_messages = []
 
     if handlers is None:
@@ -154,73 +134,71 @@ def completion_stream(
 
     tools = as_tools(tools or [])
 
-    while not response or has_tool_calls(response):
-        deltas = []
-        is_tool_call = False
-        for delta in litellm.completion(
+    counter = 0
+    while not snapshot_message or get_tool_calls([snapshot_message]):
+        response = litellm.completion(
             model=model,
-            messages=trim_messages(messages + new_messages, model=model),
+            messages=trim_messages(
+                messages + as_oai_messages(new_messages), model=model
+            ),
             tools=[t.model_dump() for t in tools] if tools else None,
             stream=True,
             **kwargs,
-        ):
+        )
+
+        deltas = []
+        for delta in response:
             deltas.append(delta)
-            response = litellm.stream_chunk_builder(deltas)
+            snapshot = litellm.stream_chunk_builder(deltas)
+            delta_message, snapshot_message = as_cf_messages([delta, snapshot])
 
             # on message created
             if len(deltas) == 1:
-                if get_tool_calls(response):
-                    is_tool_call = True
                 for h in handlers:
-                    if is_tool_call:
-                        h.on_tool_call_created(delta=delta)
+                    if snapshot_message.has_tool_calls():
+                        h.on_tool_call_created(delta=delta_message)
                     else:
-                        h.on_message_created(delta=delta)
+                        h.on_message_created(delta=delta_message)
 
             # on message delta
             for h in handlers:
-                if is_tool_call:
-                    h.on_tool_call_delta(delta=delta, snapshot=response)
+                if snapshot_message.has_tool_calls():
+                    h.on_tool_call_delta(delta=delta_message, snapshot=snapshot_message)
                 else:
-                    h.on_message_delta(delta=delta, snapshot=response)
+                    h.on_message_delta(delta=delta_message, snapshot=snapshot_message)
 
-            # yield
-            yield delta, response
+        yield snapshot_message
 
-        responses.append(response)
+        new_messages.append(snapshot_message)
 
         # on message done
-        if not is_tool_call:
-            for h in handlers:
-                h.on_message_done(response)
-        new_messages.append(response.choices[0].message)
+        for h in handlers:
+            if snapshot_message.has_tool_calls():
+                h.on_tool_call_done(snapshot_message)
+            else:
+                h.on_message_done(snapshot_message)
 
         # tool calls
-        for tool_call in get_tool_calls(response):
-            for h in handlers:
-                h.on_tool_call_done(tool_call=tool_call)
+        for tool_call in get_tool_calls([snapshot_message]):
             tool_message = handle_tool_call(tool_call, tools)
             for h in handlers:
                 h.on_tool_result(tool_message)
             new_messages.append(tool_message)
+            yield tool_message
 
-            yield None, tool_message
-
-        if len(responses) >= (max_iterations or math.inf):
+        counter += 1
+        if counter >= (max_iterations or math.inf):
             break
 
-    if response_callback:
-        response_callback(Response(messages=new_messages, responses=responses))
-
 
 async def completion_async(
-    messages: list[Union[dict, Message]],
+    messages: list[Union[dict, ControlFlowMessage]],
     model=None,
     tools: list[Callable] = None,
     max_iterations=None,
     handlers: list[Union[AsyncStreamHandler, StreamHandler]] = None,
     **kwargs,
-) -> Response:
+) -> list[ControlFlowMessage]:
     """
     Perform asynchronous completion using the LLM model.
 
@@ -228,14 +206,12 @@ async def completion_async(
         messages: A list of messages to be used for completion.
         model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used.
         tools: A list of callable tools to be used during completion.
-        call_tools: A boolean indicating whether to use the provided tools during completion.
         **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function.
 
     Returns:
-        Response
+        A list of ControlFlowMessage objects representing the completion response.
     """
-    response = None
-    responses = []
+    response_messages = []
     new_messages = []
 
     if handlers is None:
@@ -246,47 +222,49 @@ async def completion_async(
 
     tools = as_tools(tools or [])
 
-    while not response or has_tool_calls(response):
+    counter = 0
+    while not response_messages or get_tool_calls(response_messages):
         response = await litellm.acompletion(
             model=model,
-            messages=trim_messages(messages + new_messages, model=model),
+            messages=trim_messages(
+                messages + as_oai_messages(new_messages), model=model
+            ),
             tools=[t.model_dump() for t in tools] if tools else None,
             **kwargs,
         )
 
-        responses.append(response)
+        response_messages = as_cf_messages([response])
 
         # on message done
         for h in handlers:
-            await maybe_coro(h.on_message_done(response))
-        new_messages.append(response.choices[0].message)
+            for msg in response_messages:
+                if msg.has_tool_calls():
+                    await maybe_coro(h.on_tool_call_done(msg))
+                else:
+                    await maybe_coro(h.on_message_done(msg))
 
-        for tool_call in get_tool_calls(response):
-            for h in handlers:
-                await maybe_coro(h.on_tool_call_done(tool_call=tool_call))
+        new_messages.extend(response_messages)
+        for tool_call in get_tool_calls(response_messages):
             tool_message = handle_tool_call(tool_call, tools)
             for h in handlers:
                 await maybe_coro(h.on_tool_result(tool_message))
             new_messages.append(tool_message)
 
-        if len(responses) >= (max_iterations or math.inf):
+        counter += 1
+        if counter >= (max_iterations or math.inf):
             break
 
-    return Response(
-        messages=new_messages,
-        responses=responses,
-    )
+    return new_messages
 
 
 async def completion_stream_async(
-    messages: list[Union[dict, Message]],
+    messages: list[Union[dict, ControlFlowMessage]],
     model=None,
     tools: list[Callable] = None,
     max_iterations: int = None,
     handlers: list[Union[AsyncStreamHandler, StreamHandler]] = None,
-    response_callback: Callable[[Response], None] = None,
     **kwargs,
-) -> AsyncGenerator[Tuple[litellm.ModelResponse, litellm.ModelResponse], None]:
+) -> AsyncGenerator[ControlFlowMessage, None]:
     """
     Perform asynchronous streaming completion using the LLM model.
 
@@ -294,18 +272,16 @@ async def completion_stream_async(
         messages: A list of messages to be used for completion.
         model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used.
         tools: A list of callable tools to be used during completion.
-        call_tools: A boolean indicating whether to use the provided tools during completion.
         **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function.
 
     Yields:
-        A tuple containing the current completion delta and the snapshot of the completion response.
+        Each message
 
     Returns:
-        The final completion response as a litellm.ModelResponse object.
+        The final completion response as a list of ControlFlowMessage objects.
     """
 
-    response = None
-    responses = []
+    snapshot_message = None
     new_messages = []
 
     if handlers is None:
@@ -316,62 +292,66 @@ async def completion_stream_async(
 
     tools = as_tools(tools or [])
 
-    while not response or has_tool_calls(response):
-        deltas = []
-        is_tool_call = False
-        async for delta in await litellm.acompletion(
+    counter = 0
+    while not snapshot_message or get_tool_calls([snapshot_message]):
+        response = await litellm.acompletion(
             model=model,
-            messages=trim_messages(messages + new_messages, model=model),
+            messages=trim_messages(
+                messages + as_oai_messages(new_messages), model=model
+            ),
             tools=[t.model_dump() for t in tools] if tools else None,
             stream=True,
             **kwargs,
-        ):
+        )
+
+        deltas = []
+        async for delta in response:
             deltas.append(delta)
-            response = litellm.stream_chunk_builder(deltas)
+            snapshot = litellm.stream_chunk_builder(deltas)
+            delta_message, snapshot_message = as_cf_messages([delta, snapshot])
 
-            # on message / tool call created
+            # on message created
             if len(deltas) == 1:
-                if get_tool_calls(response):
-                    is_tool_call = True
                 for h in handlers:
-                    if is_tool_call:
-                        await maybe_coro(h.on_tool_call_created(delta=delta))
+                    if snapshot_message.has_tool_calls():
+                        await maybe_coro(h.on_tool_call_created(delta=delta_message))
                     else:
-                        await maybe_coro(h.on_message_created(delta=delta))
+                        await maybe_coro(h.on_message_created(delta=delta_message))
 
-            # on message / tool call delta
+            # on message delta
             for h in handlers:
-                if is_tool_call:
+                if snapshot_message.has_tool_calls():
                     await maybe_coro(
-                        h.on_tool_call_delta(delta=delta, snapshot=response)
+                        h.on_tool_call_delta(
+                            delta=delta_message, snapshot=snapshot_message
+                        )
                     )
                 else:
-                    await maybe_coro(h.on_message_delta(delta=delta, snapshot=response))
+                    await maybe_coro(
+                        h.on_message_delta(
+                            delta=delta_message, snapshot=snapshot_message
+                        )
+                    )
 
-            # yield
-            yield delta, response
+        yield snapshot_message
 
-        responses.append(response)
+        new_messages.append(snapshot_message)
 
         # on message done
-        if not is_tool_call:
-            for h in handlers:
-                await maybe_coro(h.on_message_done(response))
-        new_messages.append(response.choices[0].message)
+        for h in handlers:
+            if snapshot_message.has_tool_calls():
+                await maybe_coro(h.on_tool_call_done(snapshot_message))
+            else:
+                await maybe_coro(h.on_message_done(snapshot_message))
 
         # tool calls
-        for tool_call in get_tool_calls(response):
-            for h in handlers:
-                await maybe_coro(h.on_tool_call_done(tool_call=tool_call))
+        for tool_call in get_tool_calls([snapshot_message]):
             tool_message = handle_tool_call(tool_call, tools)
             for h in handlers:
                 await maybe_coro(h.on_tool_result(tool_message))
             new_messages.append(tool_message)
+            yield tool_message
 
-            yield None, tool_message
-
-        if len(responses) >= (max_iterations or math.inf):
+        counter += 1
+        if counter >= (max_iterations or math.inf):
             break
-
-    if response_callback:
-        response_callback(Response(messages=new_messages, responses=responses))
diff --git a/src/controlflow/llm/handlers.py b/src/controlflow/llm/handlers.py
index 59f99d7c..c0e23ec2 100644
--- a/src/controlflow/llm/handlers.py
+++ b/src/controlflow/llm/handlers.py
@@ -1,112 +1,89 @@
-import datetime
-
-import litellm
-
-from controlflow.llm.tools import ToolResult, get_tool_calls
+from controlflow.llm.tools import get_tool_calls
 from controlflow.utilities.context import ctx
-from controlflow.utilities.types import Message
+from controlflow.utilities.types import AssistantMessage, Message, ToolMessage
 
 
 class StreamHandler:
-    def on_message_created(self, delta: litellm.ModelResponse):
+    def on_message_created(self, delta: AssistantMessage):
         pass
 
-    def on_message_delta(
-        self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse
-    ):
+    def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
         pass
 
-    def on_message_done(self, response: litellm.ModelResponse):
+    def on_message_done(self, response: AssistantMessage):
         pass
 
-    def on_tool_call_created(self, delta: litellm.ModelResponse):
+    def on_tool_call_created(self, delta: AssistantMessage):
         pass
 
-    def on_tool_call_delta(
-        self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse
-    ):
+    def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
         pass
 
-    def on_tool_call_done(self, tool_call: Message):
+    def on_tool_call_done(self, tool_call: AssistantMessage):
         pass
 
-    def on_tool_result(self, tool_result: ToolResult):
+    def on_tool_result(self, tool_result: ToolMessage):
         pass
 
 
 class AsyncStreamHandler(StreamHandler):
-    async def on_message_created(self, delta: litellm.ModelResponse):
+    async def on_message_created(self, delta: AssistantMessage):
         pass
 
     async def on_message_delta(
-        self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse
+        self, delta: AssistantMessage, snapshot: AssistantMessage
     ):
         pass
 
-    async def on_message_done(self, response: litellm.ModelResponse):
+    async def on_message_done(self, response: AssistantMessage):
         pass
 
-    async def on_tool_call_created(self, delta: litellm.ModelResponse):
+    async def on_tool_call_created(self, delta: AssistantMessage):
         pass
 
     async def on_tool_call_delta(
-        self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse
+        self, delta: AssistantMessage, snapshot: AssistantMessage
     ):
         pass
 
-    async def on_tool_call_done(self, tool_call: Message):
+    async def on_tool_call_done(self, tool_call: AssistantMessage):
         pass
 
-    async def on_tool_result(self, tool_result: ToolResult):
+    async def on_tool_result(self, tool_result: ToolMessage):
         pass
 
 
 class TUIHandler(AsyncStreamHandler):
     async def on_message_delta(
-        self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse
+        self, delta: AssistantMessage, snapshot: AssistantMessage
     ) -> None:
         if tui := ctx.get("tui"):
-            tui.update_message(
-                m_id=snapshot.id,
-                message=snapshot.choices[0].message.content,
-                role=snapshot.choices[0].message.role,
-                timestamp=datetime.datetime.fromtimestamp(snapshot.created),
-            )
+            tui.update_message(message=snapshot)
 
     async def on_tool_call_delta(
-        self, delta: litellm.ModelResponse, snapshot: litellm.ModelResponse
+        self, delta: AssistantMessage, snapshot: AssistantMessage
     ):
         if tui := ctx.get("tui"):
             for tool_call in get_tool_calls(snapshot):
-                tui.update_tool_call(
-                    t_id=snapshot.id,
-                    tool_name=tool_call.function.name,
-                    tool_args=tool_call.function.arguments,
-                    timestamp=datetime.datetime.fromtimestamp(snapshot.created),
-                )
+                tui.update_message(message=snapshot)
 
     async def on_tool_result(self, message: Message):
         if tui := ctx.get("tui"):
-            tui.update_tool_result(
-                t_id=message.tool_result.tool_call_id,
-                tool_name=message.tool_result.tool_name,
-                tool_result=message.content,
-                timestamp=datetime.datetime.now(),
-            )
+            tui.update_tool_result(message=message)
 
 
 class PrintHandler(AsyncStreamHandler):
-    def on_message_created(self, delta: litellm.ModelResponse):
+    def on_message_created(self, delta: AssistantMessage):
         print(f"Created: {delta}\n")
 
-    def on_message_done(self, response: litellm.ModelResponse):
+    def on_message_done(self, response: AssistantMessage):
         print(f"Done: {response}\n")
 
-    def on_tool_call_created(self, delta: litellm.ModelResponse):
+    def on_tool_call_created(self, delta: AssistantMessage):
         print(f"Tool call created: {delta}\n")
 
-    def on_tool_call_done(self, tool_call: Message):
+    def on_tool_call_done(self, tool_call: AssistantMessage):
         print(f"Tool call: {tool_call}\n")
 
-    def on_tool_result(self, tool_result: ToolResult):
+    def on_tool_result(self, tool_result: ToolMessage):
         print(f"Tool result: {tool_result}\n")
diff --git a/src/controlflow/llm/tools.py b/src/controlflow/llm/tools.py
index 62808817..70a1c41e 100644
--- a/src/controlflow/llm/tools.py
+++ b/src/controlflow/llm/tools.py
@@ -4,10 +4,15 @@
 from functools import partial, update_wrapper
 from typing import Any, Callable, Optional, Union
 
-import litellm
 import pydantic
 
-from controlflow.utilities.types import Message, Tool, ToolResult
+from controlflow.utilities.types import (
+    AssistantMessage,
+    ControlFlowMessage,
+    Tool,
+    ToolCall,
+    ToolMessage,
+)
 
 
 def tool(
@@ -77,13 +82,6 @@ def wrapper(**kwargs):
     return wrapper
 
 
-def has_tool_calls(response: litellm.ModelResponse) -> bool:
-    """
-    Check if the model response contains tool calls.
-    """
-    return bool(response.choices[0].message.get("tool_calls"))
-
-
 def output_to_string(output: Any) -> str:
     """
     Function outputs must be provided as strings
@@ -99,55 +97,52 @@ def output_to_string(output: Any) -> str:
 
 
 def get_tool_calls(
-    response: litellm.ModelResponse,
-) -> list[litellm.utils.ChatCompletionMessageToolCall]:
-    return response.choices[0].message.get("tool_calls", [])
+    messages: list[ControlFlowMessage],
+) -> list[ToolCall]:
+    return [
+        tc
+        for m in messages
+        if isinstance(m, AssistantMessage) and m.tool_calls
+        for tc in m.tool_calls
+    ]
 
 
-def handle_tool_call(
-    tool_call: litellm.utils.ChatCompletionMessageToolCall, tools: list[dict, Callable]
-) -> Message:
+def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMessage:
     tool_lookup = as_tool_lookup(tools)
     fn_name = tool_call.function.name
     fn_args = (None,)
     try:
-        is_error = False
+        tool_failed = False
         if fn_name not in tool_lookup:
             fn_output = f'Function "{fn_name}" not found.'
-            is_error = True
+            tool_failed = True
         else:
             tool = tool_lookup[fn_name]
             fn_args = json.loads(tool_call.function.arguments)
             fn_output = tool(**fn_args)
     except Exception as exc:
         fn_output = f'Error calling function "{fn_name}": {exc}'
-        is_error = True
-    return Message(
-        role="tool",
-        name=fn_name,
+        tool_failed = True
+    return ToolMessage(
         content=output_to_string(fn_output),
         tool_call_id=tool_call.id,
-        tool_result=ToolResult(
-            tool_call_id=tool_call.id,
-            tool_name=fn_name,
-            tool=tool,
-            args=fn_args,
-            result=fn_output,
-        ),
+        tool_call=tool_call,
+        tool_result=fn_output,
+        tool_failed=tool_failed,
     )
 
 
 async def handle_tool_call_async(
-    tool_call: litellm.utils.ChatCompletionMessageToolCall, tools: list[dict, Callable]
-) -> Message:
+    tool_call: ToolCall, tools: list[dict, Callable]
+) -> ToolMessage:
     tool_lookup = as_tool_lookup(tools)
     fn_name = tool_call.function.name
     fn_args = (None,)
     try:
-        is_error = False
+        tool_failed = False
         if fn_name not in tool_lookup:
             fn_output = f'Function "{fn_name}" not found.'
-            is_error = True
+            tool_failed = True
         else:
             tool = tool_lookup[fn_name]
             fn_args = json.loads(tool_call.function.arguments)
@@ -156,18 +151,11 @@ async def handle_tool_call_async(
                 fn_output = await fn_output
     except Exception as exc:
         fn_output = f'Error calling function "{fn_name}": {exc}'
-        is_error = True
-    return Message(
-        role="tool",
-        name=fn_name,
+        tool_failed = True
+    return ToolMessage(
         content=output_to_string(fn_output),
         tool_call_id=tool_call.id,
-        tool_result=ToolResult(
-            tool_call_id=tool_call.id,
-            tool_name=fn_name,
-            tool=tool,
-            args=fn_args,
-            is_error=is_error,
-            result=fn_output,
-        ),
+        tool_call=tool_call,
+        tool_result=fn_output,
+        tool_failed=tool_failed,
     )
diff --git a/src/controlflow/tui/app.py b/src/controlflow/tui/app.py
index 7950338f..2fd58fcf 100644
--- a/src/controlflow/tui/app.py
+++ b/src/controlflow/tui/app.py
@@ -1,7 +1,6 @@
 import asyncio
-import datetime
 from contextlib import asynccontextmanager
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Union
 
 from textual.app import App, ComposeResult
 from textual.containers import Container
@@ -10,10 +9,11 @@
 from textual.widgets import Footer, Header, Label
 
 import controlflow
+from controlflow.utilities.types import AssistantMessage, ToolMessage, UserMessage
 
 from .basic import Column, Row
 from .task import TUITask
-from .thread import TUIMessage, TUIToolCall, TUIToolResult
+from .thread import TUIMessage, TUIToolMessage
 
 if TYPE_CHECKING:
     import controlflow
@@ -108,54 +108,23 @@ def update_task(self, task: "controlflow.Task"):
             self.query_one("#tasks-container", Column).mount(new_task)
             new_task.scroll_visible()
 
-    def update_message(
-        self, m_id: str, message: str, role: str, timestamp: datetime.datetime = None
-    ):
+    def update_message(self, message: Union[UserMessage, AssistantMessage]):
         try:
-            component = self.query_one(f"#message-{m_id}", TUIMessage)
+            component = self.query_one(f"#message-{message.id}", TUIMessage)
             component.message = message
             component.scroll_visible()
         except NoMatches:
-            new_message = TUIMessage(
-                message=message,
-                role=role,
-                timestamp=timestamp,
-                id=f"message-{m_id}",
-            )
+            new_message = TUIMessage(message=message, id=f"message-{message.id}")
             self.query_one("#thread-container", Column).mount(new_message)
             new_message.scroll_visible()
 
-    def update_tool_call(
-        self, t_id: str, tool_name: str, tool_args: str, timestamp: datetime.datetime
-    ):
+    def update_tool_result(self, message: ToolMessage):
         try:
-            component = self.query_one(f"#tool-call-{t_id}", TUIToolCall)
-            component.tool_args = tool_args
-            component.scroll_visible()
-        except NoMatches:
-            new_step = TUIToolCall(
-                tool_name=tool_name,
-                tool_args=tool_args,
-                timestamp=timestamp,
-                id=f"tool-call-{t_id}",
-            )
-            self.query_one("#thread-container", Column).mount(new_step)
-            new_step.scroll_visible()
-
-    def update_tool_result(
-        self, t_id: str, tool_name: str, tool_result: str, timestamp: datetime.datetime
-    ):
-        try:
-            component = self.query_one(f"#tool-result-{t_id}", TUIToolResult)
-            component.tool_result = tool_result
+            component = self.query_one(f"#message-{message.id}", TUIToolMessage)
+            component.message = message
             component.scroll_visible()
         except NoMatches:
-            new_step = TUIToolResult(
-                tool_name=tool_name,
-                tool_result=tool_result,
-                timestamp=timestamp,
-                id=f"tool-result-{t_id}",
-            )
+            new_step = TUIToolMessage(message=message, id=f"message-{message.id}")
             self.query_one("#thread-container", Column).mount(new_step)
             new_step.scroll_visible()
 
diff --git a/src/controlflow/tui/thread.py b/src/controlflow/tui/thread.py
index f9851a9f..a4cb22f7 100644
--- a/src/controlflow/tui/thread.py
+++ b/src/controlflow/tui/thread.py
@@ -1,14 +1,16 @@
 import datetime
 import inspect
-from typing import Literal
+from typing import Union
 
 from rich import box
+from rich.console import Group
 from rich.markdown import Markdown
 from rich.panel import Panel
 from textual.reactive import reactive
 from textual.widgets import Static
 
 from controlflow.core.task import TaskStatus
+from controlflow.utilities.types import AssistantMessage, ToolMessage, UserMessage
 
 
 def bool_to_emoji(value: bool) -> str:
@@ -33,20 +35,12 @@ def format_timestamp(timestamp: datetime.datetime) -> str:
 
 
 class TUIMessage(Static):
-    message: reactive[str] = reactive(None, always_update=True)
-
-    def __init__(
-        self,
-        message: str,
-        role: Literal["user", "assistant"] = "assistant",
-        timestamp: datetime.datetime = None,
-        **kwargs,
-    ):
+    message: reactive[Union[UserMessage, AssistantMessage]] = reactive(
+        None, always_update=True
+    )
+
+    def __init__(self, message: Union[UserMessage, AssistantMessage], **kwargs):
         super().__init__(**kwargs)
-        if timestamp is None:
-            timestamp = datetime.datetime.now()
-        self._timestamp = timestamp
-        self._role = role
         self.message = message
 
     def render(self):
@@ -54,53 +48,27 @@ def render(self):
             "user": "green",
             "assistant": "blue",
         }
+        if isinstance(self.message, AssistantMessage) and self.message.has_tool_calls():
+            content = Markdown(
+                inspect.cleandoc("""
+                :hammer_and_wrench: Calling `{name}` with the following arguments:
+                
+                ```json
+                {args}
+                ```
+                """).format(name=self.tool_name, args=self.tool_args)
+            )
+            title = "Tool Call"
+        else:
+            content = self.message.content
+            title = self.message.role.capitalize()
         return Panel(
-            self.message,
-            title=f"[bold]{self._role.capitalize()}[/]",
-            subtitle=f"[italic]{format_timestamp(self._timestamp)}[/]",
-            title_align="left",
-            subtitle_align="right",
-            border_style=role_colors.get(self._role, "red"),
-            box=box.ROUNDED,
-            width=100,
-            expand=True,
-            padding=(1, 2),
-        )
-
-
-class TUIToolCall(Static):
-    tool_name: reactive[str] = reactive(None, always_update=True)
-    tool_args: reactive[str] = reactive(None, always_update=True)
-
-    def __init__(
-        self,
-        tool_name: str,
-        tool_args: str,
-        timestamp: datetime.datetime = None,
-        **kwargs,
-    ):
-        super().__init__(**kwargs)
-        if timestamp is None:
-            timestamp = datetime.datetime.now()
-        self._timestamp = timestamp
-        self.tool_name = tool_name
-        self.tool_args = tool_args
-
-    def render(self):
-        content = inspect.cleandoc("""
-            :hammer_and_wrench: Calling `{name}` with the following arguments:
-            
-            ```json
-            {args}
-            ```
-            """).format(name=self.tool_name, args=self.tool_args)
-        return Panel(
-            Markdown(content),
-            title="Tool Call",
-            subtitle=f"[italic]{format_timestamp(self._timestamp)}[/]",
+            content,
+            title=f"[bold]{title}[/]",
+            subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]",
             title_align="left",
             subtitle_align="right",
-            border_style="blue",
+            border_style=role_colors.get(self.message.role, "red"),
             box=box.ROUNDED,
             width=100,
             expand=True,
@@ -108,29 +76,21 @@ def render(self):
         )
 
 
-class TUIToolResult(Static):
-    tool_name: reactive[str] = reactive(None, always_update=True)
-    tool_result: reactive[str] = reactive(None, always_update=True)
+class TUIToolMessage(Static):
+    message: reactive[ToolMessage] = reactive(None, always_update=True)
 
-    def __init__(
-        self,
-        tool_name: str,
-        tool_result: str,
-        timestamp: datetime.datetime = None,
-        **kwargs,
-    ):
+    def __init__(self, message: ToolMessage, **kwargs):
         super().__init__(**kwargs)
-        if timestamp is None:
-            timestamp = datetime.datetime.now()
-        self._timestamp = timestamp
-        self.tool_name = tool_name
-        self.tool_result = tool_result
+        self.message = message
 
     def render(self):
-        content = Markdown(
-            f":white_check_mark: Received output from the [markdown.code]{self.tool_name}[/] tool."
-            f"\n\n```json\n{self.tool_result}\n```",
-        )
+        if self.message.tool_failed:
+            content = f":x: The tool call to [markdown.code]{self.message.tool_name}[/] failed."
+        else:
+            content = Group(
+                f":white_check_mark: Received output from the [markdown.code]{self.message.tool_call.function.name}[/] tool.\n",
+                Markdown(f"```json\n{self.tool_result}\n```"),
+            )
 
         return Panel(
             content,
diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py
index 444aea34..0fd94085 100644
--- a/src/controlflow/utilities/types.py
+++ b/src/controlflow/utilities/types.py
@@ -1,8 +1,10 @@
 import datetime
 import inspect
 import json
+import uuid
+from enum import Enum
 from functools import partial, update_wrapper
-from typing import Any, Callable, Literal, Optional, Union
+from typing import Any, Callable, List, Literal, Optional, Union
 
 import litellm
 import pydantic
@@ -10,7 +12,18 @@
 from marvin.beta.assistants.assistants import AssistantTool
 from marvin.types import FunctionTool
 from marvin.utilities.asyncio import ExposeSyncMethodsMixin
-from pydantic import BaseModel, Field, PrivateAttr
+from pydantic import (
+    BaseModel,
+    Field,
+    PrivateAttr,
+    TypeAdapter,
+    computed_field,
+    field_serializer,
+    field_validator,
+    model_validator,
+    validator,
+)
+from sqlalchemy import desc
 from traitlets import default
 
 # flag for unset defaults
@@ -64,6 +77,9 @@ def __init__(self, *, _fn: Callable, **kwargs):
     def from_function(
         cls, fn: Callable, name: Optional[str] = None, description: Optional[str] = None
     ):
+        if name is None and fn.__name__ == "<lambda>":
+            name = "__lambda__"
+
         return cls(
             function=ToolFunction(
                 name=name or fn.__name__,
@@ -102,3 +118,191 @@ def __init__(
     ):
         super().__init__(content=content, role=role, **kwargs)
         self.tool_result = tool_result
+
+
+# -----------------------------------------------
+# -----------------------------------------------
+# -----------------------------------------------
+# -----------------------------------------------
+# -----------------------------------------------
+
+
+class _OpenAIBaseType(ControlFlowModel):
+    model_config = dict(extra="allow")
+
+
+Role = Literal["system", "user", "assistant", "tool"]
+
+
+class TextContent(_OpenAIBaseType):
+    type: Literal["text"] = "text"
+    text: str
+
+
+class ImageDetails(_OpenAIBaseType):
+    url: str
+    detail: Literal["auto", "high", "low"] = "auto"
+
+
+class ImageContent(_OpenAIBaseType):
+    type: Literal["image_url"] = "image_url"
+    image_url: ImageDetails
+
+
+class ControlFlowMessage(_OpenAIBaseType):
+    # ---- begin openai fields
+    role: Role = Field(openai_field=True)
+    _openai_fields: set[str] = {"role"}
+    # ---- end openai fields
+
+    id: str = Field(default_factory=lambda: uuid.uuid4().hex, repr=False)
+    timestamp: datetime.datetime = Field(
+        default_factory=lambda: datetime.datetime.now(datetime.timezone.utc),
+    )
+    llm_response: Optional[litellm.ModelResponse] = Field(None, repr=False)
+
+    @field_validator("role", mode="before")
+    def _lowercase_role(cls, v):
+        if isinstance(v, str):
+            v = v.lower()
+        return v
+
+    @field_validator("timestamp", mode="before")
+    def _validate_timestamp(cls, v):
+        if isinstance(v, int):
+            v = datetime.datetime.fromtimestamp(v)
+        return v
+
+    @model_validator(mode="after")
+    def _finalize(self):
+        self._openai_fields = (
+            getattr(super(), "_openai_fields", set()) | self._openai_fields
+        )
+        return self
+
+    @field_serializer("timestamp")
+    def _serialize_timestamp(self, timestamp: datetime.datetime):
+        return timestamp.isoformat()
+
+    def as_openai_message(self, include: set[str] = None, **kwargs) -> dict:
+        include = self._openai_fields | (include or set())
+        return self.model_dump(include=include, **kwargs)
+
+
+class SystemMessage(ControlFlowMessage):
+    # ---- begin openai fields
+    role: Literal["system"] = "system"
+    content: str
+    name: Optional[str] = None
+    _openai_fields = {"role", "content", "name"}
+    # ---- end openai fields
+
+
+class UserMessage(ControlFlowMessage):
+    # ---- begin openai fields
+    role: Literal["user"] = "user"
+    content: List[Union[TextContent, ImageContent]]
+    name: Optional[str] = None
+    _openai_fields = {"role", "content", "name"}
+    # ---- end openai fields
+
+    @field_validator("content", mode="before")
+    def _validate_content(cls, v):
+        if isinstance(v, str):
+            v = [TextContent(text=v)]
+        return v
+
+
+class AssistantMessage(ControlFlowMessage):
+    """A message from the assistant."""
+
+    # ---- begin openai fields
+    role: Literal["assistant"] = "assistant"
+    content: Optional[str] = None
+    tool_calls: Optional[List["ToolCall"]] = None
+    _openai_fields = {"role", "content", "tool_calls"}
+    # ---- end openai fields
+
+    is_delta: bool = Field(
+        default=False,
+        description="If True, this message is a streamed delta, or chunk, of a full message.",
+    )
+
+    def has_tool_calls(self):
+        return bool(self.tool_calls)
+
+
+class ToolMessage(ControlFlowMessage):
+    """A message for reporting the result of a tool call."""
+
+    # ---- begin openai fields
+    role: Literal["tool"] = "tool"
+    content: str = Field(description="The string result of the tool call.")
+    tool_call_id: str = Field(description="The ID of the tool call.")
+    _openai_fields = {"role", "content", "tool_call_id"}
+    # ---- end openai fields
+
+    tool_call: "ToolCall" = Field(cf_field=True, repr=False)
+    tool_result: Any = Field(None, cf_field=True, exclude=True)
+    tool_failed: bool = Field(False, cf_field=True)
+
+
+class ToolCallFunction(_OpenAIBaseType):
+    name: Optional[str]
+    arguments: str
+
+    def json_arguments(self):
+        return json.loads(self.arguments)
+
+
+class ToolCall(_OpenAIBaseType):
+    id: Optional[str]
+    type: Literal["function"] = "function"
+    function: ToolCallFunction
+
+
+def as_cf_messages(
+    messages: list[Union[litellm.Message, litellm.ModelResponse]],
+) -> list[Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]]:
+    message_ta = TypeAdapter(
+        Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
+    )
+
+    result = []
+    for msg in messages:
+        if isinstance(msg, ControlFlowMessage):
+            result.append(msg)
+        elif isinstance(msg, litellm.Message):
+            new_msg = message_ta.validate_python(msg.model_dump())
+            result.append(new_msg)
+        elif isinstance(msg, litellm.ModelResponse):
+            for i, choice in enumerate(msg.choices):
+                # handle delta messages streaming from the assistant
+                if hasattr(choice, "delta"):
+                    if choice.delta.role is None:
+                        new_msg = AssistantMessage(is_delta=True)
+                    else:
+                        new_msg = AssistantMessage(
+                            **choice.delta.model_dump(), is_delta=True
+                        )
+                else:
+                    new_msg = message_ta.validate_python(choice.message.model_dump())
+                new_msg.id = f"{msg.id}-{i}"
+                new_msg.timestamp = msg.created
+                new_msg.llm_response = msg
+                result.append(new_msg)
+        else:
+            raise ValueError(f"Invalid message type: {type(msg)}")
+    return result
+
+
+def as_oai_messages(messages: list[Union[ControlFlowMessage, litellm.Message]]):
+    result = []
+    for msg in messages:
+        if isinstance(msg, ControlFlowMessage):
+            result.append(msg.as_openai_message())
+        elif isinstance(msg, litellm.Message):
+            result.append(msg)
+        else:
+            raise ValueError(f"Invalid message type: {type(msg)}")
+    return result

From e3042453169f8d5d27761d103da14cad18f1115b Mon Sep 17 00:00:00 2001
From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com>
Date: Wed, 22 May 2024 16:19:17 -0400
Subject: [PATCH 2/3] Remove old types

---
 src/controlflow/core/controller/controller.py |  4 +-
 src/controlflow/core/flow.py                  |  4 +-
 src/controlflow/core/task.py                  | 49 -------------------
 src/controlflow/llm/completions.py            |  7 ---
 src/controlflow/llm/handlers.py               | 28 +++++------
 src/controlflow/llm/history.py                | 20 ++++----
 src/controlflow/utilities/types.py            | 32 ++----------
 tests/llm/test_streaming.py                   |  4 +-
 8 files changed, 34 insertions(+), 114 deletions(-)

diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py
index 9055a3f9..47193a99 100644
--- a/src/controlflow/core/controller/controller.py
+++ b/src/controlflow/core/controller/controller.py
@@ -20,7 +20,7 @@
 from controlflow.tui.app import TUIApp as TUI
 from controlflow.utilities.context import ctx
 from controlflow.utilities.tasks import all_complete, any_incomplete
-from controlflow.utilities.types import FunctionTool, Message
+from controlflow.utilities.types import FunctionTool, SystemMessage
 
 logger = logging.getLogger(__name__)
 
@@ -119,7 +119,7 @@ async def _run_agent(self, agent: Agent, tasks: list[Task] = None):
         instructions = instructions_template.render()
 
         # prepare messages
-        system_message = Message(content=instructions, role="system")
+        system_message = SystemMessage(content=instructions)
         messages = self.history.load_messages(thread_id=self.flow.thread_id)
 
         # call llm
diff --git a/src/controlflow/core/flow.py b/src/controlflow/core/flow.py
index c458bced..529e40d8 100644
--- a/src/controlflow/core/flow.py
+++ b/src/controlflow/core/flow.py
@@ -6,7 +6,7 @@
 
 from controlflow.utilities.context import ctx
 from controlflow.utilities.logging import get_logger
-from controlflow.utilities.types import ControlFlowModel, Message
+from controlflow.utilities.types import ControlFlowModel, MessageType
 
 if TYPE_CHECKING:
     from controlflow.core.agent import Agent
@@ -68,7 +68,7 @@ def get_flow() -> Optional[Flow]:
     return flow
 
 
-def get_flow_messages(limit: int = None) -> list[Message]:
+def get_flow_messages(limit: int = None) -> list[MessageType]:
     """
     Loads messages from the flow's thread.
 
diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py
index 2cfb1124..592370f0 100644
--- a/src/controlflow/core/task.py
+++ b/src/controlflow/core/task.py
@@ -36,7 +36,6 @@
 from controlflow.utilities.types import (
     NOTSET,
     ControlFlowModel,
-    Message,
     PandasDataFrame,
     PandasSeries,
     ToolType,
@@ -58,54 +57,6 @@ class TaskStatus(Enum):
     SKIPPED = "SKIPPED"
 
 
-class LoadMessage(ControlFlowModel):
-    """
-    This special object can be used to indicate that a task result should be
-    loaded from a recent message posted to the flow's thread.
-    """
-
-    type: Literal["LoadMessage"] = Field(
-        'You must provide this value as "LoadMessage".'
-    )
-
-    num_messages_ago: int = Field(
-        1,
-        description="The number of messages ago to retrieve. Default is 1, or the most recent message.",
-    )
-
-    strip_prefix: str = Field(
-        None,
-        description="These characters will be removed from the start "
-        "of the message. For example, remove text like your name prefix.",
-    )
-
-    strip_suffix: Optional[str] = Field(
-        None,
-        description="These characters will be removed from the end of "
-        "the message. For example, remove comments like 'I'll mark the task complete now.'",
-    )
-
-    def trim_message(self, message: Message) -> str:
-        content = message.content[0].text.value
-        if self.strip_prefix:
-            if content.startswith(self.strip_prefix):
-                content = content[len(self.strip_prefix) :]
-            else:
-                raise ValueError(
-                    f'Invalid strip prefix "{self.strip_prefix}"; messages '
-                    f'starts with "{content[:len(self.strip_prefix) + 10]}"'
-                )
-        if self.strip_suffix:
-            if content.endswith(self.strip_suffix):
-                content = content[: -len(self.strip_suffix)]
-            else:
-                raise ValueError(
-                    f'Invalid strip suffix "{self.strip_suffix}"; messages '
-                    f'ends with "{content[-len(self.strip_suffix) - 10:]}"'
-                )
-        return content.strip()
-
-
 class Task(ControlFlowModel):
     id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5]))
     objective: str = Field(
diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py
index 4456b91a..b0915b46 100644
--- a/src/controlflow/llm/completions.py
+++ b/src/controlflow/llm/completions.py
@@ -14,18 +14,11 @@
 )
 from controlflow.utilities.types import (
     ControlFlowMessage,
-    Message,
     as_cf_messages,
     as_oai_messages,
 )
 
 
-def as_cf_message(message: Union[Message, litellm.Message]) -> Message:
-    if isinstance(message, Message):
-        return message
-    return Message(**message.model_dump())
-
-
 async def maybe_coro(coro):
     if inspect.isawaitable(coro):
         await coro
diff --git a/src/controlflow/llm/handlers.py b/src/controlflow/llm/handlers.py
index c0e23ec2..459687a3 100644
--- a/src/controlflow/llm/handlers.py
+++ b/src/controlflow/llm/handlers.py
@@ -1,6 +1,6 @@
 from controlflow.llm.tools import get_tool_calls
 from controlflow.utilities.context import ctx
-from controlflow.utilities.types import AssistantMessage, Message, ToolMessage
+from controlflow.utilities.types import AssistantMessage, ToolMessage
 
 
 class StreamHandler:
@@ -10,7 +10,7 @@ def on_message_created(self, delta: AssistantMessage):
     def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
         pass
 
-    def on_message_done(self, response: AssistantMessage):
+    def on_message_done(self, message: AssistantMessage):
         pass
 
     def on_tool_call_created(self, delta: AssistantMessage):
@@ -19,10 +19,10 @@ def on_tool_call_created(self, delta: AssistantMessage):
     def on_tool_call_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
         pass
 
-    def on_tool_call_done(self, tool_call: AssistantMessage):
+    def on_tool_call_done(self, message: AssistantMessage):
         pass
 
-    def on_tool_result(self, tool_result: ToolMessage):
+    def on_tool_result(self, message: ToolMessage):
         pass
 
 
@@ -35,7 +35,7 @@ async def on_message_delta(
     ):
         pass
 
-    async def on_message_done(self, response: AssistantMessage):
+    async def on_message_done(self, message: AssistantMessage):
         pass
 
     async def on_tool_call_created(self, delta: AssistantMessage):
@@ -46,10 +46,10 @@ async def on_tool_call_delta(
     ):
         pass
 
-    async def on_tool_call_done(self, tool_call: AssistantMessage):
+    async def on_tool_call_done(self, message: AssistantMessage):
         pass
 
-    async def on_tool_result(self, tool_result: ToolMessage):
+    async def on_tool_result(self, message: ToolMessage):
         pass
 
 
@@ -67,7 +67,7 @@ async def on_tool_call_delta(
             for tool_call in get_tool_calls(snapshot):
                 tui.update_message(message=snapshot)
 
-    async def on_tool_result(self, message: Message):
+    async def on_tool_result(self, message: ToolMessage):
         if tui := ctx.get("tui"):
             tui.update_tool_result(message=message)
 
@@ -76,14 +76,14 @@ class PrintHandler(AsyncStreamHandler):
     def on_message_created(self, delta: AssistantMessage):
         print(f"Created: {delta}\n")
 
-    def on_message_done(self, response: AssistantMessage):
-        print(f"Done: {response}\n")
+    def on_message_done(self, message: AssistantMessage):
+        print(f"Done: {message}\n")
 
     def on_tool_call_created(self, delta: AssistantMessage):
         print(f"Tool call created: {delta}\n")
 
-    def on_tool_call_done(self, tool_call: AssistantMessage):
-        print(f"Tool call: {tool_call}\n")
+    def on_tool_call_done(self, message: AssistantMessage):
+        print(f"Tool call: {message}\n")
 
-    def on_tool_result(self, tool_result: ToolMessage):
-        print(f"Tool result: {tool_result}\n")
+    def on_tool_result(self, message: ToolMessage):
+        print(f"Tool result: {message}\n")
diff --git a/src/controlflow/llm/history.py b/src/controlflow/llm/history.py
index 3cefa6c4..52efdd98 100644
--- a/src/controlflow/llm/history.py
+++ b/src/controlflow/llm/history.py
@@ -9,7 +9,7 @@
 from pydantic import Field, field_validator
 
 import controlflow
-from controlflow.utilities.types import ControlFlowModel, Message
+from controlflow.utilities.types import ControlFlowModel, MessageType
 
 
 def get_default_history() -> "History":
@@ -24,12 +24,12 @@ def load_messages(
         limit: int = None,
         before: datetime.datetime = None,
         after: datetime.datetime = None,
-    ) -> list[Message]:
+    ) -> list[MessageType]:
         raise NotImplementedError()
 
     def load_messages_to_token_limit(
         self, thread_id: str, model: str = None
-    ) -> list[Message]:
+    ) -> list[MessageType]:
         messages = []
         # as long as the messages are not trimmed, keep loading more
         while messages == (trim := trim_messages(messages, model=model)):
@@ -42,12 +42,12 @@ def load_messages_to_token_limit(
         return trim
 
     @abc.abstractmethod
-    def save_messages(self, thread_id: str, messages: list[Message]):
+    def save_messages(self, thread_id: str, messages: list[MessageType]):
         raise NotImplementedError()
 
 
 class InMemoryHistory(History):
-    _history: ClassVar[dict[str, list[Message]]] = {}
+    _history: ClassVar[dict[str, list[MessageType]]] = {}
 
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
@@ -58,7 +58,7 @@ def load_messages(
         limit: int = None,
         before: datetime.datetime = None,
         after: datetime.datetime = None,
-    ) -> list[Message]:
+    ) -> list[MessageType]:
         messages = InMemoryHistory._history.get(thread_id, [])
         filtered_messages = [
             msg
@@ -69,7 +69,7 @@ def load_messages(
         ]
         return list(reversed(filtered_messages))
 
-    def save_messages(self, thread_id: str, messages: list[Message]):
+    def save_messages(self, thread_id: str, messages: list[MessageType]):
         InMemoryHistory._history.setdefault(thread_id, []).extend(messages)
 
 
@@ -94,7 +94,7 @@ def load_messages(
         limit: int = None,
         before: datetime.datetime = None,
         after: datetime.datetime = None,
-    ) -> list[Message]:
+    ) -> list[MessageType]:
         if not self.path(thread_id).exists():
             return []
 
@@ -103,7 +103,7 @@ def load_messages(
 
         messages = []
         for msg in reversed(all_messages):
-            message = Message.model_validate(msg)
+            message = MessageType.model_validate(msg)
             if before is None or message.timestamp < before:
                 if after is None or message.timestamp > after:
                     messages.append(message)
@@ -112,7 +112,7 @@ def load_messages(
 
         return list(reversed(messages))
 
-    def save_messages(self, thread_id: str, messages: list[Message]):
+    def save_messages(self, thread_id: str, messages: list[MessageType]):
         if self.path(thread_id).exists():
             with open(self.path(thread_id), "r") as f:
                 all_messages = json.load(f)
diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py
index 0fd94085..6be3d7a0 100644
--- a/src/controlflow/utilities/types.py
+++ b/src/controlflow/utilities/types.py
@@ -95,35 +95,8 @@ def __call__(self, *args, **kwargs):
         return self._fn(*args, **kwargs)
 
 
-class ToolResult(ControlFlowModel):
-    model_config = dict(allow_arbitrary_types=True)
-    tool_call_id: str
-    tool_name: str
-    tool: Tool
-    args: dict
-    is_error: bool
-    result: Any = Field(None, exclude=True)
-
-
-class Message(litellm.Message):
-    model_config = dict(validate_assignment=True)
-    timestamp: datetime.datetime = Field(
-        default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
-    )
-
-    tool_result: Optional[ToolResult] = None
-
-    def __init__(
-        self, content: str, *, role: str = None, tool_result: Any = None, **kwargs
-    ):
-        super().__init__(content=content, role=role, **kwargs)
-        self.tool_result = tool_result
-
-
-# -----------------------------------------------
-# -----------------------------------------------
-# -----------------------------------------------
 # -----------------------------------------------
+# Messages
 # -----------------------------------------------
 
 
@@ -247,6 +220,9 @@ class ToolMessage(ControlFlowMessage):
     tool_failed: bool = Field(False, cf_field=True)
 
 
+MessageType = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
+
+
 class ToolCallFunction(_OpenAIBaseType):
     name: Optional[str]
     arguments: str
diff --git a/tests/llm/test_streaming.py b/tests/llm/test_streaming.py
index 604eea5a..18db6612 100644
--- a/tests/llm/test_streaming.py
+++ b/tests/llm/test_streaming.py
@@ -4,7 +4,7 @@
 from controlflow.llm.completions import completion_stream
 from controlflow.llm.handlers import StreamHandler
 from controlflow.llm.tools import ToolResult
-from controlflow.utilities.types import Message
+from controlflow.utilities.types import AssistantMessage
 from pydantic import BaseModel
 
 
@@ -30,7 +30,7 @@ def on_message_delta(self, delta: litellm.utils.Delta, snapshot: litellm.Message
             )
         )
 
-    def on_message_done(self, message: Message):
+    def on_message_done(self, message: AssistantMessage):
         self.calls.append(
             StreamCall(method="on_message_done", args=dict(message=message))
         )

From 44ca723fec16abddf92401718f0bc6ec92aee5f1 Mon Sep 17 00:00:00 2001
From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com>
Date: Wed, 22 May 2024 17:20:22 -0400
Subject: [PATCH 3/3] Clean up completions

---
 src/controlflow/core/controller/controller.py |  1 -
 src/controlflow/llm/completions.py            | 28 +++++++++++--------
 src/controlflow/utilities/types.py            |  4 +--
 3 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py
index 47193a99..1a013c10 100644
--- a/src/controlflow/core/controller/controller.py
+++ b/src/controlflow/core/controller/controller.py
@@ -130,7 +130,6 @@ async def _run_agent(self, agent: Agent, tasks: list[Task] = None):
             tools=tools,
             handlers=[TUIHandler()] if controlflow.settings.enable_tui else None,
             max_iterations=1,
-            yield_deltas=False,
         ):
             response_messages.append(msg)
 
diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py
index b0915b46..6515799e 100644
--- a/src/controlflow/llm/completions.py
+++ b/src/controlflow/llm/completions.py
@@ -57,11 +57,12 @@ def completion(
 
     counter = 0
     while not response_messages or get_tool_calls(response_messages):
+        completion_messages = trim_messages(
+            as_oai_messages(messages + new_messages), model=model
+        )
         response = litellm.completion(
             model=model,
-            messages=trim_messages(
-                messages + as_oai_messages(new_messages), model=model
-            ),
+            messages=completion_messages,
             tools=[t.model_dump() for t in tools] if tools else None,
             **kwargs,
         )
@@ -129,11 +130,12 @@ def completion_stream(
 
     counter = 0
     while not snapshot_message or get_tool_calls([snapshot_message]):
+        completion_messages = trim_messages(
+            as_oai_messages(messages + new_messages), model=model
+        )
         response = litellm.completion(
             model=model,
-            messages=trim_messages(
-                messages + as_oai_messages(new_messages), model=model
-            ),
+            messages=completion_messages,
             tools=[t.model_dump() for t in tools] if tools else None,
             stream=True,
             **kwargs,
@@ -217,11 +219,12 @@ async def completion_async(
 
     counter = 0
     while not response_messages or get_tool_calls(response_messages):
+        completion_messages = trim_messages(
+            as_oai_messages(messages + new_messages), model=model
+        )
         response = await litellm.acompletion(
             model=model,
-            messages=trim_messages(
-                messages + as_oai_messages(new_messages), model=model
-            ),
+            messages=completion_messages,
             tools=[t.model_dump() for t in tools] if tools else None,
             **kwargs,
         )
@@ -287,11 +290,12 @@ async def completion_stream_async(
 
     counter = 0
     while not snapshot_message or get_tool_calls([snapshot_message]):
+        completion_messages = trim_messages(
+            as_oai_messages(messages + new_messages), model=model
+        )
         response = await litellm.acompletion(
             model=model,
-            messages=trim_messages(
-                messages + as_oai_messages(new_messages), model=model
-            ),
+            messages=completion_messages,
             tools=[t.model_dump() for t in tools] if tools else None,
             stream=True,
             **kwargs,
diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py
index 6be3d7a0..5537f655 100644
--- a/src/controlflow/utilities/types.py
+++ b/src/controlflow/utilities/types.py
@@ -272,12 +272,12 @@ def as_cf_messages(
     return result
 
 
-def as_oai_messages(messages: list[Union[ControlFlowMessage, litellm.Message]]):
+def as_oai_messages(messages: list[Union[dict, ControlFlowMessage, litellm.Message]]):
     result = []
     for msg in messages:
         if isinstance(msg, ControlFlowMessage):
             result.append(msg.as_openai_message())
-        elif isinstance(msg, litellm.Message):
+        elif isinstance(msg, (dict, litellm.Message)):
             result.append(msg)
         else:
             raise ValueError(f"Invalid message type: {type(msg)}")