Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async data calls #714

Merged
merged 10 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
)
from .translate import param_to_pydantic
from .utils import (
clean_sql, describe_data, get_schema, render_template, retry_llm_output,
clean_sql, describe_data, get_data, get_pipeline, get_schema,
render_template, retry_llm_output,
)
from .views import AnalysisOutput, LumenOutput, SQLOutput

Expand Down Expand Up @@ -274,7 +275,7 @@ async def _system_prompt_with_context(
context = f"Available tables: {', '.join(closest_tables)}"
else:
memory["current_table"] = table = memory.get("current_table", tables[0])
schema = get_schema(memory["current_source"], table)
schema = await get_schema(memory["current_source"], table)
if schema:
context = f"{table} with schema: {schema}"

Expand Down Expand Up @@ -389,7 +390,7 @@ async def answer(self, messages: list | str):
for table in source.get_tables():
tables_to_source[table] = source
if isinstance(source, DuckDBSource) and source.ephemeral:
schema = get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n"
else:
tables_schema_str += f"### {table}\n"
Expand Down Expand Up @@ -435,12 +436,12 @@ async def answer(self, messages: list | str):
get_kwargs['sql_transforms'] = [SQLLimit(limit=1_000_000)]
memory["current_source"] = source
memory["current_table"] = table
memory["current_pipeline"] = pipeline = Pipeline(
memory["current_pipeline"] = pipeline = await get_pipeline(
source=source, table=table, **get_kwargs
)
df = pipeline.data
df = await get_data(pipeline)
if len(df) > 0:
memory["current_data"] = describe_data(df)
memory["current_data"] = await describe_data(df)
if self.debug:
print(f"{self.name} thinks that the user is talking about {table=!r}.")
return pipeline
Expand Down Expand Up @@ -581,7 +582,7 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non
# Get validated query
sql_query = sql_expr_source.tables[expr_slug]
sql_transforms = [SQLLimit(limit=1_000_000)]
pipeline = Pipeline(
pipeline = await get_pipeline(
source=sql_expr_source, table=expr_slug, sql_transforms=sql_transforms
)
except InstructorRetryException as e:
Expand All @@ -605,9 +606,9 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non
step.status = "failed"
raise e

df = pipeline.data
df = await get_data(pipeline)
if len(df) > 0:
memory["current_data"] = describe_data(df)
memory["current_data"] = await describe_data(df)

memory["available_sources"].append(sql_expr_source)
memory["current_source"] = sql_expr_source
Expand Down Expand Up @@ -701,7 +702,7 @@ async def answer(self, messages: list | str):
if not hasattr(source, "get_sql_expr"):
return None

schema = get_schema(source, table, include_min_max=False)
schema = await get_schema(source, table, include_min_max=False)
join_required = await self.check_join_required(messages, schema, table)
if join_required:
tables_to_source = await self.find_join_tables(messages)
Expand All @@ -713,7 +714,7 @@ async def answer(self, messages: list | str):
if source_table == table:
table_schema = schema
else:
table_schema = get_schema(source, source_table, include_min_max=False)
table_schema = await get_schema(source, source_table, include_min_max=False)
table_schemas[source_table] = {
"schema": yaml.dump(table_schema),
"sql": source.get_sql_expr(source_table)
Expand Down Expand Up @@ -754,13 +755,14 @@ async def answer(self, messages: list | str) -> Transform:
if "current_pipeline" in memory:
pipeline = memory["current_pipeline"]
else:
pipeline = Pipeline(
pipeline = await get_pipeline(
source=memory["current_source"],
table=memory["current_table"],
)
memory["current_pipeline"] = pipeline
pipeline._update_data(force=True)
memory["current_data"] = describe_data(pipeline.data)
await asyncio.to_thread(pipeline._update_data, force=True)
data = await get_data(pipeline)
memory["current_data"] = await describe_data(data)
return pipeline

async def invoke(self, messages: list | str):
Expand Down Expand Up @@ -867,7 +869,7 @@ async def _construct_transform(
self, messages: list | str, transform: type[Transform], system_prompt: str
) -> Transform:
excluded = transform._internal_params + ["controls", "type"]
schema = get_schema(memory["current_pipeline"])
schema = await get_schema(memory["current_pipeline"])
table = memory["current_table"]
model = param_to_pydantic(transform, excluded=excluded, schema=schema)[
transform.__name__
Expand Down Expand Up @@ -912,8 +914,9 @@ async def answer(self, messages: list | str) -> Transform:
else:
pipeline.add_transform(transform)

pipeline._update_data(force=True)
memory["current_data"] = describe_data(pipeline.data)
await asyncio.to_thread(pipeline._update_data, force=True)
data = await get_data(pipeline)
memory["current_data"] = await describe_data(data)
return pipeline

async def invoke(self, messages: list | str):
Expand All @@ -927,15 +930,15 @@ class BaseViewAgent(LumenBaseAgent):

provides = param.List(default=["current_plot"], readonly=True)

def _extract_spec(self, model: BaseModel):
async def _extract_spec(self, model: BaseModel):
return dict(model)

async def answer(self, messages: list | str) -> hvPlotUIView:
pipeline = memory["current_pipeline"]

# Write prompts
system_prompt = await self._system_prompt_with_context(messages)
schema = get_schema(pipeline, include_min_max=False)
schema = await get_schema(pipeline, include_min_max=False)
view_prompt = render_template(
"plot_agent.jinja2",
schema=yaml.dump(schema),
Expand All @@ -951,7 +954,7 @@ async def answer(self, messages: list | str) -> hvPlotUIView:
system=system_prompt + view_prompt,
response_model=self._get_model(schema),
)
spec = self._extract_spec(output)
spec = await self._extract_spec(output)
chain_of_thought = spec.pop("chain_of_thought")
with self.interface.add_step(title="Generating view...") as step:
step.stream(chain_of_thought)
Expand Down Expand Up @@ -1002,7 +1005,7 @@ def _get_model(cls, schema):
})
return model[cls.view_type.__name__]

def _extract_spec(self, model):
async def _extract_spec(self, model):
pipeline = memory["current_pipeline"]
spec = {
key: val for key, val in dict(model).items()
Expand All @@ -1014,7 +1017,8 @@ def _extract_spec(self, model):

# Add defaults
spec["responsive"] = True
if len(pipeline.data) > 20000 and spec["kind"] in ("line", "scatter", "points"):
data = await get_data(pipeline)
if len(data) > 20000 and spec["kind"] in ("line", "scatter", "points"):
spec["rasterize"] = True
spec["cnorm"] = "log"
return spec
Expand All @@ -1039,7 +1043,7 @@ class VegaLiteAgent(BaseViewAgent):
def _get_model(cls, schema):
return VegaLiteSpec

def _extract_spec(self, model):
async def _extract_spec(self, model):
vega_spec = json.loads(model.json_spec)
if "$schema" not in vega_spec:
vega_spec["$schema"] = "https://vega.github.io/schema/vega-lite/v5.json"
Expand Down Expand Up @@ -1092,7 +1096,7 @@ async def _system_prompt_with_context(

async def answer(self, messages: list | str, agents: list[Agent] | None = None):
pipeline = memory['current_pipeline']
analyses = {a.name: a for a in self.analyses if a.applies(pipeline)}
analyses = {a.name: a for a in self.analyses if await a.applies(pipeline)}
if not analyses:
print("NONE found...")
return None
Expand Down Expand Up @@ -1125,8 +1129,10 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None):
with self.interface.add_step(title="Creating view...", user="Assistant") as step:
await asyncio.sleep(0.1) # necessary to give it time to render before calling sync function...
analysis_callable = analyses[analysis_name].instance(agents=agents)

data = await get_data(pipeline)
for field in analysis_callable._field_params:
analysis_callable.param[field].objects = list(pipeline.data.columns)
analysis_callable.param[field].objects = list(data.columns)
memory["current_analysis"] = analysis_callable

if analysis_callable.autorun:
Expand All @@ -1143,8 +1149,8 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None):
# Ensure current_data reflects processed pipeline
if pipeline is not memory['current_pipeline']:
pipeline = memory['current_pipeline']
if len(pipeline.data) > 0:
memory["current_data"] = describe_data(pipeline.data)
if len(data) > 0:
memory["current_data"] = await describe_data(data)
yaml_spec = yaml.dump(spec)
step.stream(f"Generated view\n```yaml\n{yaml_spec}\n```")
step.success_title = "Generated view"
Expand Down
12 changes: 7 additions & 5 deletions lumen/ai/analysis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import panel as pn
import param

from lumen.ai.utils import get_data

from ..base import Component
from .controls import SourceControls
from .memory import memory
from .utils import get_schema


class Analysis(param.ParameterizedFunction):
Expand Down Expand Up @@ -34,13 +35,14 @@ class Analysis(param.ParameterizedFunction):
_field_params = []

@classmethod
def applies(cls, pipeline) -> bool:
async def applies(cls, pipeline) -> bool:
applies = True
data = await get_data(pipeline)
for col in cls.columns:
if isinstance(col, tuple):
applies &= any(c in pipeline.data.columns for c in col)
applies &= any(c in data.columns for c in col)
else:
applies &= col in pipeline.data.columns
applies &= col in data.columns
return applies

def controls(self):
Expand Down Expand Up @@ -80,7 +82,7 @@ def controls(self):
table = memory.get("current_table")
self._previous_source = source
self._previous_table = table
columns = list(get_schema(source, table=table).keys())
columns = list(source.get_schema(table).keys())
index_col = pn.widgets.AutocompleteInput.from_param(
self.param.index_col, options=columns, name="Join on",
placeholder="Start typing column name", search_strategy="includes",
Expand Down
10 changes: 5 additions & 5 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def use_suggestion(event):
else:
return
await agent.invoke([{'role': 'user', 'content': contents}], agents=self.agents)
self._add_analysis_suggestions()
await self._add_analysis_suggestions()
else:
self.interface.send(contents)

Expand Down Expand Up @@ -233,13 +233,13 @@ async def run_demo(event):
self.interface.param.watch(hide_suggestions, "objects")
return message

def _add_analysis_suggestions(self):
async def _add_analysis_suggestions(self):
pipeline = memory['current_pipeline']
current_analysis = memory.get("current_analysis")
allow_consecutive = getattr(current_analysis, '_consecutive_calls', True)
applicable_analyses = []
for analysis in self._analyses:
if analysis.applies(pipeline) and (allow_consecutive or analysis is not type(current_analysis)):
if await analysis.applies(pipeline) and (allow_consecutive or analysis is not type(current_analysis)):
applicable_analyses.append(analysis)
self._add_suggestions_to_footer(
[f"Apply {analysis.__name__}" for analysis in applicable_analyses],
Expand All @@ -263,7 +263,7 @@ async def _invalidate_memory(self, messages):
raise KeyError(f'Table {table} could not be found in available sources.')

try:
spec = get_schema(source, table=table, include_count=True)
spec = await get_schema(source, table=table, include_count=True)
except Exception:
# If the selected table cannot be fetched we should invalidate it
spec = None
Expand Down Expand Up @@ -482,7 +482,7 @@ async def invoke(self, messages: list | str) -> str:
await agent.invoke(messages[-context_length:], **kwargs)
self._current_agent.object = "## No agent active"
if "current_pipeline" in agent.provides:
self._add_analysis_suggestions()
await self._add_analysis_suggestions()
print("\033[92mDONE\033[0m", "\n\n")

def controls(self):
Expand Down
Loading
Loading