Skip to content

Commit

Permalink
Merge pull request #13 from MatsMoll/matsei/model-write-to-pred
Browse files Browse the repository at this point in the history
feat: added new pred write features for models
  • Loading branch information
MatsMoll authored Oct 13, 2023
2 parents 5cb8c55 + 062b28c commit 9f1beee
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 32 deletions.
20 changes: 18 additions & 2 deletions aligned/feature_source.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from typing import Any
from typing import Any, TYPE_CHECKING

import asyncio
import numpy as np
import pandas as pd
import polars as pl

from aligned.data_source.batch_data_source import BatchDataSource
from aligned.request.retrival_request import FeatureRequest, RequestResult, RetrivalRequest
from aligned.retrival_job import RetrivalJob
from aligned.schemas.feature import FeatureLocation, EventTimestamp

if TYPE_CHECKING:
from datetime import datetime


class FeatureSourceFactory:
Expand All @@ -22,6 +26,9 @@ class FeatureSource:
def features_for(self, facts: RetrivalJob, request: FeatureRequest) -> RetrivalJob:
raise NotImplementedError()

async def freshness_for(self, locations: list[FeatureLocation]) -> dict[FeatureLocation, datetime]:
raise NotImplementedError()


class WritableFeatureSource:
async def write(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
Expand Down Expand Up @@ -125,6 +132,15 @@ def all_between(self, start_date: datetime, end_date: datetime, request: Feature
.derive_features(requests=request.needed_requests)
)

async def freshness_for(
self, locations: dict[FeatureLocation, EventTimestamp]
) -> dict[FeatureLocation, datetime]:
locs = list(locations.keys())
results = await asyncio.gather(
*[self.sources[loc.identifier].freshness(locations[loc]) for loc in locs]
)
return dict(zip(locs, results))


class FactualInMemoryJob(RetrivalJob):
"""
Expand Down
98 changes: 79 additions & 19 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,10 @@ class ModelFeatureStore:
model: ModelSchema
store: FeatureStore

@property
def location(self) -> FeatureLocation:
return FeatureLocation.model(self.model.name)

def raw_string_features(self, except_features: set[str]) -> set[str]:
return {
f'{feature.location.identifier}:{feature.name}'
Expand Down Expand Up @@ -737,23 +741,8 @@ def process_features(self, input: RetrivalJob | dict[str, list]) -> RetrivalJob:

def predictions_for(self, entities: dict[str, list] | RetrivalJob) -> RetrivalJob:

if self.model.predictions_view.source is None:
raise ValueError(
'Model does not have a prediction source. '
'This can be set in the metadata for a model contract.'
)

source = self.model.predictions_view.source
request = self.model.predictions_view.request(self.model.name)

if isinstance(entities, RetrivalJob):
job = entities
elif isinstance(entities, dict):
job = RetrivalJob.from_dict(entities, request=[request])
else:
raise ValueError(f'features must be a dict or a RetrivalJob, was {type(input)}')

return source.features_for(job, request).ensure_types([request]).derive_features()
location_id = self.location.identifier
return self.store.features_for(entities, features=[f'{location_id}:*'])

def all_predictions(self, limit: int | None = None) -> RetrivalJob:

Expand All @@ -768,6 +757,66 @@ def all_predictions(self, limit: int | None = None) -> RetrivalJob:
request = pred_view.request(self.model.name)
return pred_view.source.all_data(request, limit=limit)

def using_source(
self, source: FeatureSource | FeatureSourceFactory | BatchDataSource
) -> ModelFeatureStore:

model_source: FeatureSource | FeatureSourceFactory

if isinstance(source, BatchDataSource):
model_source = BatchFeatureSource({FeatureLocation.model(self.model.name).identifier: source})
else:
model_source = source

return ModelFeatureStore(self.model, self.store.with_source(model_source))

async def write_predictions(self, predictions: dict[str, list] | RetrivalJob) -> None:
"""
Writes data to a source defined as a prediction source
```python
@model_contract(
name="taxi_eta",
features=[...]
predictions_source=FileSource.parquet_at("predictions.parquet")
)
class TaxiEta:
trip_id = Int32().as_entity()
duration = Int32()
...
store = FeatureStore.from_dir(".")
await store.model("taxi_eta").write_predictions({
"trip_id": [1, 2, 3, ...],
"duration": [20, 33, 42, ...]
})
```
"""

source: Any = self.store.feature_source

if isinstance(source, BatchFeatureSource):
location = FeatureLocation.model(self.model.name).identifier
source = source.sources[location]

if not isinstance(source, WritableFeatureSource):
raise ValueError(f'The prediction source {type(source)} needs to be writable')

write_job: RetrivalJob
request = self.model.predictions_view.request(self.model.name)

if isinstance(predictions, dict):
write_job = RetrivalJob.from_dict(predictions, request)
elif isinstance(predictions, RetrivalJob):
write_job = predictions
else:
raise ValueError(f'Unable to write predictions of type {type(predictions)}')

await source.write(write_job, [request])


@dataclass
class SupervisedModelFeatureStore:
Expand Down Expand Up @@ -918,7 +967,9 @@ def request(self) -> RetrivalRequest:
def source(self) -> FeatureSource:
return self.store.feature_source

def using_source(self, source: BatchDataSource) -> FeatureViewStore:
def using_source(
self, source: FeatureSource | FeatureSourceFactory | BatchDataSource
) -> FeatureViewStore:
"""
Sets the source to load features from.
Expand All @@ -939,8 +990,17 @@ def using_source(self, source: BatchDataSource) -> FeatureViewStore:
Returns:
A new `FeatureViewStore` that sends queries to the passed source
"""
view_source: FeatureSource | FeatureSourceFactory

if isinstance(source, BatchDataSource):
view_source = BatchFeatureSource(
{FeatureLocation.feature_view(self.view.name).identifier: source}
)
else:
view_source = source

return FeatureViewStore(
store=self.store.with_source(BatchFeatureSource({self.view.name: source})),
self.store.with_source(view_source),
view=self.view,
event_triggers=self.event_triggers,
feature_filter=self.feature_filter,
Expand Down
7 changes: 7 additions & 0 deletions aligned/request/retrival_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def filter_features(self, feature_names: set[str]) -> 'RetrivalRequest':
features_to_include=feature_names,
)

@property
def all_returned_columns(self) -> list[str]:
result = self.all_feature_names.union(self.entity_names)
if self.event_timestamp:
result = result.union({self.event_timestamp.name})
return list(result)

@property
def returned_features(self) -> set[Feature]:
return {feature for feature in self.all_features if feature.name in self.features_to_include}
Expand Down
9 changes: 7 additions & 2 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def from_polars_df(df: pl.DataFrame, request: list[RetrivalRequest]) -> Retrival

return LiteralRetrivalJob(df.lazy(), RequestResult.from_request_list(request))

async def write_to_source(self, source: WritableFeatureSource):
async def write_to_source(self, source: WritableFeatureSource | DataFileReference) -> None:
"""
Writes the output of the retrival job to the passed source.
Expand All @@ -373,7 +373,12 @@ async def write_to_source(self, source: WritableFeatureSource):
Args:
source (WritableFeatureSource): A source that we can write to.
"""
await source.write(self, self.retrival_requests)
from aligned.sources.local import DataFileReference

if isinstance(source, DataFileReference):
await source.write_polars(await self.to_polars())
else:
await source.write(self, self.retrival_requests)


JobType = TypeVar('JobType')
Expand Down
11 changes: 9 additions & 2 deletions aligned/sources/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ async def as_repo_definition(self) -> RepoDefinition:

async def data_file_freshness(reference: DataFileReference, column_name: str) -> datetime:
file = await reference.to_polars()
file.select(pl.col(column_name).max().alias('max_value'))
return file.collect()['max_value'].to_list()[0]
return file.select(column_name).max().collect()[0, column_name]


@dataclass
Expand Down Expand Up @@ -124,6 +123,9 @@ async def write_pandas(self, df: pd.DataFrame) -> None:
index=self.csv_config.should_write_index,
)

async def write_polars(self, df: pl.LazyFrame) -> None:
await self.write_pandas(df.collect().to_pandas())

def std(
self, columns: set[str], time: TimespanSelector | None = None, limit: int | None = None
) -> Enricher:
Expand Down Expand Up @@ -284,6 +286,11 @@ async def feature_view_code(self, view_name: str) -> str:
schema, data_source_code, view_name, 'from aligned import FileSource'
)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
df = await self.to_polars()
et_name = event_timestamp.name
return df.select(et_name).max().collect()[0, et_name]


@dataclass
class StorageFileSource(StorageFileReference):
Expand Down
15 changes: 14 additions & 1 deletion aligned/sources/psql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Callable, Any

from aligned.data_source.batch_data_source import BatchDataSource, ColumnFeatureMappable
from aligned.feature_source import WritableFeatureSource
from aligned.request.retrival_request import RetrivalRequest
from aligned.retrival_job import FactualRetrivalJob, RetrivalJob
from aligned.schemas.codable import Codable
Expand Down Expand Up @@ -61,7 +62,7 @@ def fetch(self, query: str) -> RetrivalJob:


@dataclass
class PostgreSQLDataSource(BatchDataSource, ColumnFeatureMappable):
class PostgreSQLDataSource(BatchDataSource, ColumnFeatureMappable, WritableFeatureSource):

config: PostgreSQLConfig
table: str
Expand Down Expand Up @@ -169,3 +170,15 @@ async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
raise ValueError(f'Unsupported freshness value {value}')
else:
return None

async def write(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:

if len(requests) != 1:
raise ValueError(f'Only support writing for one request, got {len(requests)}.')

request = requests[0]

data = await job.to_polars()
data.select(request.all_returned_columns).collect().write_database(
self.table, connection_uri=self.config.url, if_exists='append'
)
49 changes: 44 additions & 5 deletions aligned/sources/tests/test_psql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import platform


@pytest.fixture
def psql() -> PostgreSQLConfig:
if 'PSQL_DATABASE_TEST' not in environ:
environ['PSQL_DATABASE_TEST'] = 'postgresql://postgres:postgres@127.0.0.1:5433/aligned-test'

return PostgreSQLConfig('PSQL_DATABASE_TEST')


@pytest.mark.skipif(
platform.uname().machine.startswith('arm'), reason='Needs psycopg2 which is not supported on arm'
)
@pytest.mark.asyncio
async def test_postgresql(point_in_time_data_test: DataTest) -> None:

if 'PSQL_DATABASE_TEST' not in environ:
environ['PSQL_DATABASE_TEST'] = 'postgresql://postgres:postgres@127.0.0.1:5433/aligned-test'
async def test_postgresql(point_in_time_data_test: DataTest, psql: PostgreSQLConfig) -> None:

psql_database = environ['PSQL_DATABASE_TEST']

Expand All @@ -28,7 +33,7 @@ async def test_postgresql(point_in_time_data_test: DataTest) -> None:
view.metadata = FeatureView.metadata_with( # type: ignore
name=view.metadata.name,
description=view.metadata.description,
batch_source=PostgreSQLConfig('PSQL_DATABASE_TEST').table(db_name),
batch_source=psql.table(db_name),
)
store.add_feature_view(view)

Expand All @@ -44,3 +49,37 @@ async def test_postgresql(point_in_time_data_test: DataTest) -> None:

ordered_columns = data.select(expected.columns)
assert ordered_columns.frame_equal(expected), f'Expected: {expected}\nGot: {ordered_columns}'


@pytest.mark.skipif(
platform.uname().machine.startswith('arm'), reason='Needs psycopg2 which is not supported on arm'
)
@pytest.mark.asyncio
async def test_postgresql_write(titanic_feature_store: FeatureStore, psql: PostgreSQLConfig) -> None:
import polars as pl
from polars.testing import assert_frame_equal

source = psql.table('titanic')

data: dict[str, list] = {'passenger_id': [1, 2, 3, 4], 'will_survive': [False, True, True, False]}

store = titanic_feature_store.model('titanic').using_source(source)
await store.write_predictions(data)

stored_data = await psql.fetch('SELECT * FROM titanic').to_polars()
assert_frame_equal(
pl.DataFrame(data),
stored_data.collect(),
check_row_order=False,
check_column_order=False,
check_dtype=False,
)

preds = await store.predictions_for({'passenger_id': [1, 3, 2, 4]}).to_polars()
assert_frame_equal(
pl.DataFrame(data),
preds.collect(),
check_row_order=False,
check_column_order=False,
check_dtype=False,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.25"
version = "0.0.26"
description = "A scalable feature store that makes it easy to align offline and online ML systems"
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 9f1beee

Please sign in to comment.