From 780d98ad8b53f673907ca18f7ddb6feebe6510e6 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Mon, 23 Sep 2024 03:48:06 -0700 Subject: [PATCH] Add cot to table (#709) Co-authored-by: Philipp Rudiger --- lumen/ai/agents.py | 25 +++++++++++++++++++------ lumen/ai/prompts/check_validity.jinja2 | 4 +++- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index d892f87d..de7cc3f4 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -348,13 +348,20 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width) + class TableAgent(LumenBaseAgent): """ Displays a single table / dataset. Does not discuss. """ system_prompt = param.String( - default="You are an agent responsible for finding the correct table based on the user prompt." + default=textwrap.dedent( + """ + Identify the most relevant table that contains the most columns useful + for answering the user's query. Keep in mind that additional tables + can be joined later, so focus on selecting the best starting point. + """ + ) ) requires = param.List(default=["current_source"], readonly=True) @@ -365,6 +372,9 @@ class TableAgent(LumenBaseAgent): def _create_table_model(tables): table_model = create_model( "Table", + chain_of_thought=(str, FieldInfo( + description="The thought process behind selecting the table, listing out which columns are useful." + )), relevant_table=(Literal[tables], FieldInfo( description="The most relevant table based on the user query; if none are relevant, select the first." )) @@ -409,9 +419,10 @@ async def answer(self, messages: list | str): allow_partial=False, ) table = result.relevant_table + step.stream(f"{result.chain_of_thought}\n\nSelected table: {table}") else: table = tables[0] - step.stream(f"Selected table: {table}") + step.stream(f"Selected table: {table}") if table in tables_to_source: source = tables_to_source[table] @@ -496,10 +507,12 @@ class SQLAgent(LumenBaseAgent): """ system_prompt = param.String( - default=textwrap.dedent(""" - You are an agent responsible for writing a SQL query that will - perform the data transformations the user requested. - """) + default=textwrap.dedent( + """ + You are an agent responsible for writing a SQL query that will + perform the data transformations the user requested. + """ + ) ) requires = param.List(default=["current_table", "current_source"], readonly=True) diff --git a/lumen/ai/prompts/check_validity.jinja2 b/lumen/ai/prompts/check_validity.jinja2 index 1b48e8ee..23e58565 100644 --- a/lumen/ai/prompts/check_validity.jinja2 +++ b/lumen/ai/prompts/check_validity.jinja2 @@ -1,4 +1,4 @@ -Based on the latest user's query, decide whether the current table and data contains the the required data. If the user is referencing "this data" or "that" they probably are referring to the current data. Pay particular attention to whether the data actually contains the columns required to answer the query, e.g. if they are asking for a location but there are no location related column the data is invalid. However, if the query can be solved through SQL, the data is assumed to be valid. +Based on the latest user's query, decide whether the current table and data contains the the required data. If the user is referencing "this data" or "that" they probably are referring to the current data. Pay particular attention to whether the data actually contains the columns required to answer the query, e.g. if they are asking for a location but there are no location related column the data is invalid. If the number of rows is insufficient, the table is invalid. However, if the query can be solved through SQL, the data is assumed to be valid. ### Current Table: ``` @@ -23,7 +23,9 @@ Schema for current table could not be determined. If the user requests one of the current analyses, the data is assumed to be valid. ### Current Analyses: +``` {% for analysis in analyses %} - {{ analysis }} {% endfor %} {% endif %} +```