Skip to content

Commit

Permalink
chore: restructured a lot of the sources
Browse files Browse the repository at this point in the history
  • Loading branch information
Mats E. Mollestad committed Nov 9, 2023
1 parent 450a6a8 commit dc4c879
Show file tree
Hide file tree
Showing 25 changed files with 250 additions and 138 deletions.
69 changes: 69 additions & 0 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,75 @@ def feature_referance(self) -> FeatureReferance:
pass


def compile_hidden_features(
feature: FeatureFactory,
location: FeatureLocation,
hidden_features: int,
var_name: str,
entities: set[Feature],
):
aggregations = []

features = set()
derived_features = set()

if feature.transformation:
# Adding features that is not stored in the view
# e.g:
# class SomeView(FeatureView):
# ...
# x, y = Bool(), Bool()
# z = (x & y) | x
#
# Here will (x & y)'s result be a 'hidden' feature
feature_deps = [(feat.depth(), feat) for feat in feature.feature_dependencies()]

# Sorting by key in order to instanciate the "core" features first
# And then making it possible for other features to reference them
def sort_key(x: tuple[int, FeatureFactory]) -> int:
return x[0]

for depth, feature_dep in sorted(feature_deps, key=sort_key):

if not feature_dep._location:
feature_dep._location = location

if feature_dep._name:
feat_dep = feature_dep.feature()
if feat_dep in features or feat_dep in entities:
continue

if depth == 0:
# The raw value and the transformed have the same name
if not feature_dep._name:
feature_dep._name = var_name
feat_dep = feature_dep.feature()
features.add(feat_dep)
continue

if not feature_dep._name:
feature_dep._name = str(hidden_features)
hidden_features += 1

if isinstance(feature_dep.transformation, AggregationTransformationFactory):
aggregations.append(feature_dep)
else:
feature_graph = feature_dep.compile() # Should decide on which payload to send
if feature_graph in derived_features:
continue

derived_features.add(feature_dep.compile())

if not feature._name:
feature._name = 'ephemoral'
if isinstance(feature.transformation, AggregationTransformationFactory):
aggregations.append(feature)
else:
derived_features.add(feature.compile()) # Should decide on which payload to send

return features, derived_features


@dataclass
class RegressionLabel(FeatureReferencable):
feature: FeatureFactory
Expand Down
28 changes: 14 additions & 14 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class ModelMetadata:
contacts: list[str] | None = field(default=None)
tags: dict[str, str] | None = field(default=None)
description: str | None = field(default=None)
predictions_source: BatchDataSource | None = field(default=None)
predictions_stream: StreamDataSource | None = field(default=None)
historical_source: BatchDataSource | None = field(default=None)
prediction_source: BatchDataSource | None = field(default=None)
prediction_stream: StreamDataSource | None = field(default=None)
application_source: BatchDataSource | None = field(default=None)
dataset_folder: Folder | None = field(default=None)


Expand All @@ -94,9 +94,9 @@ def model_contract(
contacts: list[str] | None = None,
tags: dict[str, str] | None = None,
description: str | None = None,
predictions_source: BatchDataSource | None = None,
predictions_stream: StreamDataSource | None = None,
historical_source: BatchDataSource | None = None,
prediction_source: BatchDataSource | None = None,
prediction_stream: StreamDataSource | None = None,
application_source: BatchDataSource | None = None,
dataset_folder: Folder | None = None,
) -> Callable[[Type[T]], ModelContractWrapper[T]]:
def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
Expand All @@ -106,9 +106,9 @@ def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
contacts=contacts,
tags=tags,
description=description,
predictions_source=predictions_source,
predictions_stream=predictions_stream,
historical_source=historical_source,
prediction_source=prediction_source,
prediction_stream=prediction_stream,
application_source=application_source,
dataset_folder=dataset_folder,
)
return ModelContractWrapper(metadata, cls)
Expand All @@ -126,7 +126,7 @@ def metadata_with(
tags: dict[str, str] | None = None,
predictions_source: BatchDataSource | None = None,
predictions_stream: StreamDataSource | None = None,
historical_source: BatchDataSource | None = None,
application_source: BatchDataSource | None = None,
dataset_folder: Folder | None = None,
) -> ModelMetadata:
return ModelMetadata(
Expand All @@ -137,7 +137,7 @@ def metadata_with(
description,
predictions_source,
predictions_stream,
historical_source=historical_source,
application_source=application_source,
dataset_folder=dataset_folder,
)

Expand Down Expand Up @@ -174,9 +174,9 @@ class MyModel(ModelContract):
features=set(),
derived_features=set(),
model_version_column=None,
source=metadata.predictions_source,
historical_source=metadata.historical_source,
stream_source=metadata.predictions_stream,
source=metadata.prediction_source,
application_source=metadata.application_source,
stream_source=metadata.prediction_stream,
classification_targets=set(),
regression_targets=set(),
)
Expand Down
2 changes: 1 addition & 1 deletion aligned/compiler/tests/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
source = PostgreSQLConfig.localhost('test')


@feature_view(name='test', description='test', batch_source=source.table('test'))
@feature_view(name='test', description='test', source=source.table('test'))
class Test:
id = UUID().as_entity()

Expand Down
2 changes: 1 addition & 1 deletion aligned/compiler/tests/test_repo_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_repo_reader() -> None:
view = list(definitions.feature_views)[0]

assert view.name == 'test'
assert view.batch_data_source.type_name == 'psql'
assert view.source.type_name == 'psql'
assert len(view.derived_features) == 1
assert len(view.features) == 2
assert len(view.entities) == 1
9 changes: 9 additions & 0 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,17 @@ class BatchDataSource(ABC, Codable, SerializableType):

@abstractmethod
def job_group_key(self) -> str:
"""
A key defining which sources can be grouped together in one request.
"""
pass

def source_id(self) -> str:
"""
An id that identifies a source from others.
"""
return self.job_group_key()

def _serialize(self) -> dict:
assert (
self.type_name in BatchDataSourceFactory.shared().supported_data_sources
Expand Down
19 changes: 8 additions & 11 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def from_definition(repo: RepoDefinition) -> FeatureStore:

FeatureStore.register_enrichers(repo.enrichers)
sources = {
FeatureLocation.feature_view(view.name).identifier: view.batch_data_source
FeatureLocation.feature_view(view.name).identifier: view.materialized_source
if view.materialized_source
else view.source
for view in repo.feature_views
} | {
FeatureLocation.model(model.name).identifier: model.predictions_view.source
Expand Down Expand Up @@ -468,9 +470,7 @@ class MyFeatureView:
"""
self.feature_views[view.name] = view
if isinstance(self.feature_source, BatchFeatureSource):
self.feature_source.sources[
FeatureLocation.feature_view(view.name).identifier
] = view.batch_data_source
self.feature_source.sources[FeatureLocation.feature_view(view.name).identifier] = view.source

def add_feature_view(self, feature_view: FeatureView) -> None:
self.add_compiled_view(feature_view.compile_instance())
Expand Down Expand Up @@ -519,7 +519,7 @@ def with_source(self, source: FeatureSourceable = None) -> FeatureStore:
feature_source = source.feature_source()
elif source is None:
sources = {
FeatureLocation.feature_view(view.name).identifier: view.batch_data_source
FeatureLocation.feature_view(view.name).identifier: view.source
for view in set(self.feature_views.values())
} | {
FeatureLocation.model(model.name).identifier: model.predictions_view.source
Expand Down Expand Up @@ -560,8 +560,7 @@ def use_application_sources(self) -> FeatureStore:
FeatureStore: A new feature store that loads features from the application source
"""
sources = {
FeatureLocation.feature_view(view.name).identifier: view.application_source
or view.batch_data_source
FeatureLocation.feature_view(view.name).identifier: view.application_source or view.source
for view in set(self.feature_views.values())
} | {
FeatureLocation.model(model.name).identifier: model.predictions_view.source
Expand Down Expand Up @@ -601,10 +600,8 @@ def views_with_config(self, config: Any) -> list[SourceRequest]:
views: list[SourceRequest] = []
for view in self.feature_views.values():
request = view.request_all.needed_requests[0]
if view.batch_data_source.contains_config(config):
views.append(
SourceRequest(FeatureLocation.feature_view(view.name), view.batch_data_source, request)
)
if view.source.contains_config(config):
views.append(SourceRequest(FeatureLocation.feature_view(view.name), view.source, request))

if view.application_source and view.application_source.contains_config(config):
views.append(
Expand Down
50 changes: 26 additions & 24 deletions aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
@dataclass
class FeatureViewMetadata:
name: str
batch_source: BatchDataSource
source: BatchDataSource
description: str | None = field(default=None)
stream_source: StreamDataSource | None = field(default=None)
application_source: BatchDataSource | None = field(default=None)
staging_source: BatchDataSource | None = field(default=None)
materialized_source: BatchDataSource | None = field(default=None)
contacts: list[str] | None = field(default=None)
tags: dict[str, str] = field(default_factory=dict)

Expand All @@ -52,31 +52,31 @@ def from_compiled(view: CompiledFeatureView) -> FeatureViewMetadata:
name=view.name,
description=view.description,
tags=view.tags,
batch_source=view.batch_data_source,
source=view.source,
stream_source=view.stream_data_source,
application_source=view.application_source,
staging_source=view.staging_source,
materialized_source=view.materialized_source,
)


def feature_view(
name: str,
batch_source: BatchDataSource,
source: BatchDataSource,
description: str | None = None,
stream_source: StreamDataSource | None = None,
application_source: BatchDataSource | None = None,
staging_source: BatchDataSource | None = None,
materialized_source: BatchDataSource | None = None,
contacts: list[str] | None = None,
tags: dict[str, str] | None = None,
) -> Callable[[Type[T]], FeatureViewWrapper[T]]:
def decorator(cls: Type[T]) -> FeatureViewWrapper[T]:
metadata = FeatureViewMetadata(
name,
batch_source,
source,
description=description,
stream_source=stream_source,
application_source=application_source,
staging_source=staging_source,
materialized_source=materialized_source,
contacts=contacts,
tags=tags or {},
)
Expand All @@ -95,15 +95,15 @@ def __call__(self) -> T:
# Needs to compiile the model to set the location for the view features
_ = self.compile()
view = self.view()
setattr(view, "__view_wrapper__", self)
setattr(view, '__view_wrapper__', self)
return view

def compile(self) -> CompiledFeatureView:

return FeatureView.compile_with_metadata(self.view(), self.metadata)

def with_filter(
self, named: str, where: Callable[[T], Bool], stored_at: BatchDataSource | None = None
self, named: str, where: Callable[[T], Bool], materialize_source: BatchDataSource | None = None
) -> FeatureViewWrapper[T]:

from aligned.data_source.batch_data_source import FilteredDataSource
Expand All @@ -113,14 +113,16 @@ def with_filter(

condition = where(self.__call__())

main_source = meta.materialized_source if meta.materialized_source else meta.source

if condition.transformation:
meta.batch_source = FilteredDataSource(self.metadata.batch_source, condition.compile())
meta.source = FilteredDataSource(main_source, condition.compile())
else:
meta.batch_source = FilteredDataSource(self.metadata.batch_source, condition.feature())
meta.source = FilteredDataSource(main_source, condition.feature())

if stored_at:
meta.staging_source = meta.batch_source
meta.batch_source = stored_at
if materialize_source:
meta.materialized_source = materialize_source
meta.source = main_source

return FeatureViewWrapper(metadata=meta, view=self.view)

Expand All @@ -139,8 +141,8 @@ def with_joined(self, view: Any, join_on: str, method: str = 'inner') -> BatchDa
request = compiled_view.request_all

return JoinDataSource(
source=self.metadata.batch_source,
right_source=compiled_view.batch_data_source,
source=self.metadata.source,
right_source=compiled_view.source,
right_request=request.needed_requests[0],
left_on=join_on,
right_on=join_on,
Expand All @@ -157,8 +159,8 @@ def with_entity_renaming(self, named: str, renames: dict[str, str] | str) -> Fea
meta.name = named

all_data_sources = [
meta.batch_source,
meta.staging_source,
meta.source,
meta.materialized_source,
meta.application_source,
meta.stream_source,
]
Expand Down Expand Up @@ -279,7 +281,7 @@ class MyView:
```
"""
compiled = self.compile()
return await FeatureView.freshness_in_source(compiled, compiled.batch_data_source)
return await FeatureView.freshness_in_source(compiled, compiled.source)


class FeatureView(ABC):
Expand Down Expand Up @@ -312,7 +314,7 @@ def metadata_with(
description,
stream_source or HttpStreamSource(name),
application_source=application_source,
staging_source=staging_source,
materialized_source=staging_source,
contacts=contacts,
tags=tags or {},
)
Expand All @@ -330,7 +332,7 @@ async def batch_source_freshness(cls) -> datetime | None:
Returns the freshest datetime for the batch data source
"""
compiled = cls().compile_instance()
return await FeatureView.freshness_in_source(compiled, compiled.batch_data_source)
return await FeatureView.freshness_in_source(compiled, compiled.source)

@staticmethod
async def freshness_in_source(view: CompiledFeatureView, source: BatchDataSource) -> datetime | None:
Expand All @@ -354,15 +356,15 @@ def compile_with_metadata(feature_view: Any, metadata: FeatureViewMetadata) -> C
name=metadata.name,
description=metadata.description,
tags=metadata.tags,
batch_data_source=metadata.batch_source,
source=metadata.source,
entities=set(),
features=set(),
derived_features=set(),
aggregated_features=set(),
event_timestamp=None,
stream_data_source=metadata.stream_source,
application_source=metadata.application_source,
staging_source=metadata.staging_source,
materialized_source=metadata.materialized_source,
indexes=[],
)
aggregations: list[FeatureFactory] = []
Expand Down
Loading

0 comments on commit dc4c879

Please sign in to comment.