Skip to content

Commit

Permalink
Added SQL support on the store
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 18, 2024
1 parent ce49120 commit 81af2e8
Show file tree
Hide file tree
Showing 17 changed files with 259 additions and 132 deletions.
12 changes: 1 addition & 11 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,7 @@ def as_view(self) -> CompiledFeatureView | None:
compiled = self.compile()
view = compiled.predictions_view

if not view.source:
return None

return CompiledFeatureView(
name=self.metadata.name,
source=view.source,
entities=view.entities,
features=view.features,
derived_features=view.derived_features,
event_timestamp=view.event_timestamp,
)
return view.as_view(self.metadata.name)

def filter(
self, name: str, where: Callable[[T], Bool], application_source: BatchDataSource | None = None
Expand Down
108 changes: 108 additions & 0 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
StreamAggregationJob,
SupervisedJob,
ConvertableToRetrivalJob,
CustomLazyPolarsJob,
)
from aligned.schemas.feature import FeatureLocation, Feature, FeatureReferance
from aligned.schemas.feature_view import CompiledFeatureView
Expand Down Expand Up @@ -249,6 +250,94 @@ async def from_dir(path: str = '.') -> FeatureStore:
definition = await RepoDefinition.from_path(path)
return FeatureStore.from_definition(definition)

def execute_sql(self, query: str) -> RetrivalJob:
import polars as pl
import sqlglot

expr = sqlglot.parse_one(query)
select_expr = expr.find_all(sqlglot.exp.Select)

tables = set()
table_alias: dict[str, str] = {}
table_columns: dict[str, set[str]] = defaultdict(set)
unique_column_table_lookup: dict[str, str] = {}

all_table_columns = {
table_name: set(view.request_all.needed_requests[0].all_returned_columns)
for table_name, view in self.feature_views.items()
}
all_model_columns = {
table_name: set(model.predictions_view.request(table_name).all_returned_columns)
for table_name, model in self.models.items()
}

for expr in select_expr:

for table in expr.find_all(sqlglot.exp.Table):
tables.add(table.name)
table_alias[table.alias_or_name] = table.name

for column in all_table_columns.get(table.name, set()).union(
all_model_columns.get(table.name, set())
):

if column in unique_column_table_lookup:
del unique_column_table_lookup[column]
else:
unique_column_table_lookup[column] = table.name

if expr.find(sqlglot.exp.Star):
for table in tables:
table_columns[table].update(
all_table_columns.get(table, set()).union(all_model_columns.get(table, set()))
)
else:
for column in expr.find_all(sqlglot.exp.Column):
source_table = table_alias.get(column.table)

if source_table:
table_columns[source_table].add(column.name)
continue

if column.table == '' and column.name in unique_column_table_lookup:
table_columns[unique_column_table_lookup[column.name]].add(column.name)
continue

raise ValueError(f"Unable to find table `{column.table}` for query `{query}`")

all_features = set()

for table, columns in table_columns.items():
all_features.update(f'{table}:{column}' for column in columns)

raw_request = RawStringFeatureRequest(features=all_features)
feature_request = self.requests_for(raw_request, None)

request = RetrivalRequest.unsafe_combine(feature_request.needed_requests)

async def run_query() -> pl.LazyFrame:
dfs = {}

for req in feature_request.needed_requests:

if req.location.location == 'feature_view':
view = self.feature_view(req.location.name).select(req.all_feature_names).all()
dfs[req.location.name] = await view.to_lazy_polars()
elif req.location.location == 'model':
model = (
self.model(req.location.name).all_predictions().select_columns(req.all_feature_names)
)
dfs[req.location.name] = await model.to_lazy_polars()
else:
raise ValueError(f"Unsupported location: {req.location}")

return pl.SQLContext(dfs).execute(query)

return CustomLazyPolarsJob(
request=request,
method=run_query,
)

def features_for_request(
self,
requests: FeatureRequest,
Expand Down Expand Up @@ -339,6 +428,7 @@ def _requests_for(
models: dict[str, ModelSchema],
event_timestamp_column: str | None = None,
) -> FeatureRequest:

features = feature_request.grouped_features
requests: list[RetrivalRequest] = []
entity_names = set()
Expand Down Expand Up @@ -367,13 +457,28 @@ def _requests_for(

elif location_name in feature_views:
feature_view = feature_views[location_name]

if len(features[location]) == 1 and list(features[location])[0] == '*':
sub_requests = feature_view.request_all
else:
sub_requests = feature_view.request_for(features[location])
requests.extend(sub_requests.needed_requests)
for request in sub_requests.needed_requests:
entity_names.update(request.entity_names)
elif location_name in models:
model = models[location_name]
feature_view = model.predictions_view

if feature_view is None:
raise ValueError(f'Unable to find: {location_name}')

if len(features[location]) == 1 and list(features[location])[0] == '*':
sub_request = feature_view.request(location_name)
else:
sub_request = feature_view.request_for(features[location], location_name)

requests.append(sub_request)
entity_names.update(sub_request.entity_names)
else:
raise ValueError(
f'Unable to find: {location_name}, '
Expand All @@ -387,6 +492,9 @@ def _requests_for(
else:
requests = [request.without_event_timestamp() for request in requests]

if not requests:
raise ValueError(f'Unable to find any requests for: {feature_request}')

return FeatureRequest(
FeatureLocation.model('custom features'),
feature_request.feature_names.union(entity_names),
Expand Down
13 changes: 13 additions & 0 deletions aligned/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ class PredictionsView(Codable):
acceptable_freshness: timedelta | None = field(default=None)
unacceptable_freshness: timedelta | None = field(default=None)

def as_view(self, name: str) -> CompiledFeatureView | None:
if not self.source:
return None

return CompiledFeatureView(
name=name,
source=self.source,
entities=self.entities,
features=self.features,
derived_features=self.derived_features,
event_timestamp=self.event_timestamp,
)

@property
def full_schema(self) -> set[Feature]:

Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.81"
version = "0.0.82"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down Expand Up @@ -68,6 +68,7 @@ prometheus-fastapi-instrumentator = { version="^5.9.1", optional = true }
kafka-python = { version= "^2.0.2", optional = true }
connectorx = { version = "^0.3.2", optional = true }
asyncpg = { version = "^0.29.0", optional = true }
sqlglot = "^22.5.0"

[tool.poetry.extras]
aws = ["aioaws", "connectorx"]
Expand Down
14 changes: 7 additions & 7 deletions test_data/credit_history.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
student_loan_due,bankruptcies,credit_card_due,event_timestamp,dob_ssn,due_sum
22328,0,8419,2020-04-26 18:01:04.746575+00:00,19530219_5179,30747
2515,0,2944,2020-04-26 18:01:04.746575+00:00,19520816_8737,5459
33000,0,833,2020-04-26 18:01:04.746575+00:00,19860413_2537,33833
48955,0,5936,2020-04-27 18:01:04.746575+00:00,19530219_5179,54891
9501,0,1575,2020-04-27 18:01:04.746575+00:00,19520816_8737,11076
35510,0,6263,2020-04-27 18:01:04.746575+00:00,19860413_2537,41773
student_loan_due,credit_card_due,bankruptcies,event_timestamp,due_sum,dob_ssn
22328,8419,0,2020-04-26 18:01:04.746575+00:00,30747,19530219_5179
2515,2944,0,2020-04-26 18:01:04.746575+00:00,5459,19520816_8737
33000,833,0,2020-04-26 18:01:04.746575+00:00,33833,19860413_2537
48955,5936,0,2020-04-27 18:01:04.746575+00:00,54891,19530219_5179
9501,1575,0,2020-04-27 18:01:04.746575+00:00,11076,19520816_8737
35510,6263,0,2020-04-27 18:01:04.746575+00:00,41773,19860413_2537
Binary file modified test_data/credit_history.parquet
Binary file not shown.
Binary file modified test_data/credit_history_agg.parquet
Binary file not shown.
Binary file modified test_data/credit_history_mater.parquet
Binary file not shown.
Loading

0 comments on commit 81af2e8

Please sign in to comment.