From e6a8212cb7be1bbc015814e7a986dcf0a8d20cff Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 21 Aug 2024 14:44:51 +0200 Subject: [PATCH] Improve detection of table in source (#681) --- lumen/ai/agents.py | 9 ++++++++- lumen/ai/assistant.py | 4 ++-- lumen/sources/base.py | 3 +++ lumen/sources/intake_dremio.py | 7 +++++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index afb2a0bd0..735b71712 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -376,6 +376,7 @@ async def answer(self, messages: list | str): tables_schema_str += f"### {table}\n```yaml\n{yaml.safe_dump(schema)}```\n" else: source = memory["current_source"] + available_sources = [source] tables_to_source = {table: source for table in source.get_tables()} tables_schema_str = "" @@ -405,7 +406,13 @@ async def answer(self, messages: list | str): table = tables[0] step.stream(f"Selected table: {table}") - memory["current_source"] = source = tables_to_source.get(table, memory['current_source']) + if table in tables_to_source: + source = tables_to_source[table] + else: + sources = [src for src in available_sources if table in src] + source = sources[0] if sources else memory["current_source"] + + memory["current_source"] = source memory["current_table"] = table memory["current_pipeline"] = pipeline = Pipeline( source=source, table=table diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 15e276db7..dc4fb2bab 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -250,8 +250,8 @@ async def _invalidate_memory(self, messages): return source = memory.get("current_source") - if table not in source.get_tables(): - sources = [src for src in memory.get('available_sources', []) if table in src.get_tables()] + if table not in source: + sources = [src for src in memory.get('available_sources', []) if table in src] if sources: memory['current_source'] = source = sources[0] else: diff --git a/lumen/sources/base.py b/lumen/sources/base.py index 4bade3244..45052978e 100644 --- a/lumen/sources/base.py +++ b/lumen/sources/base.py @@ -391,6 +391,9 @@ def _set_cache( f"Error during saving process: {e}" ) + def __contains__(self, table): + return table in self.get_tables() + def clear_cache(self, *events: param.parameterized.Event): """ Clears any cached data. diff --git a/lumen/sources/intake_dremio.py b/lumen/sources/intake_dremio.py index fd4adb414..59ee59742 100644 --- a/lumen/sources/intake_dremio.py +++ b/lumen/sources/intake_dremio.py @@ -34,6 +34,13 @@ def _get_source(self, table): table = tables[normalized_tables.index(normalized_table)] return super()._get_source(table) + def __contains__(self, table): + try: + self._get_source(table) + return True + except KeyError: + return False + def create_sql_expr_source(self, tables: dict[str, str], **kwargs): """ Creates a new SQL Source given a set of table names and