Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make get data async #713

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,9 @@ def _render_sql(self, query):

@retry_llm_output()
async def _create_valid_sql(self, messages, system, tables_to_source, errors=None):
import time
start = time.perf_counter()

if errors:
last_query = self.interface.serialize()[-1]["content"].replace("```sql", "").rstrip("```").strip()
errors = '\n'.join(errors)
Expand All @@ -552,6 +555,8 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non
step_message += f"\n```sql\n{sql_query}\n```"
step.stream(step_message, replace=True)

print(f"SQL thought took {time.perf_counter() - start:.2f}s")

if not sql_query:
raise ValueError("No SQL query was generated.")

Expand All @@ -574,16 +579,21 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non
else:
source = next(iter(sources))

# time
print(f"SQL mirrors took {time.perf_counter() - start:.2f}s")

# check whether the SQL query is valid
expr_slug = output.expr_slug
try:
sql_expr_source = source.create_sql_expr_source({expr_slug: sql_query})
# Get validated query
sql_query = sql_expr_source.tables[expr_slug]
sql_transforms = [SQLLimit(limit=1_000_000)]
print(f"SQL source took {time.perf_counter() - start:.2f}s")
pipeline = Pipeline(
source=sql_expr_source, table=expr_slug, sql_transforms=sql_transforms
)
print(f"SQL run took {time.perf_counter() - start:.2f}s")
except InstructorRetryException as e:
error_msg = str(e)
step.stream(f'\n```python\n{error_msg}\n```')
Expand All @@ -605,10 +615,14 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non
step.status = "failed"
raise e

print(f"SQL validation took {time.perf_counter() - start:.2f}s")

df = pipeline.data
if len(df) > 0:
memory["current_data"] = describe_data(df)

print(f"SQL describe took {time.perf_counter() - start:.2f}s")

memory["available_sources"].append(sql_expr_source)
memory["current_source"] = sql_expr_source
memory["current_pipeline"] = pipeline
Expand Down
5 changes: 3 additions & 2 deletions lumen/sources/ae5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import datetime as dt

from typing import Any
Expand Down Expand Up @@ -288,10 +289,10 @@ def _get_resource_allocations(self):
return allocations[self._allocation_columns]

@cached
def get(self, table, **query):
async def get(self, table, **query):
if table not in self._tables:
raise ValueError(f"AE5Source has no '{table}' table, choose from {self._tables!r}.")
return getattr(self, f'_get_{table}')()
return await asyncio.to_thread(getattr(self, f'_get_{table}'))

@cached_schema
def get_schema(
Expand Down
30 changes: 16 additions & 14 deletions lumen/sources/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import hashlib
import json
import re
Expand Down Expand Up @@ -458,7 +459,7 @@ def get_schema(
msg = match_suggestion_message(table or '', names, msg)
raise ValidationError(msg) from e

def get(self, table: str, **query) -> DataFrame:
async def get(self, table: str, **query) -> DataFrame:
"""
Return a table; optionally filtered by the given query.

Expand Down Expand Up @@ -501,9 +502,9 @@ def get_schema(
response.json().items()}

@cached
def get(self, table: str, **query) -> pd.DataFrame:
async def get(self, table: str, **query) -> pd.DataFrame:
query = dict(table=table, **query)
r = requests.get(self.url+'/data', params=query)
r = await asyncio.to_thread(requests.get, self.url+'/data', params=query)
df = pd.DataFrame(r.json())
return df

Expand All @@ -527,9 +528,9 @@ def get_schema(
else:
return {t: get_dataframe_schema(self.get(t))['items']['properties'] for t in self.get_tables()}

def get(self, table: str, **query) -> pd.DataFrame:
async def get(self, table: str, **query) -> pd.DataFrame:
dask = query.pop('__dask', False)
table = self.tables.get(table)
table = await asyncio.to_thread(self.tables.get, table)
df = FilterTransform.apply_to(table, conditions=list(query.items()))
return df if dask or not hasattr(df, 'compute') else df.compute()

Expand Down Expand Up @@ -700,9 +701,9 @@ def _load_table(self, table: str, dask: bool = True) -> DataFrame:
return df

@cached
def get(self, table: str, **query) -> DataFrame:
async def get(self, table: str, **query) -> DataFrame:
dask = query.pop('__dask', self.dask)
df = self._load_table(table)
df = await asyncio.to_thread(self._load_table, table)
df = FilterTransform.apply_to(df, conditions=list(query.items()))
return df if dask or not hasattr(df, 'compute') else df.compute()

Expand Down Expand Up @@ -820,11 +821,11 @@ def get_tables(self) -> list[str]:
return ['status']

@cached
def get(self, table: str, **query) -> pd.DataFrame:
async def get(self, table: str, **query) -> pd.DataFrame:
data = []
for url in self.urls:
try:
r = requests.get(url)
r = await asyncio.to_thread(requests.get, url)
live = r.status_code == 200
except Exception:
live = False
Expand Down Expand Up @@ -922,7 +923,8 @@ def _get_session_info(self, table: str, url: str) -> list[dict[str, Any]]:
return data

@cached
def get(self, table: str, **query) -> pd.DataFrame:
async def get(self, table: str, **query) -> pd.DataFrame:
# TODO: make this async instead of using ThreadPoolExecutor
data = []
with futures.ThreadPoolExecutor(len(self.urls)) as executor:
tasks = {executor.submit(self._get_session_info, table, url): url
Expand Down Expand Up @@ -1016,15 +1018,15 @@ def get_schema(
return schemas if table is None else schemas[table]

@cached
def get(self, table: str, **query) -> DataFrame:
async def get(self, table: str, **query) -> DataFrame:
df, left_key = None, None
for spec in self.tables[table]:
source, subtable = spec['source'], spec['table']
source_query = dict(query)
right_key = spec.get('index')
if df is not None and left_key and right_key not in query:
source_query[right_key] = list(df[left_key].unique())
df_merge = self.sources[source].get(subtable, **source_query)
df_merge = await asyncio.to_thread(self.sources[source].get, subtable, **source_query)
if df is None:
df = df_merge
left_key = spec.get('index')
Expand Down Expand Up @@ -1140,15 +1142,15 @@ def _get_source_table(self, table: str) -> DataFrame:
return source.get(table, **query)

@cached
def get(self, table: str, **query) -> DataFrame:
async def get(self, table: str, **query) -> DataFrame:
df = self._get_source_table(table)
if self.tables:
transforms = self.tables[table].get('transforms', []) + self.transforms
else:
transforms = self.transforms
transforms.append(FilterTransform(conditions=list(query.items())))
for transform in transforms:
df = transform.apply(df)
df = await asyncio.to_thread(transform.apply, df)
return df

get.__doc__ = Source.get.__doc__
Expand Down
10 changes: 8 additions & 2 deletions lumen/sources/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import re

from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -248,8 +249,11 @@ def get_sql_expr(self, table: str):
sql_expr = self.sql_expr.format(table=table)
return sql_expr.rstrip(";")

def _execute(self, sql_expr):
return self._connection.execute(sql_expr).fetch_df(date_as_object=True)

@cached
def get(self, table, **query):
async def get(self, table, **query):
query.pop('__dask', None)
sql_expr = self.get_sql_expr(table)
sql_transforms = query.pop('sql_transforms', [])
Expand All @@ -258,7 +262,7 @@ def get(self, table, **query):
sql_transforms = [SQLFilter(conditions=conditions)] + sql_transforms
for st in sql_transforms:
sql_expr = st.apply(sql_expr)
df = self._connection.execute(sql_expr).fetch_df(date_as_object=True)
df = await asyncio.to_thread(self._execute, sql_expr)
if not self.filter_in_sql:
df = Filter.apply_to(df, conditions=conditions)
return df
Expand All @@ -280,6 +284,7 @@ def get_schema(
continue
sql_expr = self.get_sql_expr(entry)
data = self._connection.execute(sql_limit.apply(sql_expr)).fetch_df()
print(data)
schemas[entry] = schema = get_dataframe_schema(data)['items']['properties']
if limit:
continue
Expand Down Expand Up @@ -319,5 +324,6 @@ def get_schema(
count_expr = ' '.join(count_expr.splitlines())
count_data = self._connection.execute(count_expr).fetch_df()
schema['count'] = cast(count_data['count'].iloc[0])
print(schema)

return schemas if table is None else schemas[table]
6 changes: 4 additions & 2 deletions lumen/sources/intake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio

from typing import Any

import intake # type: ignore
Expand Down Expand Up @@ -49,15 +51,15 @@ def get_schema(
return schemas if table is None else schemas[table]

@cached
def get(self, table, **query):
async def get(self, table, **query):
dask = query.pop('__dask', self.dask)
try:
entry = self.cat[table]
except KeyError:
raise KeyError(f"'{table}' table could not be found in Intake "
"catalog. Available tables include: "
f"{list(self.cat)}.")
df = self._read(entry, dask)
df = await asyncio.to_thread(self._read, entry, dask)
return df if dask or not hasattr(df, 'compute') else df.compute()


Expand Down
8 changes: 5 additions & 3 deletions lumen/sources/intake_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio

from typing import Any

import param # type: ignore
Expand Down Expand Up @@ -49,7 +51,7 @@ def get_sql_expr(self, table):
return self._get_source(table)._sql_expr

@cached
def get(self, table, **query):
async def get(self, table, **query):
'''
Applies SQL Transforms, creating new temp catalog on the fly
and querying the database.
Expand All @@ -62,12 +64,12 @@ def get(self, table, **query):
raise ValueError(
'SQLTransforms cannot be applied to non-SQL based Intake source.'
)
return super().get(table, **query)
return await super().get(table, **query)
conditions = list(query.items())
if self.filter_in_sql:
sql_transforms = [SQLFilter(conditions=conditions)] + sql_transforms
source = self._apply_transforms(source, sql_transforms)
df = self._read(source)
df = await asyncio.to_thread(self._read, source)
if not self.filter_in_sql:
df = Filter.apply_to(df, conditions=conditions)
return df if dask or not hasattr(df, 'compute') else df.compute()
Expand Down
5 changes: 3 additions & 2 deletions lumen/sources/prometheus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import datetime as dt
import urllib.parse as urlparse

Expand Down Expand Up @@ -242,11 +243,11 @@ def get_schema(
return {"timeseries": schema} if table is None else schema

@cached
def get(self, table, **query):
async def get(self, table, **query):
if table not in ('timeseries',):
raise ValueError(f"PrometheusSource has no '{table}' table, "
"it currently only has a 'timeseries' table.")
return self._make_query()
return await asyncio.to_thread(self._make_query)

def get_tables(self):
return list(self._metrics)
Loading