From 584b2e61d79c847be29006eb7beef291235ab43a Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Fri, 20 Sep 2024 04:29:22 -0700 Subject: [PATCH] Consider count when invalidating (#705) --- lumen/ai/assistant.py | 3 ++- lumen/ai/models.py | 8 +++++--- lumen/ai/utils.py | 8 ++++++++ lumen/sources/duckdb.py | 8 +++++++- lumen/tests/sources/test_duckdb.py | 6 ++++-- lumen/transforms/sql.py | 13 +++++++++++++ 6 files changed, 39 insertions(+), 7 deletions(-) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index a3d9c77d..b4f82109 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -262,10 +262,11 @@ async def _invalidate_memory(self, messages): raise KeyError(f'Table {table} could not be found in available sources.') try: - spec = get_schema(source, table=table) + spec = get_schema(source, table=table, include_count=True) except Exception: # If the selected table cannot be fetched we should invalidate it spec = None + sql = memory.get("current_sql") system = render_template("check_validity.jinja2", table=table, spec=spec, sql=sql, analyses=self._analyses) with self.interface.add_step(title="Checking memory...", user="Assistant") as step: diff --git a/lumen/ai/models.py b/lumen/ai/models.py index a734382b..0bde70e6 100644 --- a/lumen/ai/models.py +++ b/lumen/ai/models.py @@ -72,9 +72,11 @@ class Validity(BaseModel): correct_assessment: str = Field( description=""" - Thoughts on whether the current table meets the requirement - to answer the user's query, i.e. table contains all necessary columns, - unless user explicitly asks for a refresh. + Restate the current table and think thru whether the current table meets the requirement + to answer the user's query, i.e. table contains all the necessary raw columns. + However, if the query can be solved through SQL, the data is assumed to be valid. + If the number of rows is insufficient, the table is invalid. + If the user user explicitly asks for a refresh, then the table is invalid. """ ) diff --git a/lumen/ai/utils.py b/lumen/ai/utils.py index 628b9677..30d2e29e 100644 --- a/lumen/ai/utils.py +++ b/lumen/ai/utils.py @@ -102,6 +102,7 @@ def get_schema( table: str | None = None, include_min_max: bool = True, include_enum: bool = True, + include_count: bool = False, **get_kwargs ): if isinstance(source, Pipeline): @@ -112,6 +113,10 @@ def get_schema( schema = source.get_schema(table, **get_kwargs) schema = dict(schema) + # first pop regardless to prevent + # argument of type 'numpy.int64' is not iterable + count = schema.pop("count", None) + if include_min_max: for field, spec in schema.items(): if "inclusiveMinimum" in spec: @@ -134,6 +139,9 @@ def get_schema( if "enum" in spec: spec.pop("enum") + if count and include_count: + spec["count"] = count + schema = format_schema(schema) return schema diff --git a/lumen/sources/duckdb.py b/lumen/sources/duckdb.py index d85f6c1c..f30a0850 100644 --- a/lumen/sources/duckdb.py +++ b/lumen/sources/duckdb.py @@ -11,7 +11,7 @@ from ..serializers import Serializer from ..transforms import Filter from ..transforms.sql import ( - SQLDistinct, SQLFilter, SQLLimit, SQLMinMax, + SQLCount, SQLDistinct, SQLFilter, SQLLimit, SQLMinMax, ) from ..util import get_dataframe_schema from .base import ( @@ -313,4 +313,10 @@ def get_schema( cast = lambda v: v schema[col]['inclusiveMinimum'] = cast(minmax_data[f'{col}_min'].iloc[0]) schema[col]['inclusiveMaximum'] = cast(minmax_data[f'{col}_max'].iloc[0]) + + count_expr = SQLCount().apply(sql_expr) + count_expr = ' '.join(count_expr.splitlines()) + count_data = self._connection.execute(count_expr).fetch_df() + schema['count'] = cast(count_data['count'].iloc[0]) + return schemas if table is None else schemas[table] diff --git a/lumen/tests/sources/test_duckdb.py b/lumen/tests/sources/test_duckdb.py index daed6a00..1514f2ed 100644 --- a/lumen/tests/sources/test_duckdb.py +++ b/lumen/tests/sources/test_duckdb.py @@ -65,7 +65,8 @@ def test_duckdb_get_schema(duckdb_source): 'inclusiveMaximum': '2009-01-07 00:00:00', 'inclusiveMinimum': '2009-01-01 00:00:00', 'type': 'string' - } + }, + 'count': '5' } source = duckdb_source.get_schema('test_sql') source["C"]["enum"].sort() @@ -84,7 +85,8 @@ def test_duckdb_get_schema_with_none(duckdb_source): 'inclusiveMaximum': '2009-01-07 00:00:00', 'inclusiveMinimum': '2009-01-01 00:00:00', 'type': 'string' - } + }, + 'count': '5' } source = duckdb_source.get_schema('test_sql_with_none') source["C"]["enum"].sort(key=enum.index) diff --git a/lumen/transforms/sql.py b/lumen/transforms/sql.py index 43fbe3a1..d9917851 100644 --- a/lumen/transforms/sql.py +++ b/lumen/transforms/sql.py @@ -133,6 +133,19 @@ def apply(self, sql_in): return self._render_template(template, sql_in=sql_in, columns=', '.join(map(quote, self.columns))) +class SQLCount(SQLTransform): + + transform_type: ClassVar[str] = 'sql_count' + + def apply(self, sql_in): + sql_in = super().apply(sql_in) + template = """ + SELECT + COUNT({{column}}) as count + FROM ( {{sql_in}} )""" + return self._render_template(template, sql_in=sql_in) + + class SQLMinMax(SQLTransform): columns = param.List(default=[], doc="Columns to return min/max values for.")