Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug with reentrant contexts #67

Merged
merged 3 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class Flow(ControlFlowModel):
_tasks: dict[str, "Task"] = {}
context: dict[str, Any] = {}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__cm_stack = []

def add_task(self, task: "Task"):
if self._tasks.get(task.id, task) is not task:
raise ValueError(
Expand All @@ -43,11 +47,12 @@ def _context(self):
yield self

def __enter__(self):
self.__cm = self._context()
return self.__cm.__enter__()
# use stack so we can enter the context multiple times
self.__cm_stack.append(self._context())
return self.__cm_stack[-1].__enter__()

def __exit__(self, *exc_info):
return self.__cm.__exit__(*exc_info)
return self.__cm_stack.pop().__exit__(*exc_info)

def run(self):
"""
Expand Down
10 changes: 7 additions & 3 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
).strip()

super().__init__(**kwargs)
self.__cm_stack = []

def __repr__(self):
include_fields = [
Expand Down Expand Up @@ -316,11 +317,11 @@ def _context(self):
yield self

def __enter__(self):
self.__cm = self._context()
return self.__cm.__enter__()
self.__cm_stack.append(self._context())
return self.__cm_stack[-1].__enter__()

def __exit__(self, *exc_info):
return self.__cm.__exit__(*exc_info)
return self.__cm_stack.pop().__exit__(*exc_info)

def is_incomplete(self) -> bool:
return self.status == TaskStatus.INCOMPLETE
Expand Down Expand Up @@ -367,6 +368,7 @@ def succeed(result: result_schema) -> str: # type: ignore
succeed,
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
metadata=dict(is_task_status_tool=True),
)

def _create_fail_tool(self) -> Callable:
Expand All @@ -378,6 +380,7 @@ def _create_fail_tool(self) -> Callable:
self.mark_failed,
name=f"mark_task_{self.id}_failed",
description=f"Mark task {self.id} as failed. Only use when a technical issue like a broken tool or unresponsive human prevents completion.",
metadata=dict(is_task_status_tool=True),
)

def _create_skip_tool(self) -> Callable:
Expand All @@ -388,6 +391,7 @@ def _create_skip_tool(self) -> Callable:
self.mark_skipped,
name=f"mark_task_{self.id}_skipped",
description=f"Mark task {self.id} as skipped. Only use when completing its parent task early.",
metadata=dict(is_task_status_tool=True),
)

def get_agents(self) -> list["Agent"]:
Expand Down
33 changes: 20 additions & 13 deletions src/controlflow/llm/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ def tool(
*,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[dict] = None,
) -> Tool:
if fn is None:
return partial(tool, name=name, description=description)
return Tool.from_function(fn, name=name, description=description)
return partial(tool, name=name, description=description, metadata=metadata)
return Tool.from_function(fn, name=name, description=description, metadata=metadata)


def annotate_fn(
fn: Callable, name: Optional[str], description: Optional[str]
fn: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[dict] = None,
) -> Callable:
"""
Annotate a function with a new name and description without modifying the
Expand All @@ -37,6 +41,7 @@ def annotate_fn(
new_fn = functools.partial(fn)
new_fn.__name__ = name or fn.__name__
new_fn.__doc__ = description or fn.__doc__
new_fn.__metadata__ = getattr(fn, "__metadata__", {}) | metadata
return new_fn


Expand Down Expand Up @@ -110,25 +115,26 @@ def get_tool_calls(
def handle_tool_call(tool_call: ToolCall, tools: list[dict, Callable]) -> ToolMessage:
tool_lookup = as_tool_lookup(tools)
fn_name = tool_call.function.name
fn_args = (None,)
fn_args = None
metadata = {}
try:
tool_failed = False
if fn_name not in tool_lookup:
fn_output = f'Function "{fn_name}" not found.'
tool_failed = True
metadata["is_failed"] = True
else:
tool = tool_lookup[fn_name]
metadata.update(tool._metadata)
fn_args = json.loads(tool_call.function.arguments)
fn_output = tool(**fn_args)
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
tool_failed = True
metadata["is_failed"] = True
return ToolMessage(
content=output_to_string(fn_output),
tool_call_id=tool_call.id,
tool_call=tool_call,
tool_result=fn_output,
tool_failed=tool_failed,
tool_metadata=metadata,
)


Expand All @@ -137,25 +143,26 @@ async def handle_tool_call_async(
) -> ToolMessage:
tool_lookup = as_tool_lookup(tools)
fn_name = tool_call.function.name
fn_args = (None,)
fn_args = None
metadata = {}
try:
tool_failed = False
if fn_name not in tool_lookup:
fn_output = f'Function "{fn_name}" not found.'
tool_failed = True
metadata["is_failed"] = True
else:
tool = tool_lookup[fn_name]
metadata = tool._metadata
fn_args = json.loads(tool_call.function.arguments)
fn_output = tool(**fn_args)
if inspect.is_awaitable(fn_output):
fn_output = await fn_output
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
tool_failed = True
metadata["is_failed"] = True
return ToolMessage(
content=output_to_string(fn_output),
tool_call_id=tool_call.id,
tool_call=tool_call,
tool_result=fn_output,
tool_failed=tool_failed,
tool_metadata=metadata,
)
2 changes: 1 addition & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Settings(ControlFlowSettings):
# ------------ TUI settings ------------

enable_tui: bool = Field(
True,
False,
description="If True, the TUI will be enabled. If False, the TUI will be disabled.",
)
run_tui_headless: bool = Field(
Expand Down
16 changes: 11 additions & 5 deletions src/controlflow/tui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, flow: "controlflow.Flow", **kwargs):
async def run_context(
self,
run: bool = True,
inline: bool = True,
inline: bool = False,
inline_stay_visible: bool = True,
headless: bool = None,
hold: bool = True,
Expand All @@ -64,6 +64,9 @@ async def run_context(

try:
yield self
except Exception:
self.hold = False
raise
finally:
if run:
while self.hold:
Expand All @@ -79,10 +82,13 @@ def action_toggle_hold(self):
self.hold = not self.hold

def watch_hold(self, hold: bool):
if hold:
self.query_one("#hold-banner").display = "block"
else:
self.query_one("#hold-banner").display = "none"
try:
if hold:
self.query_one("#hold-banner").display = "block"
else:
self.query_one("#hold-banner").display = "none"
except NoMatches:
pass

def on_mount(self):
if self._flow.name:
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/tui/app.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ TUITask {
.task-info-row {
height: auto;
width: 1fr;
margin-top: 1;
margin-top: 0;
# margin-left: 4;

}
Expand Down
22 changes: 9 additions & 13 deletions src/controlflow/tui/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def watch_task(self, task: Task):
self.status = task.status.value

self.result = task.result
if self.result is not None:
self.query_one(".result-collapsible", Collapsible).display = "block"
self.error_msg = task.error
if self.error_msg is not None:
self.query_one(".error-collapsible", Collapsible).display = "block"
if self.is_mounted:
if self.result is not None:
self.query_one(".result-collapsible", Collapsible).display = "block"
if self.error_msg is not None:
self.query_one(".error-collapsible", Collapsible).display = "block"

def compose(self):
self.border_title = f"Task {self.task.id}"
Expand All @@ -78,19 +79,14 @@ def compose(self):
yield Label(self.task.objective, classes="objective task-info")

with Vertical(classes="task-info-row"):
# yield Label(
# f"ID: {self.task.id}",
# classes="task-info",
# )
yield Label(
f"Agents: {', '.join(a.name for a in self.task.get_agents())}",
classes="user-access task-info",
)
# yield Rule(orientation="vertical")
yield Label(
f"User access: {self.task.user_access}",
classes="user-access task-info",
)
# yield Label(
# f"User access: {self.task.user_access}",
# classes="user-access task-info",
# )

# ------------------ success

Expand Down
9 changes: 5 additions & 4 deletions src/controlflow/tui/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from controlflow import Task
from controlflow.core.flow import Flow
from controlflow.tui.app import TUIApp
from controlflow.utilities.types import AssistantMessage


class Name(BaseModel):
Expand Down Expand Up @@ -49,13 +50,13 @@ async def run():
)
await asyncio.sleep(1)
t0.mark_failed(message="this is my result")
app.update_message(m_id="1", message="hello there", role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)
app.update_message(m_id="2", message="hello there" * 50, role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)
app.update_message(m_id="3", message="hello there", role="user")
app.update_message(AssistantMessage(content="hello there" * 50))
await asyncio.sleep(1)
app.update_message(m_id="4", message="hello there", role="assistant")
app.update_message(AssistantMessage(content="hello there"))
await asyncio.sleep(1)

await asyncio.sleep(inf)
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/tui/test2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ async def run():


if __name__ == "__main__":
# r = asyncio.run(run())
r = asyncio.run(run())
# print(r)
flow.run()
# flow.run()
Loading
Loading