Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore basic human in the loop #80

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/controlflow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ class TasksTemplate(Template):
template: str = """
## Tasks

Your job is to complete the tasks assigned to you. Tasks may have multiple agents assigned. Only one agent
will be active at a time.
Your job is to complete any tasks assigned to you. Tasks may have multiple agents assigned.

### Current tasks

These tasks are assigned to you and ready to be worked on because their dependencies have been completed.
These tasks are assigned to you and ready to be worked on because their dependencies have been completed:

{% for task in tasks %}
{% if task.is_ready %}
Expand Down
3 changes: 2 additions & 1 deletion src/controlflow/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
as_tools,
get_tool_calls,
handle_tool_call,
handle_tool_call_async,
)


Expand Down Expand Up @@ -209,7 +210,7 @@ async def _completion_async_generator(
response_messages.append(response_message)

for tool_call in get_tool_calls(response_message):
tool_result_message = handle_tool_call(tool_call, tools)
tool_result_message = await handle_tool_call_async(tool_call, tools)
yield CompletionEvent(
type="tool_result_done", payload=dict(message=tool_result_message)
)
Expand Down
11 changes: 8 additions & 3 deletions src/controlflow/llm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,26 @@ def on_tool_result_done(self, message: ToolMessage):


class PrintHandler(CompletionHandler):
def __init__(self):
self.messages: dict[str, ControlFlowMessage] = {}
self.live = Live(auto_refresh=False)

def on_start(self):
self.live = Live(refresh_per_second=12)
self.live.start()
self.messages: dict[str, ControlFlowMessage] = {}

def on_end(self):
self.live.stop()

def on_exception(self, exc: Exception):
self.live.stop()

def update_live(self):
messages = sorted(self.messages.values(), key=lambda m: m.timestamp)
content = []
for message in messages:
content.append(format_message(message))

self.live.update(Group(*content))
self.live.update(Group(*content), refresh=True)

def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
self.messages[snapshot.id] = snapshot
Expand Down
5 changes: 4 additions & 1 deletion src/controlflow/llm/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Literal, Optional, Union

import pydantic
from prefect.utilities.asyncutils import run_coro_as_sync

from controlflow.llm.messages import (
AssistantMessage,
Expand Down Expand Up @@ -171,6 +172,8 @@ def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMe
metadata.update(tool._metadata)
fn_args = tool_call.function.json_arguments()
fn_output = tool(**fn_args)
if inspect.isawaitable(fn_output):
fn_output = run_coro_as_sync(fn_output)
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
metadata["is_failed"] = True
Expand Down Expand Up @@ -199,7 +202,7 @@ async def handle_tool_call_async(
metadata = tool._metadata
fn_args = tool_call.function.json_arguments()
fn_output = tool(**fn_args)
if inspect.is_awaitable(fn_output):
if inspect.isawaitable(fn_output):
fn_output = await fn_output
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
Expand Down
32 changes: 13 additions & 19 deletions src/controlflow/tools/talk_to_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,33 @@

from prefect.context import FlowRunContext
from prefect.input.run_input import receive_input
from rich.prompt import Prompt

import controlflow
from controlflow.llm.tools import tool
from controlflow.utilities.context import ctx

if TYPE_CHECKING:
from controlflow.tui.app import TUIApp
pass


async def get_terminal_input():
# as a convenience, we wait for human input on the local terminal
# this is not necessary for the flow to run, but can be useful for testing
loop = asyncio.get_event_loop()
user_input = await loop.run_in_executor(None, input, "Type your response: ")
# user_input = await loop.run_in_executor(None, input, "Type your response: ")
user_input = await loop.run_in_executor(None, Prompt.ask, "Type your response")
return user_input


async def get_tui_input(tui: "TUIApp", message: str):
container = []
await tui.get_input(message=message, container=container)
while not container:
await asyncio.sleep(0.1)
return container[0]
# async def get_tui_input(tui: "TUIApp", message: str):
# container = []
# await tui.get_input(message=message, container=container)
# while not container:
# await asyncio.sleep(0.1)
# return container[0]


async def listen_for_response():
async def get_flow_run_input():
async for response in receive_input(
str, flow_run_id=FlowRunContext.get().flow_run.id, poll_interval=0.2
):
Expand All @@ -48,18 +49,11 @@ async def talk_to_human(message: str, get_response: bool = True) -> str:
tasks = []
# if running in a Prefect flow, listen for a remote input
if (frc := FlowRunContext.get()) and frc.flow_run and frc.flow_run.id:
remote_input = asyncio.create_task(listen_for_response())
remote_input = asyncio.create_task(get_flow_run_input())
tasks.append(remote_input)
# if terminal input is enabled, listen for local input
if controlflow.settings.enable_local_input:
# if a TUI is running, use it to get input
if controlflow.settings.enable_tui and ctx.get("tui"):
local_input = asyncio.create_task(
get_tui_input(tui=ctx.get("tui"), message=message)
)
# otherwise use terminal
else:
local_input = asyncio.create_task(get_terminal_input())
local_input = asyncio.create_task(get_terminal_input())
tasks.append(local_input)
if not tasks:
raise ValueError(
Expand Down
Loading