Skip to content

Commit

Permalink
Merge pull request #80 from PrefectHQ/human
Browse files Browse the repository at this point in the history
Restore basic human in the loop
  • Loading branch information
jlowin authored Jun 3, 2024
2 parents 5f40000 + 70d09e8 commit a66e4c0
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 27 deletions.
5 changes: 2 additions & 3 deletions src/controlflow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ class TasksTemplate(Template):
template: str = """
## Tasks
Your job is to complete the tasks assigned to you. Tasks may have multiple agents assigned. Only one agent
will be active at a time.
Your job is to complete any tasks assigned to you. Tasks may have multiple agents assigned.
### Current tasks
These tasks are assigned to you and ready to be worked on because their dependencies have been completed.
These tasks are assigned to you and ready to be worked on because their dependencies have been completed:
{% for task in tasks %}
{% if task.is_ready %}
Expand Down
3 changes: 2 additions & 1 deletion src/controlflow/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
as_tools,
get_tool_calls,
handle_tool_call,
handle_tool_call_async,
)


Expand Down Expand Up @@ -209,7 +210,7 @@ async def _completion_async_generator(
response_messages.append(response_message)

for tool_call in get_tool_calls(response_message):
tool_result_message = handle_tool_call(tool_call, tools)
tool_result_message = await handle_tool_call_async(tool_call, tools)
yield CompletionEvent(
type="tool_result_done", payload=dict(message=tool_result_message)
)
Expand Down
11 changes: 8 additions & 3 deletions src/controlflow/llm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,26 @@ def on_tool_result_done(self, message: ToolMessage):


class PrintHandler(CompletionHandler):
def __init__(self):
self.messages: dict[str, ControlFlowMessage] = {}
self.live = Live(auto_refresh=False)

def on_start(self):
self.live = Live(refresh_per_second=12)
self.live.start()
self.messages: dict[str, ControlFlowMessage] = {}

def on_end(self):
self.live.stop()

def on_exception(self, exc: Exception):
self.live.stop()

def update_live(self):
messages = sorted(self.messages.values(), key=lambda m: m.timestamp)
content = []
for message in messages:
content.append(format_message(message))

self.live.update(Group(*content))
self.live.update(Group(*content), refresh=True)

def on_message_delta(self, delta: AssistantMessage, snapshot: AssistantMessage):
self.messages[snapshot.id] = snapshot
Expand Down
5 changes: 4 additions & 1 deletion src/controlflow/llm/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Literal, Optional, Union

import pydantic
from prefect.utilities.asyncutils import run_coro_as_sync

from controlflow.llm.messages import (
AssistantMessage,
Expand Down Expand Up @@ -171,6 +172,8 @@ def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMe
metadata.update(tool._metadata)
fn_args = tool_call.function.json_arguments()
fn_output = tool(**fn_args)
if inspect.isawaitable(fn_output):
fn_output = run_coro_as_sync(fn_output)
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
metadata["is_failed"] = True
Expand Down Expand Up @@ -199,7 +202,7 @@ async def handle_tool_call_async(
metadata = tool._metadata
fn_args = tool_call.function.json_arguments()
fn_output = tool(**fn_args)
if inspect.is_awaitable(fn_output):
if inspect.isawaitable(fn_output):
fn_output = await fn_output
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
Expand Down
32 changes: 13 additions & 19 deletions src/controlflow/tools/talk_to_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,33 @@

from prefect.context import FlowRunContext
from prefect.input.run_input import receive_input
from rich.prompt import Prompt

import controlflow
from controlflow.llm.tools import tool
from controlflow.utilities.context import ctx

if TYPE_CHECKING:
from controlflow.tui.app import TUIApp
pass


async def get_terminal_input():
# as a convenience, we wait for human input on the local terminal
# this is not necessary for the flow to run, but can be useful for testing
loop = asyncio.get_event_loop()
user_input = await loop.run_in_executor(None, input, "Type your response: ")
# user_input = await loop.run_in_executor(None, input, "Type your response: ")
user_input = await loop.run_in_executor(None, Prompt.ask, "Type your response")
return user_input


async def get_tui_input(tui: "TUIApp", message: str):
container = []
await tui.get_input(message=message, container=container)
while not container:
await asyncio.sleep(0.1)
return container[0]
# async def get_tui_input(tui: "TUIApp", message: str):
# container = []
# await tui.get_input(message=message, container=container)
# while not container:
# await asyncio.sleep(0.1)
# return container[0]


async def listen_for_response():
async def get_flow_run_input():
async for response in receive_input(
str, flow_run_id=FlowRunContext.get().flow_run.id, poll_interval=0.2
):
Expand All @@ -48,18 +49,11 @@ async def talk_to_human(message: str, get_response: bool = True) -> str:
tasks = []
# if running in a Prefect flow, listen for a remote input
if (frc := FlowRunContext.get()) and frc.flow_run and frc.flow_run.id:
remote_input = asyncio.create_task(listen_for_response())
remote_input = asyncio.create_task(get_flow_run_input())
tasks.append(remote_input)
# if terminal input is enabled, listen for local input
if controlflow.settings.enable_local_input:
# if a TUI is running, use it to get input
if controlflow.settings.enable_tui and ctx.get("tui"):
local_input = asyncio.create_task(
get_tui_input(tui=ctx.get("tui"), message=message)
)
# otherwise use terminal
else:
local_input = asyncio.create_task(get_terminal_input())
local_input = asyncio.create_task(get_terminal_input())
tasks.append(local_input)
if not tasks:
raise ValueError(
Expand Down

0 comments on commit a66e4c0

Please sign in to comment.