Skip to content

Commit

Permalink
Limit the amount of data that TableAgent pulls (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Aug 21, 2024
1 parent e6a8212 commit 2009652
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
13 changes: 10 additions & 3 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
from ..base import Component
from ..dashboard import Config
from ..pipeline import Pipeline
from ..sources.base import BaseSQLSource
from ..sources.duckdb import DuckDBSource
from ..state import state
from ..transforms.sql import SQLOverride, SQLTransform, Transform
from ..transforms.sql import (
SQLLimit, SQLOverride, SQLTransform, Transform,
)
from ..views import VegaLiteView, View, hvPlotUIView
from .analysis import Analysis
from .config import FUZZY_TABLE_LENGTH
Expand Down Expand Up @@ -344,6 +347,7 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non
message_kwargs = dict(value=out, user=self.user)
self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width)


class TableAgent(LumenBaseAgent):
"""
Displays a single table / dataset. Does not discuss.
Expand Down Expand Up @@ -412,12 +416,15 @@ async def answer(self, messages: list | str):
sources = [src for src in available_sources if table in src]
source = sources[0] if sources else memory["current_source"]

get_kwargs = {}
if isinstance(source, BaseSQLSource):
get_kwargs['sql_transforms'] = [SQLLimit(limit=1_000_000)]
memory["current_source"] = source
memory["current_table"] = table
memory["current_pipeline"] = pipeline = Pipeline(
source=source, table=table
source=source, table=table, **get_kwargs
)
df = pipeline.__panel__()[-1].value
df = pipeline.data
if len(df) > 0:
memory["current_data"] = describe_data(df)
if self.debug:
Expand Down
23 changes: 21 additions & 2 deletions lumen/ai/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..dashboard import load_yaml
from ..downloads import Download
from ..pipeline import Pipeline
from ..transforms.sql import SQLLimit
from ..views.base import Table


Expand Down Expand Up @@ -117,8 +118,26 @@ async def _render_component(self):
)
download_pane = download.__panel__()
download_pane.sizing_mode = 'fixed'
download_pane.styles = {'position': 'absolute', 'right': '40px', 'top': '-35px'}
output = pn.Column(download_pane, table)
controls = pn.Row(
download_pane,
styles={'position': 'absolute', 'right': '40px', 'top': '-35px'}
)
for sql_limit in self.component.sql_transforms:
if isinstance(sql_limit, SQLLimit):
break
else:
sql_limit = None
if sql_limit:
limited = len(self.component.data) == sql_limit.limit
if limited:
def unlimit(e):
sql_limit.limit = None if e.new else 1_000_000
full_data = pn.widgets.Checkbox(
name='Full data', width=100, visible=limited
)
full_data.param.watch(unlimit, 'value')
controls.insert(0, full_data)
output = pn.Column(controls, table)
else:
output = self.component.__panel__()
self._last_output.clear()
Expand Down
2 changes: 1 addition & 1 deletion lumen/tests/transforms/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_sql_group_by_multi_columns():
def test_sql_limit():
assert (
SQLLimit.apply_to('SELECT * FROM TABLE', limit=10) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nLIMIT 10"""
"""SELECT * FROM TABLE LIMIT 10"""
)

def test_sql_columns():
Expand Down
11 changes: 4 additions & 7 deletions lumen/transforms/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,15 @@ class SQLLimit(SQLTransform):
Performs a LIMIT SQL operation on the query
"""

limit = param.Integer(default=1000, doc="Limit on the number of rows to return")
limit = param.Integer(default=1000, allow_None=True, doc="Limit on the number of rows to return")

transform_type: ClassVar[str] = 'sql_limit'

def apply(self, sql_in):
if self.limit is None:
return sql_in
sql_in = super().apply(sql_in)
template = """
SELECT
*
FROM ( {{sql_in}} )
LIMIT {{limit}}
"""
template = "{{sql_in}} LIMIT {{limit}}"
return self._render_template(template, sql_in=sql_in, limit=self.limit)


Expand Down

0 comments on commit 2009652

Please sign in to comment.