Skip to content

Commit

Permalink
Merge pull request #67 from jlowin/tui
Browse files Browse the repository at this point in the history
fix bug with reentrant contexts
  • Loading branch information
jlowin authored May 22, 2024
2 parents 81cff32 + c4aa63b commit 4c8da63
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 79 deletions.
11 changes: 8 additions & 3 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class Flow(ControlFlowModel):
_tasks: dict[str, "Task"] = {}
context: dict[str, Any] = {}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__cm_stack = []

def add_task(self, task: "Task"):
if self._tasks.get(task.id, task) is not task:
raise ValueError(
Expand All @@ -43,11 +47,12 @@ def _context(self):
yield self

def __enter__(self):
self.__cm = self._context()
return self.__cm.__enter__()
# use stack so we can enter the context multiple times
self.__cm_stack.append(self._context())
return self.__cm_stack[-1].__enter__()

def __exit__(self, *exc_info):
return self.__cm.__exit__(*exc_info)
return self.__cm_stack.pop().__exit__(*exc_info)

def run(self):
"""
Expand Down
10 changes: 7 additions & 3 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
).strip()

super().__init__(**kwargs)
self.__cm_stack = []

def __repr__(self):
include_fields = [
Expand Down Expand Up @@ -316,11 +317,11 @@ def _context(self):
yield self

def __enter__(self):
self.__cm = self._context()
return self.__cm.__enter__()
self.__cm_stack.append(self._context())
return self.__cm_stack[-1].__enter__()

def __exit__(self, *exc_info):
return self.__cm.__exit__(*exc_info)
return self.__cm_stack.pop().__exit__(*exc_info)

def is_incomplete(self) -> bool:
return self.status == TaskStatus.INCOMPLETE
Expand Down Expand Up @@ -367,6 +368,7 @@ def succeed(result: result_schema) -> str: # type: ignore
succeed,
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
metadata=dict(is_task_status_tool=True),
)

def _create_fail_tool(self) -> Callable:
Expand All @@ -378,6 +380,7 @@ def _create_fail_tool(self) -> Callable:
self.mark_failed,
name=f"mark_task_{self.id}_failed",
description=f"Mark task {self.id} as failed. Only use when a technical issue like a broken tool or unresponsive human prevents completion.",
metadata=dict(is_task_status_tool=True),
)

def _create_skip_tool(self) -> Callable:
Expand All @@ -388,6 +391,7 @@ def _create_skip_tool(self) -> Callable:
self.mark_skipped,
name=f"mark_task_{self.id}_skipped",
description=f"Mark task {self.id} as skipped. Only use when completing its parent task early.",
metadata=dict(is_task_status_tool=True),
)

def get_agents(self) -> list["Agent"]:
Expand Down
33 changes: 20 additions & 13 deletions src/controlflow/llm/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ def tool(
*,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[dict] = None,
) -> Tool:
if fn is None:
return partial(tool, name=name, description=description)
return Tool.from_function(fn, name=name, description=description)
return partial(tool, name=name, description=description, metadata=metadata)
return Tool.from_function(fn, name=name, description=description, metadata=metadata)


def annotate_fn(
fn: Callable, name: Optional[str], description: Optional[str]
fn: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[dict] = None,
) -> Callable:
"""
Annotate a function with a new name and description without modifying the
Expand All @@ -37,6 +41,7 @@ def annotate_fn(
new_fn = functools.partial(fn)
new_fn.__name__ = name or fn.__name__
new_fn.__doc__ = description or fn.__doc__
new_fn.__metadata__ = getattr(fn, "__metadata__", {}) | metadata
return new_fn


Expand Down Expand Up @@ -110,25 +115,26 @@ def get_tool_calls(
def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMessage:
tool_lookup = as_tool_lookup(tools)
fn_name = tool_call.function.name
fn_args = (None,)
fn_args = None
metadata = {}
try:
tool_failed = False
if fn_name not in tool_lookup:
fn_output = f'Function "{fn_name}" not found.'
tool_failed = True
metadata["is_failed"] = True
else:
tool = tool_lookup[fn_name]
metadata.update(tool._metadata)
fn_args = json.loads(tool_call.function.arguments)
fn_output = tool(**fn_args)
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
tool_failed = True
metadata["is_failed"] = True
return ToolMessage(
content=output_to_string(fn_output),
tool_call_id=tool_call.id,
tool_call=tool_call,
tool_result=fn_output,
tool_failed=tool_failed,
tool_metadata=metadata,
)


Expand All @@ -137,25 +143,26 @@ async def handle_tool_call_async(
) -> ToolMessage:
tool_lookup = as_tool_lookup(tools)
fn_name = tool_call.function.name
fn_args = (None,)
fn_args = None
metadata = {}
try:
tool_failed = False
if fn_name not in tool_lookup:
fn_output = f'Function "{fn_name}" not found.'
tool_failed = True
metadata["is_failed"] = True
else:
tool = tool_lookup[fn_name]
metadata = tool._metadata
fn_args = json.loads(tool_call.function.arguments)
fn_output = tool(**fn_args)
if inspect.is_awaitable(fn_output):
fn_output = await fn_output
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
tool_failed = True
metadata["is_failed"] = True
return ToolMessage(
content=output_to_string(fn_output),
tool_call_id=tool_call.id,
tool_call=tool_call,
tool_result=fn_output,
tool_failed=tool_failed,
tool_metadata=metadata,
)
2 changes: 1 addition & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Settings(ControlFlowSettings):
# ------------ TUI settings ------------

enable_tui: bool = Field(
True,
False,
description="If True, the TUI will be enabled. If False, the TUI will be disabled.",
)
run_tui_headless: bool = Field(
Expand Down
16 changes: 11 additions & 5 deletions src/controlflow/tui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, flow: "controlflow.Flow", **kwargs):
async def run_context(
self,
run: bool = True,
inline: bool = True,
inline: bool = False,
inline_stay_visible: bool = True,
headless: bool = None,
hold: bool = True,
Expand All @@ -64,6 +64,9 @@ async def run_context(

try:
yield self
except Exception:
self.hold = False
raise
finally:
if run:
while self.hold:
Expand All @@ -79,10 +82,13 @@ def action_toggle_hold(self):
self.hold = not self.hold

def watch_hold(self, hold: bool):
if hold:
self.query_one("#hold-banner").display = "block"
else:
self.query_one("#hold-banner").display = "none"
try:
if hold:
self.query_one("#hold-banner").display = "block"
else:
self.query_one("#hold-banner").display = "none"
except NoMatches:
pass

def on_mount(self):
if self._flow.name:
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/tui/app.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ TUITask {
.task-info-row {
height: auto;
width: 1fr;
margin-top: 1;
margin-top: 0;
# margin-left: 4;

}
Expand Down
22 changes: 9 additions & 13 deletions src/controlflow/tui/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def watch_task(self, task: Task):
self.status = task.status.value

self.result = task.result
if self.result is not None:
self.query_one(".result-collapsible", Collapsible).display = "block"
self.error_msg = task.error
if self.error_msg is not None:
self.query_one(".error-collapsible", Collapsible).display = "block"
if self.is_mounted:
if self.result is not None:
self.query_one(".result-collapsible", Collapsible).display = "block"
if self.error_msg is not None:
self.query_one(".error-collapsible", Collapsible).display = "block"

def compose(self):
self.border_title = f"Task {self.task.id}"
Expand All @@ -78,19 +79,14 @@ def compose(self):
yield Label(self.task.objective, classes="objective task-info")

with Vertical(classes="task-info-row"):
# yield Label(
# f"ID: {self.task.id}",
# classes="task-info",
# )
yield Label(
f"Agents: {', '.join(a.name for a in self.task.get_agents())}",
classes="user-access task-info",
)
# yield Rule(orientation="vertical")
yield Label(
f"User access: {self.task.user_access}",
classes="user-access task-info",
)
# yield Label(
# f"User access: {self.task.user_access}",
# classes="user-access task-info",
# )

# ------------------ success

Expand Down
9 changes: 5 additions & 4 deletions src/controlflow/tui/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from controlflow import Task
from controlflow.core.flow import Flow
from controlflow.tui.app import TUIApp
from controlflow.utilities.types import AssistantMessage


class Name(BaseModel):
Expand Down Expand Up @@ -49,13 +50,13 @@ async def run():
)
await asyncio.sleep(1)
t0.mark_failed(message="this is my result")
app.update_message(m_id="1", message="hello there", role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)
app.update_message(m_id="2", message="hello there" * 50, role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)
app.update_message(m_id="3", message="hello there", role="user")
app.update_message(AssistantMessage(content="hello there" * 50))
await asyncio.sleep(1)
app.update_message(m_id="4", message="hello there", role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)

await asyncio.sleep(inf)
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/tui/test2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ async def run():


if __name__ == "__main__":
# r = asyncio.run(run())
r = asyncio.run(run())
# print(r)
flow.run()
# flow.run()
Loading

0 comments on commit 4c8da63

Please sign in to comment.