Skip to content

Commit

Permalink
Improve detection of table in source (#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Aug 21, 2024
1 parent 756b746 commit e6a8212
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
9 changes: 8 additions & 1 deletion lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ async def answer(self, messages: list | str):
tables_schema_str += f"### {table}\n```yaml\n{yaml.safe_dump(schema)}```\n"
else:
source = memory["current_source"]
available_sources = [source]
tables_to_source = {table: source for table in source.get_tables()}
tables_schema_str = ""

Expand Down Expand Up @@ -405,7 +406,13 @@ async def answer(self, messages: list | str):
table = tables[0]
step.stream(f"Selected table: {table}")

memory["current_source"] = source = tables_to_source.get(table, memory['current_source'])
if table in tables_to_source:
source = tables_to_source[table]
else:
sources = [src for src in available_sources if table in src]
source = sources[0] if sources else memory["current_source"]

memory["current_source"] = source
memory["current_table"] = table
memory["current_pipeline"] = pipeline = Pipeline(
source=source, table=table
Expand Down
4 changes: 2 additions & 2 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ async def _invalidate_memory(self, messages):
return

source = memory.get("current_source")
if table not in source.get_tables():
sources = [src for src in memory.get('available_sources', []) if table in src.get_tables()]
if table not in source:
sources = [src for src in memory.get('available_sources', []) if table in src]
if sources:
memory['current_source'] = source = sources[0]
else:
Expand Down
3 changes: 3 additions & 0 deletions lumen/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ def _set_cache(
f"Error during saving process: {e}"
)

def __contains__(self, table):
return table in self.get_tables()

def clear_cache(self, *events: param.parameterized.Event):
"""
Clears any cached data.
Expand Down
7 changes: 7 additions & 0 deletions lumen/sources/intake_dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def _get_source(self, table):
table = tables[normalized_tables.index(normalized_table)]
return super()._get_source(table)

def __contains__(self, table):
try:
self._get_source(table)
return True
except KeyError:
return False

def create_sql_expr_source(self, tables: dict[str, str], **kwargs):
"""
Creates a new SQL Source given a set of table names and
Expand Down

0 comments on commit e6a8212

Please sign in to comment.