Skip to content

Commit

Permalink
search across all sources (#676)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Aug 21, 2024
1 parent c50ead0 commit d48b429
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
22 changes: 19 additions & 3 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non
message_kwargs = dict(value=out, user=self.user)
self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width)


class TableAgent(LumenBaseAgent):
"""
Displays a single table / dataset. Does not discuss.
Expand All @@ -366,7 +365,22 @@ def _create_table_model(tables):
return table_model

async def answer(self, messages: list | str):
tables = tuple(memory["current_source"].get_tables())

if len(memory["available_sources"]) >= 1:
available_sources = memory["available_sources"]
tables_to_source = {}
tables_schema_str = "\nHere are the table schemas\n"
for source in available_sources:
for table in source.get_tables():
tables_to_source[table] = source
schema = get_schema(source, table, include_min_max=False, include_enum=False, limit=1)
tables_schema_str += f"### {table}\n```yaml\n{yaml.safe_dump(schema)}```\n"
else:
source = memory["current_source"]
tables_to_source = {table: source for table in source.get_tables()}
tables_schema_str = ""

tables = tuple(tables_to_source)
if len(tables) == 1:
table = tables[0]
else:
Expand All @@ -376,7 +390,7 @@ async def answer(self, messages: list | str):
tables = closest_tables
elif len(tables) > FUZZY_TABLE_LENGTH:
tables = await self._get_closest_tables(messages, tables)
system_prompt = await self._system_prompt_with_context(messages)
system_prompt = await self._system_prompt_with_context(messages) + tables_schema_str
if self.debug:
print(f"{self.name} is being instructed that it should {system_prompt}")
if len(tables) > 1:
Expand All @@ -391,6 +405,8 @@ async def answer(self, messages: list | str):
else:
table = tables[0]
step.stream(f"Selected table: {table}")

memory["current_source"] = tables_to_source[table]
memory["current_table"] = table
memory["current_pipeline"] = pipeline = Pipeline(
source=memory["current_source"], table=table
Expand Down
20 changes: 18 additions & 2 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,18 @@ def format_schema(schema):


def get_schema(
source: Source | Pipeline, table: str | None = None, include_min_max: bool = True
source: Source | Pipeline,
table: str | None = None,
include_min_max: bool = True,
include_enum: bool = True,
**get_kwargs
):
if isinstance(source, Pipeline):
schema = source.get_schema()
else:
schema = source.get_schema(table, limit=100)
if "limit" not in get_kwargs:
get_kwargs["limit"] = 100
schema = source.get_schema(table, **get_kwargs)
schema = dict(schema)

if include_min_max:
Expand All @@ -118,6 +124,16 @@ def get_schema(
spec.pop("inclusiveMinimum")
if "inclusiveMaximum" in spec:
spec.pop("inclusiveMaximum")
if "min" in spec:
spec.pop("min")
if "max" in spec:
spec.pop("max")

if not include_enum:
for field, spec in schema.items():
if "enum" in spec:
spec.pop("enum")

schema = format_schema(schema)
return schema

Expand Down

0 comments on commit d48b429

Please sign in to comment.