Skip to content

Commit

Permalink
update tui for new types
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed May 22, 2024
1 parent bba50f7 commit c4aa63b
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Settings(ControlFlowSettings):
# ------------ TUI settings ------------

enable_tui: bool = Field(
True,
False,
description="If True, the TUI will be enabled. If False, the TUI will be disabled.",
)
run_tui_headless: bool = Field(
Expand Down
16 changes: 11 additions & 5 deletions src/controlflow/tui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, flow: "controlflow.Flow", **kwargs):
async def run_context(
self,
run: bool = True,
inline: bool = True,
inline: bool = False,
inline_stay_visible: bool = True,
headless: bool = None,
hold: bool = True,
Expand All @@ -64,6 +64,9 @@ async def run_context(

try:
yield self
except Exception:
self.hold = False
raise
finally:
if run:
while self.hold:
Expand All @@ -79,10 +82,13 @@ def action_toggle_hold(self):
self.hold = not self.hold

def watch_hold(self, hold: bool):
if hold:
self.query_one("#hold-banner").display = "block"
else:
self.query_one("#hold-banner").display = "none"
try:
if hold:
self.query_one("#hold-banner").display = "block"
else:
self.query_one("#hold-banner").display = "none"
except NoMatches:
pass

def on_mount(self):
if self._flow.name:
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/tui/app.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ TUITask {
.task-info-row {
height: auto;
width: 1fr;
margin-top: 1;
margin-top: 0;
# margin-left: 4;

}
Expand Down
22 changes: 9 additions & 13 deletions src/controlflow/tui/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def watch_task(self, task: Task):
self.status = task.status.value

self.result = task.result
if self.result is not None:
self.query_one(".result-collapsible", Collapsible).display = "block"
self.error_msg = task.error
if self.error_msg is not None:
self.query_one(".error-collapsible", Collapsible).display = "block"
if self.is_mounted:
if self.result is not None:
self.query_one(".result-collapsible", Collapsible).display = "block"
if self.error_msg is not None:
self.query_one(".error-collapsible", Collapsible).display = "block"

def compose(self):
self.border_title = f"Task {self.task.id}"
Expand All @@ -78,19 +79,14 @@ def compose(self):
yield Label(self.task.objective, classes="objective task-info")

with Vertical(classes="task-info-row"):
# yield Label(
# f"ID: {self.task.id}",
# classes="task-info",
# )
yield Label(
f"Agents: {', '.join(a.name for a in self.task.get_agents())}",
classes="user-access task-info",
)
# yield Rule(orientation="vertical")
yield Label(
f"User access: {self.task.user_access}",
classes="user-access task-info",
)
# yield Label(
# f"User access: {self.task.user_access}",
# classes="user-access task-info",
# )

# ------------------ success

Expand Down
9 changes: 5 additions & 4 deletions src/controlflow/tui/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from controlflow import Task
from controlflow.core.flow import Flow
from controlflow.tui.app import TUIApp
from controlflow.utilities.types import AssistantMessage


class Name(BaseModel):
Expand Down Expand Up @@ -49,13 +50,13 @@ async def run():
)
await asyncio.sleep(1)
t0.mark_failed(message="this is my result")
app.update_message(m_id="1", message="hello there", role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)
app.update_message(m_id="2", message="hello there" * 50, role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)
app.update_message(m_id="3", message="hello there", role="user")
app.update_message(AssistantMessage(content="hello there" * 50))
await asyncio.sleep(1)
app.update_message(m_id="4", message="hello there", role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)

await asyncio.sleep(inf)
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/tui/test2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ async def run():


if __name__ == "__main__":
# r = asyncio.run(run())
r = asyncio.run(run())
# print(r)
flow.run()
# flow.run()
86 changes: 57 additions & 29 deletions src/controlflow/tui/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from textual.widgets import Static

from controlflow.core.task import TaskStatus
from controlflow.llm.tools import get_tool_calls
from controlflow.utilities.types import AssistantMessage, ToolMessage, UserMessage


Expand Down Expand Up @@ -48,32 +49,53 @@ def render(self):
"user": "green",
"assistant": "blue",
}
if isinstance(self.message, AssistantMessage) and self.message.has_tool_calls():
content = Markdown(
inspect.cleandoc("""
:hammer_and_wrench: Calling `{name}` with the following arguments:
```json
{args}
```
""").format(name=self.tool_name, args=self.tool_args)
panels = []
if tool_calls := get_tool_calls(self.message):
for tool_call in tool_calls:
content = Markdown(
inspect.cleandoc("""
:hammer_and_wrench: Calling `{name}` with the following arguments:
```json
{args}
```
""").format(
name=tool_call.function.name, args=tool_call.function.arguments
)
)
panels.append(
Panel(
content,
title="[bold]Tool Call[/]",
subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]",
title_align="left",
subtitle_align="right",
border_style=role_colors.get(self.message.role, "red"),
box=box.ROUNDED,
width=100,
expand=True,
padding=(1, 2),
)
)
else:
panels.append(
Panel(
Markdown(self.message.content),
title=f"[bold]{self.message.role.capitalize()}[/]",
subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]",
title_align="left",
subtitle_align="right",
border_style=role_colors.get(self.message.role, "red"),
box=box.ROUNDED,
width=100,
expand=True,
padding=(1, 2),
)
)
title = "Tool Call"
if len(panels) == 1:
return panels[0]
else:
content = self.message.content
title = self.message.role.capitalize()
return Panel(
content,
title=f"[bold]{title}[/]",
subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]",
title_align="left",
subtitle_align="right",
border_style=role_colors.get(self.message.role, "red"),
box=box.ROUNDED,
width=100,
expand=True,
padding=(1, 2),
)
return Group(*panels)


class TUIToolMessage(Static):
Expand All @@ -84,18 +106,24 @@ def __init__(self, message: ToolMessage, **kwargs):
self.message = message

def render(self):
if self.message.tool_failed:
content = f":x: The tool call to [markdown.code]{self.message.tool_name}[/] failed."
else:
if self.message.tool_metadata.get("is_failed"):
content = f":x: The tool call to [markdown.code]{self.message.tool_call.function.name}[/] failed."
elif not self.message.tool_metadata.get("is_task_status_tool"):
content_type = (
"json" if isinstance(self.message.tool_result, (dict, list)) else ""
)
content = Group(
f":white_check_mark: Received output from the [markdown.code]{self.message.tool_call.function.name}[/] tool.\n",
Markdown(f"```json\n{self.tool_result}\n```"),
Markdown(f"```{content_type}\n{self.message.content}\n```"),
)
else:
self.display = False
return ""

return Panel(
content,
title="Tool Call Result",
subtitle=f"[italic]{format_timestamp(self._timestamp)}[/]",
subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]",
title_align="left",
subtitle_align="right",
border_style="blue",
Expand Down
17 changes: 12 additions & 5 deletions src/controlflow/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,20 @@ class Tool(ControlFlowModel):
type: Literal["function"] = "function"
function: ToolFunction
_fn: Callable = PrivateAttr()
_metadata: dict = PrivateAttr(default_factory=dict)

def __init__(self, *, _fn: Callable, **kwargs):
def __init__(self, *, _fn: Callable, _metadata: dict = None, **kwargs):
super().__init__(**kwargs)
self._fn = _fn
self._metadata = _metadata or {}

@classmethod
def from_function(
cls, fn: Callable, name: Optional[str] = None, description: Optional[str] = None
cls,
fn: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[dict] = None,
):
if name is None and fn.__name__ == "<lambda>":
name = "__lambda__"
Expand All @@ -89,6 +95,7 @@ def from_function(
).json_schema(),
),
_fn=fn,
_metadata=metadata or getattr(fn, "__metadata__", {}),
)

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -215,9 +222,9 @@ class ToolMessage(ControlFlowMessage):
_openai_fields = {"role", "content", "tool_call_id"}
# ---- end openai fields

tool_call: "ToolCall" = Field(cf_field=True, repr=False)
tool_result: Any = Field(None, cf_field=True, exclude=True)
tool_failed: bool = Field(False, cf_field=True)
tool_call: "ToolCall" = Field(repr=False)
tool_result: Any = Field(None, exclude=True)
tool_metadata: dict = Field(default_factory=dict)


MessageType = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
Expand Down

0 comments on commit c4aa63b

Please sign in to comment.