Skip to content

Commit

Permalink
Merge pull request #16 from MatsMoll/matsei/minor-model-changes
Browse files Browse the repository at this point in the history
Minor bug fixes
  • Loading branch information
MatsMoll authored Oct 23, 2023
2 parents 4157812 + 6acda6e commit 6f7f998
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 52 deletions.
2 changes: 1 addition & 1 deletion aligned/data_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def read_pandas(self) -> pd.DataFrame:
raise NotImplementedError()

async def to_pandas(self) -> pd.DataFrame:
await self.read_pandas()
return await self.read_pandas()

async def to_polars(self) -> pl.LazyFrame:
raise NotImplementedError()
Expand Down
29 changes: 17 additions & 12 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from importlib import import_module
from typing import Any
from typing import Any, Union

from prometheus_client import Histogram

Expand Down Expand Up @@ -44,6 +44,8 @@
labelnames=['feature_view'],
)

FeatureSourceable = Union[FeatureSource, FeatureSourceFactory, None]


@dataclass
class SourceRequest:
Expand Down Expand Up @@ -493,12 +495,12 @@ def add_model(self, model: ModelContract) -> None:
compiled_model = type(model).compile()
self.models[compiled_model.name] = compiled_model

def with_source(self, source: FeatureSource | FeatureSourceFactory | None = None) -> FeatureStore:
def with_source(self, source: FeatureSourceable = None) -> FeatureStore:
"""
Creates a new instance of a feature store, but changes where to fetch the features from
```
store = # Load the store
store = await FeatureStore.from_dir(".")
redis_store = store.with_source(redis)
batch_source = redis_store.with_source()
```
Expand All @@ -511,7 +513,7 @@ def with_source(self, source: FeatureSource | FeatureSourceFactory | None = None
"""
if isinstance(source, FeatureSourceFactory):
feature_source = source.feature_source()
else:
elif source is None:
sources = {
FeatureLocation.feature_view(view.name).identifier: view.batch_data_source
for view in set(self.feature_views.values())
Expand All @@ -521,6 +523,13 @@ def with_source(self, source: FeatureSource | FeatureSourceFactory | None = None
if model.predictions_view.source is not None
}
feature_source = source or BatchFeatureSource(sources=sources)
elif isinstance(source, FeatureSource):
feature_source = source
else:
raise ValueError(
'Setting a dedicated source needs to be either a FeatureSource, '
f'or FeatureSourceFactory. Got: {type(source)}'
)

return FeatureStore(
feature_views=self.feature_views,
Expand Down Expand Up @@ -781,11 +790,9 @@ def all_predictions(self, limit: int | None = None) -> RetrivalJob:
request = pred_view.request(self.model.name)
return pred_view.source.all_data(request, limit=limit)

def using_source(
self, source: FeatureSource | FeatureSourceFactory | BatchDataSource
) -> ModelFeatureStore:
def using_source(self, source: FeatureSourceable | BatchDataSource) -> ModelFeatureStore:

model_source: FeatureSource | FeatureSourceFactory
model_source: FeatureSourceable

if isinstance(source, BatchDataSource):
model_source = BatchFeatureSource({FeatureLocation.model(self.model.name).identifier: source})
Expand Down Expand Up @@ -1038,9 +1045,7 @@ def request(self) -> RetrivalRequest:
def source(self) -> FeatureSource:
return self.store.feature_source

def using_source(
self, source: FeatureSource | FeatureSourceFactory | BatchDataSource
) -> FeatureViewStore:
def using_source(self, source: FeatureSourceable | BatchDataSource) -> FeatureViewStore:
"""
Sets the source to load features from.
Expand All @@ -1061,7 +1066,7 @@ def using_source(
Returns:
A new `FeatureViewStore` that sends queries to the passed source
"""
view_source: FeatureSource | FeatureSourceFactory
view_source: FeatureSourceable

if isinstance(source, BatchDataSource):
view_source = BatchFeatureSource(
Expand Down
5 changes: 2 additions & 3 deletions aligned/feature_view/combined_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def query(self) -> 'FeatureViewStore':
"""Makes it possible to query the feature view for features
```python
class SomeView(FeatureView):
metadata = ...
@feature_view(...)
class SomeView:
id = Int32().as_entity()
Expand Down
21 changes: 21 additions & 0 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def fill_missing_columns(self) -> RetrivalJob:
def rename(self, mappings: dict[str, str]) -> RetrivalJob:
return RenameJob(self, mappings)

def drop_duplicate_entities(self) -> RetrivalJob:
return DropDuplicateEntities(self)

def ignore_event_timestamp(self) -> RetrivalJob:
if isinstance(self, ModificationJob):
return self.copy_with(self.job.ignore_event_timestamp())
Expand Down Expand Up @@ -433,6 +436,24 @@ async def to_polars(self) -> pl.LazyFrame:
return df.rename(self.mappings)


@dataclass
class DropDuplicateEntities(RetrivalJob, ModificationJob):

job: RetrivalJob

@property
def entity_columns(self) -> list[str]:
return self.job.request_result.entity_columns

async def to_polars(self) -> pl.LazyFrame:
df = await self.job.to_polars()
return df.unique(subset=self.entity_columns)

async def to_pandas(self) -> pd.DataFrame:
df = await self.job.to_pandas()
return df.drop_duplicates(subset=self.entity_columns)


@dataclass
class UpdateVectorIndexJob(RetrivalJob, ModificationJob):

Expand Down
8 changes: 1 addition & 7 deletions aligned/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,10 @@ def full_schema(self) -> set[Feature]:
return schema

def request(self, name: str) -> RetrivalRequest:
entities = self.entities
if self.model_version_column:
entities.add(self.model_version_column)
return RetrivalRequest(
name=name,
location=FeatureLocation.model(name),
entities=entities,
entities=self.entities,
features=self.features,
derived_features=self.derived_features,
event_timestamp=self.event_timestamp,
Expand All @@ -84,9 +81,6 @@ def request(self, name: str) -> RetrivalRequest:
def request_for(self, features: set[str], name: str) -> RetrivalRequest:
entities = self.entities

# if self.model_version_column:
# entities.add(self.model_version_column)

return RetrivalRequest(
name=name,
location=FeatureLocation.model(name),
Expand Down
22 changes: 4 additions & 18 deletions aligned/validation/pandera.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,13 @@ def _column_for(self, feature: Feature) -> Column:
)

def _build_schema(self, features: list[Feature]) -> DataFrameSchema:
return DataFrameSchema(columns={feature.name: self._column_for(feature) for feature in features})
return DataFrameSchema(
columns={feature.name: self._column_for(feature) for feature in features}, drop_invalid_rows=True
)

async def validate_pandas(self, features: list[Feature], df: pd.DataFrame) -> pd.DataFrame:
from pandera.errors import SchemaError

schema = self._build_schema(features)
try:
return schema.validate(df)
except SchemaError as error:
# Will only return one error at a time, so will remove
# errors and then run it recrusive

if error.failure_cases.shape[0] == df.shape[0]:
raise ValueError('Validation is removing all the data.')

if error.failure_cases['index'].iloc[0] is None:
raise ValueError(error)

return await self.validate_pandas(
features, df.loc[df.index.delete(error.failure_cases['index'])].reset_index()
)
return schema.validate(df, lazy=True)

async def validate_polars(self, features: list[Feature], df: pl.LazyFrame) -> pl.LazyFrame:
input_df = df.collect().to_pandas()
Expand Down
49 changes: 40 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.30"
version = "0.0.31"
description = "A scalable feature store that makes it easy to align offline and online ML systems"
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down Expand Up @@ -55,7 +55,7 @@ nest-asyncio = "^1.5.5"
pydantic = "^2.0.0"
prometheus_client = "^0.16.0"
asgi-correlation-id = { version = "^3.0.0", optional = true }
pandera = { version = "^0.13.3", optional = true}
pandera = { version = "^0.17.0", optional = true}
httpx = "^0.23.0"
polars = { version = "^0.17.15", extras = ["all"] }
pillow = { version = "^9.4.0", optional = true }
Expand Down

0 comments on commit 6f7f998

Please sign in to comment.