Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tui running #65

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 5 additions & 111 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
import datetime
import json
import logging
import math
from contextlib import asynccontextmanager
from functools import cached_property
from typing import Union

import prefect
from marvin.beta.assistants import EndRun, PrintHandler, Run
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from openai import AsyncAssistantEventHandler
from openai.types.beta.threads import Message as OAIMessage
from openai.types.beta.threads import MessageDelta as OAIMessageDelta
from openai.types.beta.threads.runs import RunStep, RunStepDelta, ToolCall
from prefect import get_client as get_prefect_client
from prefect.context import FlowRunContext
from pydantic import BaseModel, Field, computed_field, model_validator

import controlflow
Expand All @@ -25,14 +15,10 @@
from controlflow.core.task import Task
from controlflow.instructions import get_instructions
from controlflow.llm.completions import completion_stream_async
from controlflow.llm.handlers import PrintHandler
from controlflow.llm.handlers import TUIHandler
from controlflow.llm.history import History
from controlflow.tui.app import TUIApp as TUI
from controlflow.utilities.context import ctx
from controlflow.utilities.prefect import (
create_json_artifact,
create_python_artifact,
)
from controlflow.utilities.tasks import all_complete, any_incomplete
from controlflow.utilities.types import FunctionTool, Message

Expand Down Expand Up @@ -102,11 +88,12 @@ def help_im_stuck():
if self._endrun_count >= 3:
self._should_abort = True
self._endrun_count = 0
return EndRun()

return f"Ending turn. {3 - self._endrun_count} more uses will abort the workflow."

return help_im_stuck

async def _run_agent(self, agent: Agent, tasks: list[Task] = None) -> Run:
async def _run_agent(self, agent: Agent, tasks: list[Task] = None):
"""
Run a single agent.
"""
Expand Down Expand Up @@ -141,7 +128,7 @@ async def _run_agent(self, agent: Agent, tasks: list[Task] = None) -> Run:
messages=[system_message] + messages,
model=agent.model,
tools=tools,
handlers=[PrintHandler()],
handlers=[TUIHandler()] if controlflow.settings.enable_tui else None,
max_iterations=1,
response_callback=r.append,
):
Expand Down Expand Up @@ -243,96 +230,3 @@ async def run_async(self):
f"Task iterations exceeded maximum of {max_task_iterations} for each task."
)
self._should_abort = False


class TUIHandler(AsyncAssistantEventHandler):
async def on_message_delta(
self, delta: OAIMessageDelta, snapshot: OAIMessage
) -> None:
if tui := ctx.get("tui"):
content = []
for item in snapshot.content:
if item.type == "text":
content.append(item.text.value)

tui.update_message(
m_id=snapshot.id,
message="\n\n".join(content),
role=snapshot.role,
timestamp=datetime.datetime.fromtimestamp(snapshot.created_at),
)

async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
if tui := ctx.get("tui"):
tui.update_step(snapshot)


class AgentHandler(PrintHandler):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.tool_calls = {}

async def on_tool_call_created(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is created"""

if tool_call.type == "function":
task_run_name = "Prepare arguments for tool call"
else:
task_run_name = f"Tool call: {tool_call.type}"

client = get_prefect_client()
engine_context = FlowRunContext.get()
if not engine_context:
return

task_run = await client.create_task_run(
task=prefect.Task(fn=lambda: None),
name=task_run_name,
extra_tags=["tool-call"],
flow_run_id=engine_context.flow_run.id,
dynamic_key=tool_call.id,
state=prefect.states.Running(),
)

self.tool_calls[tool_call.id] = task_run

async def on_tool_call_done(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is done"""

client = get_prefect_client()
task_run = self.tool_calls.get(tool_call.id)

if not task_run:
return
await client.set_task_run_state(
task_run_id=task_run.id, state=prefect.states.Completed(), force=True
)

# code interpreter is run as a single call, so we can publish a result artifact
if tool_call.type == "code_interpreter":
# images = []
# for output in tool_call.code_interpreter.outputs:
# if output.type == "image":
# image_path = download_temp_file(output.image.file_id)
# images.append(image_path)

create_python_artifact(
key="code",
code=tool_call.code_interpreter.input,
description="Code executed in the code interpreter",
task_run_id=task_run.id,
)
create_json_artifact(
key="output",
data=tool_call.code_interpreter.outputs,
description="Output from the code interpreter",
task_run_id=task_run.id,
)

elif tool_call.type == "function":
create_json_artifact(
key="arguments",
data=json.dumps(json.loads(tool_call.function.arguments), indent=2),
description=f"Arguments for the `{tool_call.function.name}` tool",
task_run_id=task_run.id,
)
6 changes: 3 additions & 3 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,16 @@ def mark_successful(self, result: T = None, validate_upstreams: bool = True):
self.result = validate_result(result, self.result_type)
self.set_status(TaskStatus.SUCCESSFUL)

return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}"
return f"{self.friendly_name()} marked successful."

def mark_failed(self, message: Union[str, None] = None):
self.error = message
self.set_status(TaskStatus.FAILED)
return f"{self.friendly_name()} marked failed. Updated task definition: {self.model_dump()}"
return f"{self.friendly_name()} marked failed."

def mark_skipped(self):
self.set_status(TaskStatus.SKIPPED)
return f"{self.friendly_name()} marked skipped. Updated task definition: {self.model_dump()}"
return f"{self.friendly_name()} marked skipped."


def generate_result_schema(result_type: type[T]) -> type[T]:
Expand Down
81 changes: 48 additions & 33 deletions src/controlflow/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def completion(

# on message done
for h in handlers:
h.on_message_done(response.choices[0].message)
h.on_message_done(response)
new_messages.append(response.choices[0].message)

for tool_call in get_tool_calls(response):
for h in handlers:
h.on_tool_call(tool_call=tool_call)
h.on_tool_call_done(tool_call=tool_call)
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
h.on_tool_result(tool_message.tool_result)
h.on_tool_result(tool_message)
new_messages.append(tool_message)

if len(responses) >= (max_iterations or math.inf):
Expand Down Expand Up @@ -156,44 +156,52 @@ def completion_stream(

while not response or has_tool_calls(response):
deltas = []
is_tool_call = False
for delta in litellm.completion(
model=model,
messages=trim_messages(messages + new_messages, model=model),
tools=[t.model_dump() for t in tools] if tools else None,
stream=True,
**kwargs,
):
# on message created
if not deltas:
for h in handlers:
h.on_message_created(delta=delta.choices[0].delta)

deltas.append(delta)
response = litellm.stream_chunk_builder(deltas)

# on message created
if len(deltas) == 1:
if get_tool_calls(response):
is_tool_call = True
for h in handlers:
if is_tool_call:
h.on_tool_call_created(delta=delta)
else:
h.on_message_created(delta=delta)

# on message delta
for h in handlers:
h.on_message_delta(
delta=delta.choices[0].delta, snapshot=response.choices[0].message
)
if is_tool_call:
h.on_tool_call_delta(delta=delta, snapshot=response)
else:
h.on_message_delta(delta=delta, snapshot=response)

# yield
yield delta, response

responses.append(response)

# on message done
for h in handlers:
h.on_message_done(response.choices[0].message)
if not is_tool_call:
for h in handlers:
h.on_message_done(response)
new_messages.append(response.choices[0].message)

# tool calls
for tool_call in get_tool_calls(response):
for h in handlers:
h.on_tool_call(tool_call=tool_call)
h.on_tool_call_done(tool_call=tool_call)
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
h.on_tool_result(tool_message.tool_result)
h.on_tool_result(tool_message)
new_messages.append(tool_message)

yield None, tool_message
Expand Down Expand Up @@ -250,15 +258,15 @@ async def completion_async(

# on message done
for h in handlers:
await maybe_coro(h.on_message_done(response.choices[0].message))
await maybe_coro(h.on_message_done(response))
new_messages.append(response.choices[0].message)

for tool_call in get_tool_calls(response):
for h in handlers:
await maybe_coro(h.on_tool_call(tool_call=tool_call))
await maybe_coro(h.on_tool_call_done(tool_call=tool_call))
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
await maybe_coro(h.on_tool_result(tool_message.tool_result))
await maybe_coro(h.on_tool_result(tool_message))
new_messages.append(tool_message)

if len(responses) >= (max_iterations or math.inf):
Expand Down Expand Up @@ -310,47 +318,54 @@ async def completion_stream_async(

while not response or has_tool_calls(response):
deltas = []
is_tool_call = False
async for delta in await litellm.acompletion(
model=model,
messages=trim_messages(messages + new_messages, model=model),
tools=[t.model_dump() for t in tools] if tools else None,
stream=True,
**kwargs,
):
# on message created
if not deltas:
for h in handlers:
await maybe_coro(h.on_message_created(delta=delta.choices[0].delta))

deltas.append(delta)
response = litellm.stream_chunk_builder(deltas)

# on message delta
# on message / tool call created
if len(deltas) == 1:
if get_tool_calls(response):
is_tool_call = True
for h in handlers:
if is_tool_call:
await maybe_coro(h.on_tool_call_created(delta=delta))
else:
await maybe_coro(h.on_message_created(delta=delta))

# on message / tool call delta
for h in handlers:
await maybe_coro(
h.on_message_delta(
delta=delta.choices[0].delta,
snapshot=response.choices[0].message,
if is_tool_call:
await maybe_coro(
h.on_tool_call_delta(delta=delta, snapshot=response)
)
)
else:
await maybe_coro(h.on_message_delta(delta=delta, snapshot=response))

# yield
yield delta, response

responses.append(response)

# on message done
for h in handlers:
await maybe_coro(h.on_message_done(response.choices[0].message))
if not is_tool_call:
for h in handlers:
await maybe_coro(h.on_message_done(response))
new_messages.append(response.choices[0].message)

# tool calls
for tool_call in get_tool_calls(response):
for h in handlers:
await maybe_coro(h.on_tool_call(tool_call=tool_call))
await maybe_coro(h.on_tool_call_done(tool_call=tool_call))
tool_message = handle_tool_call(tool_call, tools)
for h in handlers:
await maybe_coro(h.on_tool_result(tool_message.tool_result))
await maybe_coro(h.on_tool_result(tool_message))
new_messages.append(tool_message)

yield None, tool_message
Expand Down
Loading
Loading