From 0ef68b0c493ae865c2da22ee30ae0f1d8dffa10b Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Fri, 26 Jul 2024 03:46:59 -0700 Subject: [PATCH] Improve passing context (#619) Co-authored-by: Philipp Rudiger --- lumen/ai/assistant.py | 9 ++++++--- lumen/ai/views.py | 3 +++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index d4f63c8f..183d1549 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -8,7 +8,7 @@ import param from panel import bind -from panel.chat import ChatInterface +from panel.chat import ChatInterface, ChatStep from panel.layout import Column, FlexBox, Tabs from panel.pane import HTML, Markdown from panel.viewable import Viewer @@ -338,7 +338,6 @@ async def _get_agent(self, messages: list | str): step.stream(output.chain_of_thought, replace=True) agent = output.agent step.success_title = f"Selected {agent}" - messages.append({"role": "assistant", "content": output.chain_of_thought}) if agent is None: return None @@ -391,6 +390,9 @@ async def _get_agent(self, messages: list | str): def _serialize(self, obj): if isinstance(obj, (Tabs, Column)): for o in obj: + if isinstance(obj, ChatStep) and not obj.title.startswith("Selected"): + # only want the chain of thoughts from the selected agent + continue if hasattr(o, "visible") and o.visible: break return self._serialize(o) @@ -422,7 +424,8 @@ async def invoke(self, messages: list | str) -> str: print(f"{message['role']!r}: {message['content']}") print("ENTRY" + "-" * 10) - await agent.invoke(messages[-2:]) + print("\n\033[95mAGENT:\033[0m", agent, messages[-3:]) + await agent.invoke(messages[-3:]) self._current_agent.object = "## No agent active" if "current_pipeline" in agent.provides: diff --git a/lumen/ai/views.py b/lumen/ai/views.py index e86522d7..265d3bce 100644 --- a/lumen/ai/views.py +++ b/lumen/ai/views.py @@ -121,6 +121,9 @@ def __panel__(self): def __repr__(self): return self.spec + def __str__(self): + return self.spec + class SQLOutput(LumenOutput):