Skip to content

Commit

Permalink
Consider count when invalidating (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Sep 20, 2024
1 parent 453312b commit 584b2e6
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 7 deletions.
3 changes: 2 additions & 1 deletion lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,11 @@ async def _invalidate_memory(self, messages):
raise KeyError(f'Table {table} could not be found in available sources.')

try:
spec = get_schema(source, table=table)
spec = get_schema(source, table=table, include_count=True)
except Exception:
# If the selected table cannot be fetched we should invalidate it
spec = None

sql = memory.get("current_sql")
system = render_template("check_validity.jinja2", table=table, spec=spec, sql=sql, analyses=self._analyses)
with self.interface.add_step(title="Checking memory...", user="Assistant") as step:
Expand Down
8 changes: 5 additions & 3 deletions lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ class Validity(BaseModel):

correct_assessment: str = Field(
description="""
Thoughts on whether the current table meets the requirement
to answer the user's query, i.e. table contains all necessary columns,
unless user explicitly asks for a refresh.
Restate the current table and think thru whether the current table meets the requirement
to answer the user's query, i.e. table contains all the necessary raw columns.
However, if the query can be solved through SQL, the data is assumed to be valid.
If the number of rows is insufficient, the table is invalid.
If the user user explicitly asks for a refresh, then the table is invalid.
"""
)

Expand Down
8 changes: 8 additions & 0 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def get_schema(
table: str | None = None,
include_min_max: bool = True,
include_enum: bool = True,
include_count: bool = False,
**get_kwargs
):
if isinstance(source, Pipeline):
Expand All @@ -112,6 +113,10 @@ def get_schema(
schema = source.get_schema(table, **get_kwargs)
schema = dict(schema)

# first pop regardless to prevent
# argument of type 'numpy.int64' is not iterable
count = schema.pop("count", None)

if include_min_max:
for field, spec in schema.items():
if "inclusiveMinimum" in spec:
Expand All @@ -134,6 +139,9 @@ def get_schema(
if "enum" in spec:
spec.pop("enum")

if count and include_count:
spec["count"] = count

schema = format_schema(schema)
return schema

Expand Down
8 changes: 7 additions & 1 deletion lumen/sources/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..serializers import Serializer
from ..transforms import Filter
from ..transforms.sql import (
SQLDistinct, SQLFilter, SQLLimit, SQLMinMax,
SQLCount, SQLDistinct, SQLFilter, SQLLimit, SQLMinMax,
)
from ..util import get_dataframe_schema
from .base import (
Expand Down Expand Up @@ -313,4 +313,10 @@ def get_schema(
cast = lambda v: v
schema[col]['inclusiveMinimum'] = cast(minmax_data[f'{col}_min'].iloc[0])
schema[col]['inclusiveMaximum'] = cast(minmax_data[f'{col}_max'].iloc[0])

count_expr = SQLCount().apply(sql_expr)
count_expr = ' '.join(count_expr.splitlines())
count_data = self._connection.execute(count_expr).fetch_df()
schema['count'] = cast(count_data['count'].iloc[0])

return schemas if table is None else schemas[table]
6 changes: 4 additions & 2 deletions lumen/tests/sources/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def test_duckdb_get_schema(duckdb_source):
'inclusiveMaximum': '2009-01-07 00:00:00',
'inclusiveMinimum': '2009-01-01 00:00:00',
'type': 'string'
}
},
'count': '5'
}
source = duckdb_source.get_schema('test_sql')
source["C"]["enum"].sort()
Expand All @@ -84,7 +85,8 @@ def test_duckdb_get_schema_with_none(duckdb_source):
'inclusiveMaximum': '2009-01-07 00:00:00',
'inclusiveMinimum': '2009-01-01 00:00:00',
'type': 'string'
}
},
'count': '5'
}
source = duckdb_source.get_schema('test_sql_with_none')
source["C"]["enum"].sort(key=enum.index)
Expand Down
13 changes: 13 additions & 0 deletions lumen/transforms/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ def apply(self, sql_in):
return self._render_template(template, sql_in=sql_in, columns=', '.join(map(quote, self.columns)))


class SQLCount(SQLTransform):

transform_type: ClassVar[str] = 'sql_count'

def apply(self, sql_in):
sql_in = super().apply(sql_in)
template = """
SELECT
COUNT({{column}}) as count
FROM ( {{sql_in}} )"""
return self._render_template(template, sql_in=sql_in)


class SQLMinMax(SQLTransform):

columns = param.List(default=[], doc="Columns to return min/max values for.")
Expand Down

0 comments on commit 584b2e6

Please sign in to comment.