Skip to content

Commit

Permalink
Add cot to table (#709)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp Rudiger <prudiger@anaconda.com>
  • Loading branch information
ahuang11 and philippjfr authored Sep 23, 2024
1 parent a79738a commit 780d98a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
25 changes: 19 additions & 6 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,20 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non
self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width)



class TableAgent(LumenBaseAgent):
"""
Displays a single table / dataset. Does not discuss.
"""

system_prompt = param.String(
default="You are an agent responsible for finding the correct table based on the user prompt."
default=textwrap.dedent(
"""
Identify the most relevant table that contains the most columns useful
for answering the user's query. Keep in mind that additional tables
can be joined later, so focus on selecting the best starting point.
"""
)
)

requires = param.List(default=["current_source"], readonly=True)
Expand All @@ -365,6 +372,9 @@ class TableAgent(LumenBaseAgent):
def _create_table_model(tables):
table_model = create_model(
"Table",
chain_of_thought=(str, FieldInfo(
description="The thought process behind selecting the table, listing out which columns are useful."
)),
relevant_table=(Literal[tables], FieldInfo(
description="The most relevant table based on the user query; if none are relevant, select the first."
))
Expand Down Expand Up @@ -409,9 +419,10 @@ async def answer(self, messages: list | str):
allow_partial=False,
)
table = result.relevant_table
step.stream(f"{result.chain_of_thought}\n\nSelected table: {table}")
else:
table = tables[0]
step.stream(f"Selected table: {table}")
step.stream(f"Selected table: {table}")

if table in tables_to_source:
source = tables_to_source[table]
Expand Down Expand Up @@ -496,10 +507,12 @@ class SQLAgent(LumenBaseAgent):
"""

system_prompt = param.String(
default=textwrap.dedent("""
You are an agent responsible for writing a SQL query that will
perform the data transformations the user requested.
""")
default=textwrap.dedent(
"""
You are an agent responsible for writing a SQL query that will
perform the data transformations the user requested.
"""
)
)

requires = param.List(default=["current_table", "current_source"], readonly=True)
Expand Down
4 changes: 3 additions & 1 deletion lumen/ai/prompts/check_validity.jinja2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Based on the latest user's query, decide whether the current table and data contains the the required data. If the user is referencing "this data" or "that" they probably are referring to the current data. Pay particular attention to whether the data actually contains the columns required to answer the query, e.g. if they are asking for a location but there are no location related column the data is invalid. However, if the query can be solved through SQL, the data is assumed to be valid.
Based on the latest user's query, decide whether the current table and data contains the the required data. If the user is referencing "this data" or "that" they probably are referring to the current data. Pay particular attention to whether the data actually contains the columns required to answer the query, e.g. if they are asking for a location but there are no location related column the data is invalid. If the number of rows is insufficient, the table is invalid. However, if the query can be solved through SQL, the data is assumed to be valid.

### Current Table:
```
Expand All @@ -23,7 +23,9 @@ Schema for current table could not be determined.
If the user requests one of the current analyses, the data is assumed to be valid.

### Current Analyses:
```
{% for analysis in analyses %}
- {{ analysis }}
{% endfor %}
{% endif %}
```

0 comments on commit 780d98a

Please sign in to comment.