Skip to content

Commit

Permalink
Merge pull request #65 from jlowin/tui
Browse files Browse the repository at this point in the history
tui running
  • Loading branch information
jlowin authored May 22, 2024
2 parents 50f7de8 + 1fee220 commit 859187b
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 249 deletions.
116 changes: 5 additions & 111 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
import datetime
import json
import logging
import math
from contextlib import asynccontextmanager
from functools import cached_property
from typing import Union

import prefect
from marvin.beta.assistants import EndRun, PrintHandler, Run
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from openai import AsyncAssistantEventHandler
from openai.types.beta.threads import Message as OAIMessage
from openai.types.beta.threads import MessageDelta as OAIMessageDelta
from openai.types.beta.threads.runs import RunStep, RunStepDelta, ToolCall
from prefect import get_client as get_prefect_client
from prefect.context import FlowRunContext
from pydantic import BaseModel, Field, computed_field, model_validator

import controlflow
Expand All @@ -25,14 +15,10 @@
from controlflow.core.task import Task
from controlflow.instructions import get_instructions
from controlflow.llm.completions import completion_stream_async
from controlflow.llm.handlers import PrintHandler
from controlflow.llm.handlers import TUIHandler
from controlflow.llm.history import History
from controlflow.tui.app import TUIApp as TUI
from controlflow.utilities.context import ctx
from controlflow.utilities.prefect import (
create_json_artifact,
create_python_artifact,
)
from controlflow.utilities.tasks import all_complete, any_incomplete
from controlflow.utilities.types import FunctionTool, Message

Expand Down Expand Up @@ -102,11 +88,12 @@ def help_im_stuck():
if self._endrun_count >= 3:
self._should_abort = True
self._endrun_count = 0
return EndRun()

return f"Ending turn. {3 - self._endrun_count} more uses will abort the workflow."

return help_im_stuck

async def _run_agent(self, agent: Agent, tasks: list[Task] = None) -> Run:
async def _run_agent(self, agent: Agent, tasks: list[Task] = None):
"""
Run a single agent.
"""
Expand Down Expand Up @@ -141,7 +128,7 @@ async def _run_agent(self, agent: Agent, tasks: list[Task] = None) -> Run:
messages=[system_message] + messages,
model=agent.model,
tools=tools,
handlers=[PrintHandler()],
handlers=[TUIHandler()] if controlflow.settings.enable_tui else None,
max_iterations=1,
response_callback=r.append,
):
Expand Down Expand Up @@ -243,96 +230,3 @@ async def run_async(self):
f"Task iterations exceeded maximum of {max_task_iterations} for each task."
)
self._should_abort = False


class TUIHandler(AsyncAssistantEventHandler):
async def on_message_delta(
self, delta: OAIMessageDelta, snapshot: OAIMessage
) -> None:
if tui := ctx.get("tui"):
content = []
for item in snapshot.content:
if item.type == "text":
content.append(item.text.value)

tui.update_message(
m_id=snapshot.id,
message="\n\n".join(content),
role=snapshot.role,
timestamp=datetime.datetime.fromtimestamp(snapshot.created_at),
)

async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
if tui := ctx.get("tui"):
tui.update_step(snapshot)


class AgentHandler(PrintHandler):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.tool_calls = {}

async def on_tool_call_created(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is created"""

if tool_call.type == "function":
task_run_name = "Prepare arguments for tool call"
else:
task_run_name = f"Tool call: {tool_call.type}"

client = get_prefect_client()
engine_context = FlowRunContext.get()
if not engine_context:
return

task_run = await client.create_task_run(
task=prefect.Task(fn=lambda: None),
name=task_run_name,
extra_tags=["tool-call"],
flow_run_id=engine_context.flow_run.id,
dynamic_key=tool_call.id,
state=prefect.states.Running(),
)

self.tool_calls[tool_call.id] = task_run

async def on_tool_call_done(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is done"""

client = get_prefect_client()
task_run = self.tool_calls.get(tool_call.id)

if not task_run:
return
await client.set_task_run_state(
task_run_id=task_run.id, state=prefect.states.Completed(), force=True
)

# code interpreter is run as a single call, so we can publish a result artifact
if tool_call.type == "code_interpreter":
# images = []
# for output in tool_call.code_interpreter.outputs:
# if output.type == "image":
# image_path = download_temp_file(output.image.file_id)
# images.append(image_path)

create_python_artifact(
key="code",
code=tool_call.code_interpreter.input,
description="Code executed in the code interpreter",
task_run_id=task_run.id,
)
create_json_artifact(
key="output",
data=tool_call.code_interpreter.outputs,
description="Output from the code interpreter",
task_run_id=task_run.id,
)

elif tool_call.type == "function":
create_json_artifact(
key="arguments",
data=json.dumps(json.loads(tool_call.function.arguments), indent=2),
description=f"Arguments for the `{tool_call.function.name}` tool",
task_run_id=task_run.id,
)
6 changes: 3 additions & 3 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,16 @@ def mark_successful(self, result: T = None, validate_upstreams: bool = True):
self.result = validate_result(result, self.result_type)
self.set_status(TaskStatus.SUCCESSFUL)

return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}"
return f"{self.friendly_name()} marked successful."

def mark_failed(self, message: Union[str, None] = None):
self.error = message
self.set_status(TaskStatus.FAILED)
return f"{self.friendly_name()} marked failed. Updated task definition: {self.model_dump()}"
return f"{self.friendly_name()} marked failed."

def mark_skipped(self):
self.set_status(TaskStatus.SKIPPED)
return f"{self.friendly_name()} marked skipped. Updated task definition: {self.model_dump()}"
return f"{self.friendly_name()} marked skipped."


def generate_result_schema(result_type: type[T]) -> type[T]:
Expand Down
81 changes: 48 additions & 33 deletions src/controlflow/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def completion(

# on message done
for h in handlers:
h.on_message_done(response.choices[0].message)
h.on_message_done(response)
new_messages.append(response.choices[0].message)

for tool_call in get_tool_calls(response):
for h in handlers:
h.on_tool_call(tool_call=tool_call)
h.on_tool_call_done(tool_call=tool_call)
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
h.on_tool_result(tool_message.tool_result)
h.on_tool_result(tool_message)
new_messages.append(tool_message)

if len(responses) >= (max_iterations or math.inf):
Expand Down Expand Up @@ -156,44 +156,52 @@ def completion_stream(

while not response or has_tool_calls(response):
deltas = []
is_tool_call = False
for delta in litellm.completion(
model=model,
messages=trim_messages(messages + new_messages, model=model),
tools=[t.model_dump() for t in tools] if tools else None,
stream=True,
**kwargs,
):
# on message created
if not deltas:
for h in handlers:
h.on_message_created(delta=delta.choices[0].delta)

deltas.append(delta)
response = litellm.stream_chunk_builder(deltas)

# on message created
if len(deltas) == 1:
if get_tool_calls(response):
is_tool_call = True
for h in handlers:
if is_tool_call:
h.on_tool_call_created(delta=delta)
else:
h.on_message_created(delta=delta)

# on message delta
for h in handlers:
h.on_message_delta(
delta=delta.choices[0].delta, snapshot=response.choices[0].message
)
if is_tool_call:
h.on_tool_call_delta(delta=delta, snapshot=response)
else:
h.on_message_delta(delta=delta, snapshot=response)

# yield
yield delta, response

responses.append(response)

# on message done
for h in handlers:
h.on_message_done(response.choices[0].message)
if not is_tool_call:
for h in handlers:
h.on_message_done(response)
new_messages.append(response.choices[0].message)

# tool calls
for tool_call in get_tool_calls(response):
for h in handlers:
h.on_tool_call(tool_call=tool_call)
h.on_tool_call_done(tool_call=tool_call)
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
h.on_tool_result(tool_message.tool_result)
h.on_tool_result(tool_message)
new_messages.append(tool_message)

yield None, tool_message
Expand Down Expand Up @@ -250,15 +258,15 @@ async def completion_async(

# on message done
for h in handlers:
await maybe_coro(h.on_message_done(response.choices[0].message))
await maybe_coro(h.on_message_done(response))
new_messages.append(response.choices[0].message)

for tool_call in get_tool_calls(response):
for h in handlers:
await maybe_coro(h.on_tool_call(tool_call=tool_call))
await maybe_coro(h.on_tool_call_done(tool_call=tool_call))
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
await maybe_coro(h.on_tool_result(tool_message.tool_result))
await maybe_coro(h.on_tool_result(tool_message))
new_messages.append(tool_message)

if len(responses) >= (max_iterations or math.inf):
Expand Down Expand Up @@ -310,47 +318,54 @@ async def completion_stream_async(

while not response or has_tool_calls(response):
deltas = []
is_tool_call = False
async for delta in await litellm.acompletion(
model=model,
messages=trim_messages(messages + new_messages, model=model),
tools=[t.model_dump() for t in tools] if tools else None,
stream=True,
**kwargs,
):
# on message created
if not deltas:
for h in handlers:
await maybe_coro(h.on_message_created(delta=delta.choices[0].delta))

deltas.append(delta)
response = litellm.stream_chunk_builder(deltas)

# on message delta
# on message / tool call created
if len(deltas) == 1:
if get_tool_calls(response):
is_tool_call = True
for h in handlers:
if is_tool_call:
await maybe_coro(h.on_tool_call_created(delta=delta))
else:
await maybe_coro(h.on_message_created(delta=delta))

# on message / tool call delta
for h in handlers:
await maybe_coro(
h.on_message_delta(
delta=delta.choices[0].delta,
snapshot=response.choices[0].message,
if is_tool_call:
await maybe_coro(
h.on_tool_call_delta(delta=delta, snapshot=response)
)
)
else:
await maybe_coro(h.on_message_delta(delta=delta, snapshot=response))

# yield
yield delta, response

responses.append(response)

# on message done
for h in handlers:
await maybe_coro(h.on_message_done(response.choices[0].message))
if not is_tool_call:
for h in handlers:
await maybe_coro(h.on_message_done(response))
new_messages.append(response.choices[0].message)

# tool calls
for tool_call in get_tool_calls(response):
for h in handlers:
await maybe_coro(h.on_tool_call(tool_call=tool_call))
await maybe_coro(h.on_tool_call_done(tool_call=tool_call))
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
await maybe_coro(h.on_tool_result(tool_message.tool_result))
await maybe_coro(h.on_tool_result(tool_message))
new_messages.append(tool_message)

yield None, tool_message
Expand Down
Loading

0 comments on commit 859187b

Please sign in to comment.