From 918aab97648fac5ec7060b15a9fdbdca9955841e Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 22 May 2024 18:19:28 -0400 Subject: [PATCH 1/3] fix bug with reentrant contexts --- src/controlflow/core/flow.py | 11 ++++++++--- src/controlflow/core/task.py | 10 +++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/controlflow/core/flow.py b/src/controlflow/core/flow.py index 529e40d8..ef729bf6 100644 --- a/src/controlflow/core/flow.py +++ b/src/controlflow/core/flow.py @@ -30,6 +30,10 @@ class Flow(ControlFlowModel): _tasks: dict[str, "Task"] = {} context: dict[str, Any] = {} + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__cm_stack = [] + def add_task(self, task: "Task"): if self._tasks.get(task.id, task) is not task: raise ValueError( @@ -43,11 +47,12 @@ def _context(self): yield self def __enter__(self): - self.__cm = self._context() - return self.__cm.__enter__() + # use stack so we can enter the context multiple times + self.__cm_stack.append(self._context()) + return self.__cm_stack[-1].__enter__() def __exit__(self, *exc_info): - return self.__cm.__exit__(*exc_info) + return self.__cm_stack.pop().__exit__(*exc_info) def run(self): """ diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 592370f0..5a4cb6e1 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -122,6 +122,7 @@ def __init__( ).strip() super().__init__(**kwargs) + self.__cm_stack = [] def __repr__(self): include_fields = [ @@ -316,11 +317,11 @@ def _context(self): yield self def __enter__(self): - self.__cm = self._context() - return self.__cm.__enter__() + self.__cm_stack.append(self._context()) + return self.__cm_stack[-1].__enter__() def __exit__(self, *exc_info): - return self.__cm.__exit__(*exc_info) + return self.__cm_stack.pop().__exit__(*exc_info) def is_incomplete(self) -> bool: return self.status == TaskStatus.INCOMPLETE @@ -367,6 +368,7 @@ def succeed(result: result_schema) -> str: # type: ignore succeed, name=f"mark_task_{self.id}_successful", description=f"Mark task {self.id} as successful.", + metadata=dict(is_task_status_tool=True), ) def _create_fail_tool(self) -> Callable: @@ -378,6 +380,7 @@ def _create_fail_tool(self) -> Callable: self.mark_failed, name=f"mark_task_{self.id}_failed", description=f"Mark task {self.id} as failed. Only use when a technical issue like a broken tool or unresponsive human prevents completion.", + metadata=dict(is_task_status_tool=True), ) def _create_skip_tool(self) -> Callable: @@ -388,6 +391,7 @@ def _create_skip_tool(self) -> Callable: self.mark_skipped, name=f"mark_task_{self.id}_skipped", description=f"Mark task {self.id} as skipped. Only use when completing its parent task early.", + metadata=dict(is_task_status_tool=True), ) def get_agents(self) -> list["Agent"]: From bba50f751720ded86be082cd0e5477672a4075b8 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 22 May 2024 18:19:36 -0400 Subject: [PATCH 2/3] apply tool metadata --- src/controlflow/llm/tools.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/controlflow/llm/tools.py b/src/controlflow/llm/tools.py index 70a1c41e..17820152 100644 --- a/src/controlflow/llm/tools.py +++ b/src/controlflow/llm/tools.py @@ -20,14 +20,18 @@ def tool( *, name: Optional[str] = None, description: Optional[str] = None, + metadata: Optional[dict] = None, ) -> Tool: if fn is None: - return partial(tool, name=name, description=description) - return Tool.from_function(fn, name=name, description=description) + return partial(tool, name=name, description=description, metadata=metadata) + return Tool.from_function(fn, name=name, description=description, metadata=metadata) def annotate_fn( - fn: Callable, name: Optional[str], description: Optional[str] + fn: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[dict] = None, ) -> Callable: """ Annotate a function with a new name and description without modifying the @@ -37,6 +41,7 @@ def annotate_fn( new_fn = functools.partial(fn) new_fn.__name__ = name or fn.__name__ new_fn.__doc__ = description or fn.__doc__ + new_fn.__metadata__ = getattr(fn, "__metadata__", {}) | metadata return new_fn @@ -110,25 +115,26 @@ def get_tool_calls( def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMessage: tool_lookup = as_tool_lookup(tools) fn_name = tool_call.function.name - fn_args = (None,) + fn_args = None + metadata = {} try: - tool_failed = False if fn_name not in tool_lookup: fn_output = f'Function "{fn_name}" not found.' - tool_failed = True + metadata["is_failed"] = True else: tool = tool_lookup[fn_name] + metadata.update(tool._metadata) fn_args = json.loads(tool_call.function.arguments) fn_output = tool(**fn_args) except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' - tool_failed = True + metadata["is_failed"] = True return ToolMessage( content=output_to_string(fn_output), tool_call_id=tool_call.id, tool_call=tool_call, tool_result=fn_output, - tool_failed=tool_failed, + tool_metadata=metadata, ) @@ -137,25 +143,26 @@ async def handle_tool_call_async( ) -> ToolMessage: tool_lookup = as_tool_lookup(tools) fn_name = tool_call.function.name - fn_args = (None,) + fn_args = None + metadata = {} try: - tool_failed = False if fn_name not in tool_lookup: fn_output = f'Function "{fn_name}" not found.' - tool_failed = True + metadata["is_failed"] = True else: tool = tool_lookup[fn_name] + metadata = tool._metadata fn_args = json.loads(tool_call.function.arguments) fn_output = tool(**fn_args) if inspect.is_awaitable(fn_output): fn_output = await fn_output except Exception as exc: fn_output = f'Error calling function "{fn_name}": {exc}' - tool_failed = True + metadata["is_failed"] = True return ToolMessage( content=output_to_string(fn_output), tool_call_id=tool_call.id, tool_call=tool_call, tool_result=fn_output, - tool_failed=tool_failed, + tool_metadata=metadata, ) From c4aa63bce7f742e8c8b4215336f7b4b1a255e11a Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 22 May 2024 18:19:42 -0400 Subject: [PATCH 3/3] update tui for new types --- src/controlflow/settings.py | 2 +- src/controlflow/tui/app.py | 16 ++++-- src/controlflow/tui/app.tcss | 2 +- src/controlflow/tui/task.py | 22 ++++---- src/controlflow/tui/test.py | 9 ++-- src/controlflow/tui/test2.py | 4 +- src/controlflow/tui/thread.py | 86 ++++++++++++++++++++---------- src/controlflow/utilities/types.py | 17 ++++-- 8 files changed, 98 insertions(+), 60 deletions(-) diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 188f4811..2c1378a3 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -95,7 +95,7 @@ class Settings(ControlFlowSettings): # ------------ TUI settings ------------ enable_tui: bool = Field( - True, + False, description="If True, the TUI will be enabled. If False, the TUI will be disabled.", ) run_tui_headless: bool = Field( diff --git a/src/controlflow/tui/app.py b/src/controlflow/tui/app.py index 2fd58fcf..c0573a43 100644 --- a/src/controlflow/tui/app.py +++ b/src/controlflow/tui/app.py @@ -39,7 +39,7 @@ def __init__(self, flow: "controlflow.Flow", **kwargs): async def run_context( self, run: bool = True, - inline: bool = True, + inline: bool = False, inline_stay_visible: bool = True, headless: bool = None, hold: bool = True, @@ -64,6 +64,9 @@ async def run_context( try: yield self + except Exception: + self.hold = False + raise finally: if run: while self.hold: @@ -79,10 +82,13 @@ def action_toggle_hold(self): self.hold = not self.hold def watch_hold(self, hold: bool): - if hold: - self.query_one("#hold-banner").display = "block" - else: - self.query_one("#hold-banner").display = "none" + try: + if hold: + self.query_one("#hold-banner").display = "block" + else: + self.query_one("#hold-banner").display = "none" + except NoMatches: + pass def on_mount(self): if self._flow.name: diff --git a/src/controlflow/tui/app.tcss b/src/controlflow/tui/app.tcss index d0e3290c..de04c490 100644 --- a/src/controlflow/tui/app.tcss +++ b/src/controlflow/tui/app.tcss @@ -73,7 +73,7 @@ TUITask { .task-info-row { height: auto; width: 1fr; - margin-top: 1; + margin-top: 0; # margin-left: 4; } diff --git a/src/controlflow/tui/task.py b/src/controlflow/tui/task.py index 7ab4df2f..58470f74 100644 --- a/src/controlflow/tui/task.py +++ b/src/controlflow/tui/task.py @@ -63,11 +63,12 @@ def watch_task(self, task: Task): self.status = task.status.value self.result = task.result - if self.result is not None: - self.query_one(".result-collapsible", Collapsible).display = "block" self.error_msg = task.error - if self.error_msg is not None: - self.query_one(".error-collapsible", Collapsible).display = "block" + if self.is_mounted: + if self.result is not None: + self.query_one(".result-collapsible", Collapsible).display = "block" + if self.error_msg is not None: + self.query_one(".error-collapsible", Collapsible).display = "block" def compose(self): self.border_title = f"Task {self.task.id}" @@ -78,19 +79,14 @@ def compose(self): yield Label(self.task.objective, classes="objective task-info") with Vertical(classes="task-info-row"): - # yield Label( - # f"ID: {self.task.id}", - # classes="task-info", - # ) yield Label( f"Agents: {', '.join(a.name for a in self.task.get_agents())}", classes="user-access task-info", ) - # yield Rule(orientation="vertical") - yield Label( - f"User access: {self.task.user_access}", - classes="user-access task-info", - ) + # yield Label( + # f"User access: {self.task.user_access}", + # classes="user-access task-info", + # ) # ------------------ success diff --git a/src/controlflow/tui/test.py b/src/controlflow/tui/test.py index eb9060d3..d077aed0 100644 --- a/src/controlflow/tui/test.py +++ b/src/controlflow/tui/test.py @@ -6,6 +6,7 @@ from controlflow import Task from controlflow.core.flow import Flow from controlflow.tui.app import TUIApp +from controlflow.utilities.types import AssistantMessage class Name(BaseModel): @@ -49,13 +50,13 @@ async def run(): ) await asyncio.sleep(1) t0.mark_failed(message="this is my result") - app.update_message(m_id="1", message="hello there", role="assistant") + app.update_message(AssistantMessage(content="hello there")) await asyncio.sleep(1) - app.update_message(m_id="2", message="hello there" * 50, role="assistant") + app.update_message(AssistantMessage(content="hello there")) await asyncio.sleep(1) - app.update_message(m_id="3", message="hello there", role="user") + app.update_message(AssistantMessage(content="hello there" * 50)) await asyncio.sleep(1) - app.update_message(m_id="4", message="hello there", role="assistant") + app.update_message(AssistantMessage(content="hello there")) await asyncio.sleep(1) await asyncio.sleep(inf) diff --git a/src/controlflow/tui/test2.py b/src/controlflow/tui/test2.py index 18193e78..f78ae9fa 100644 --- a/src/controlflow/tui/test2.py +++ b/src/controlflow/tui/test2.py @@ -24,6 +24,6 @@ async def run(): if __name__ == "__main__": - # r = asyncio.run(run()) + r = asyncio.run(run()) # print(r) - flow.run() + # flow.run() diff --git a/src/controlflow/tui/thread.py b/src/controlflow/tui/thread.py index a4cb22f7..6d60352a 100644 --- a/src/controlflow/tui/thread.py +++ b/src/controlflow/tui/thread.py @@ -10,6 +10,7 @@ from textual.widgets import Static from controlflow.core.task import TaskStatus +from controlflow.llm.tools import get_tool_calls from controlflow.utilities.types import AssistantMessage, ToolMessage, UserMessage @@ -48,32 +49,53 @@ def render(self): "user": "green", "assistant": "blue", } - if isinstance(self.message, AssistantMessage) and self.message.has_tool_calls(): - content = Markdown( - inspect.cleandoc(""" - :hammer_and_wrench: Calling `{name}` with the following arguments: - - ```json - {args} - ``` - """).format(name=self.tool_name, args=self.tool_args) + panels = [] + if tool_calls := get_tool_calls(self.message): + for tool_call in tool_calls: + content = Markdown( + inspect.cleandoc(""" + :hammer_and_wrench: Calling `{name}` with the following arguments: + + ```json + {args} + ``` + """).format( + name=tool_call.function.name, args=tool_call.function.arguments + ) + ) + panels.append( + Panel( + content, + title="[bold]Tool Call[/]", + subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]", + title_align="left", + subtitle_align="right", + border_style=role_colors.get(self.message.role, "red"), + box=box.ROUNDED, + width=100, + expand=True, + padding=(1, 2), + ) + ) + else: + panels.append( + Panel( + Markdown(self.message.content), + title=f"[bold]{self.message.role.capitalize()}[/]", + subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]", + title_align="left", + subtitle_align="right", + border_style=role_colors.get(self.message.role, "red"), + box=box.ROUNDED, + width=100, + expand=True, + padding=(1, 2), + ) ) - title = "Tool Call" + if len(panels) == 1: + return panels[0] else: - content = self.message.content - title = self.message.role.capitalize() - return Panel( - content, - title=f"[bold]{title}[/]", - subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]", - title_align="left", - subtitle_align="right", - border_style=role_colors.get(self.message.role, "red"), - box=box.ROUNDED, - width=100, - expand=True, - padding=(1, 2), - ) + return Group(*panels) class TUIToolMessage(Static): @@ -84,18 +106,24 @@ def __init__(self, message: ToolMessage, **kwargs): self.message = message def render(self): - if self.message.tool_failed: - content = f":x: The tool call to [markdown.code]{self.message.tool_name}[/] failed." - else: + if self.message.tool_metadata.get("is_failed"): + content = f":x: The tool call to [markdown.code]{self.message.tool_call.function.name}[/] failed." + elif not self.message.tool_metadata.get("is_task_status_tool"): + content_type = ( + "json" if isinstance(self.message.tool_result, (dict, list)) else "" + ) content = Group( f":white_check_mark: Received output from the [markdown.code]{self.message.tool_call.function.name}[/] tool.\n", - Markdown(f"```json\n{self.tool_result}\n```"), + Markdown(f"```{content_type}\n{self.message.content}\n```"), ) + else: + self.display = False + return "" return Panel( content, title="Tool Call Result", - subtitle=f"[italic]{format_timestamp(self._timestamp)}[/]", + subtitle=f"[italic]{format_timestamp(self.message.timestamp)}[/]", title_align="left", subtitle_align="right", border_style="blue", diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index 5537f655..f40e39c5 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -68,14 +68,20 @@ class Tool(ControlFlowModel): type: Literal["function"] = "function" function: ToolFunction _fn: Callable = PrivateAttr() + _metadata: dict = PrivateAttr(default_factory=dict) - def __init__(self, *, _fn: Callable, **kwargs): + def __init__(self, *, _fn: Callable, _metadata: dict = None, **kwargs): super().__init__(**kwargs) self._fn = _fn + self._metadata = _metadata or {} @classmethod def from_function( - cls, fn: Callable, name: Optional[str] = None, description: Optional[str] = None + cls, + fn: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[dict] = None, ): if name is None and fn.__name__ == "": name = "__lambda__" @@ -89,6 +95,7 @@ def from_function( ).json_schema(), ), _fn=fn, + _metadata=metadata or getattr(fn, "__metadata__", {}), ) def __call__(self, *args, **kwargs): @@ -215,9 +222,9 @@ class ToolMessage(ControlFlowMessage): _openai_fields = {"role", "content", "tool_call_id"} # ---- end openai fields - tool_call: "ToolCall" = Field(cf_field=True, repr=False) - tool_result: Any = Field(None, cf_field=True, exclude=True) - tool_failed: bool = Field(False, cf_field=True) + tool_call: "ToolCall" = Field(repr=False) + tool_result: Any = Field(None, exclude=True) + tool_metadata: dict = Field(default_factory=dict) MessageType = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]