Skip to content

Commit

Permalink
Materialize from (#26)
Browse files Browse the repository at this point in the history
* Added materialize from field

* Added partitioned parquet source
  • Loading branch information
MatsMoll authored May 23, 2024
1 parent c5ac466 commit 17dd447
Show file tree
Hide file tree
Showing 13 changed files with 239 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mlruns

test_data/feature-store.json
test_data/mlruns
test_data/temp

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ All this is described through a `model_contract`, as shown bellow.
)
class EtaTaxi:
trip_id = Int32().as_entity()
predicted_at = EventTimestamp()
predicted_at = ValidFrom()
predicted_duration = trips.duration.as_regression_target()
```

Expand Down
8 changes: 7 additions & 1 deletion aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ class BatchDataSourceFactory:
_shared: BatchDataSourceFactory | None = None

def __init__(self) -> None:
from aligned.sources.local import CsvFileSource, ParquetFileSource, DeltaFileSource
from aligned.sources.local import (
CsvFileSource,
ParquetFileSource,
DeltaFileSource,
PartitionedParquetFileSource,
)
from aligned.sources.psql import PostgreSQLDataSource
from aligned.sources.redshift import RedshiftSQLDataSource
from aligned.sources.s3 import AwsS3CsvDataSource, AwsS3ParquetDataSource
Expand All @@ -49,6 +54,7 @@ def __init__(self) -> None:
PostgreSQLDataSource,
# File Sources
ParquetFileSource,
PartitionedParquetFileSource,
CsvFileSource,
DeltaFileSource,
# Aws Sources
Expand Down
5 changes: 5 additions & 0 deletions aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class FeatureViewMetadata:
stream_source: StreamDataSource | None = field(default=None)
application_source: BatchDataSource | None = field(default=None)
materialized_source: BatchDataSource | None = field(default=None)
materialize_from: datetime | None = field(default=None)
contacts: list[str] | None = field(default=None)
tags: list[str] | None = field(default=None)
acceptable_freshness: timedelta | None = field(default=None)
Expand All @@ -72,6 +73,7 @@ def from_compiled(view: CompiledFeatureView) -> FeatureViewMetadata:
stream_source=view.stream_data_source,
application_source=view.application_source,
materialized_source=view.materialized_source,
materialize_from=view.materialize_from,
acceptable_freshness=view.acceptable_freshness,
unacceptable_freshness=view.unacceptable_freshness,
)
Expand All @@ -97,6 +99,7 @@ def feature_view(
stream_source: StreamDataSource | None = None,
application_source: BatchDataSource | None = None,
materialized_source: BatchDataSource | None = None,
materialize_from: datetime | None = None,
contacts: list[str] | None = None,
tags: list[str] | None = None,
acceptable_freshness: timedelta | None = None,
Expand All @@ -114,6 +117,7 @@ def decorator(cls: Type[T]) -> FeatureViewWrapper[T]:
stream_source=stream_source,
application_source=application_source,
materialized_source=materialized_source,
materialize_from=materialize_from,
contacts=contacts,
tags=tags,
acceptable_freshness=acceptable_freshness,
Expand Down Expand Up @@ -508,6 +512,7 @@ def compile_with_metadata(feature_view: Any, metadata: FeatureViewMetadata) -> C
stream_data_source=metadata.stream_source,
application_source=metadata.application_source,
materialized_source=metadata.materialized_source,
materialize_from=metadata.materialize_from,
acceptable_freshness=metadata.acceptable_freshness,
unacceptable_freshness=metadata.unacceptable_freshness,
indexes=[],
Expand Down
2 changes: 1 addition & 1 deletion aligned/local/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_file_full_job_polars(retrival_request_without_derived: RetrivalRe

@pytest.mark.asyncio
async def test_write_and_read_feature_store(titanic_feature_store_scd: ContractStore) -> None:
source = FileSource.json_at('test_data/feature-store.json')
source = FileSource.json_at('test_data/temp/feature-store.json')
definition = titanic_feature_store_scd.repo_definition()
await source.write(definition.to_json().encode('utf-8'))
store = await source.feature_store()
Expand Down
5 changes: 4 additions & 1 deletion aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,9 +931,12 @@ def ignore_event_timestamp(self) -> RetrivalJob:
return self.copy_with(self.job.ignore_event_timestamp())
raise NotImplementedError('Not implemented ignore_event_timestamp')

def polars_method(self, polars_method: Callable[[pl.LazyFrame], pl.LazyFrame]) -> RetrivalJob:
def transform_polars(self, polars_method: Callable[[pl.LazyFrame], pl.LazyFrame]) -> RetrivalJob:
return CustomPolarsJob(self, polars_method)

def polars_method(self, polars_method: Callable[[pl.LazyFrame], pl.LazyFrame]) -> RetrivalJob:
return self.transform_polars(polars_method)

@staticmethod
def from_dict(data: dict[str, list], request: list[RetrivalRequest] | RetrivalRequest) -> RetrivalJob:
if isinstance(request, RetrivalRequest):
Expand Down
2 changes: 2 additions & 0 deletions aligned/schemas/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class CompiledFeatureView(Codable):
application_source: BatchDataSource | None = field(default=None)
materialized_source: BatchDataSource | None = field(default=None)

materialize_from: datetime | None = field(default=None)

acceptable_freshness: timedelta | None = field(default=None)
unacceptable_freshness: timedelta | None = field(default=None)

Expand Down
11 changes: 11 additions & 0 deletions aligned/sources/azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ParquetConfig,
StorageFileReference,
Directory,
PartitionedParquetFileSource,
data_file_freshness,
)
from aligned.storage import Storage
Expand Down Expand Up @@ -88,6 +89,16 @@ def parquet_at(
self, path, mapping_keys=mapping_keys or {}, date_formatter=date_formatter or DateFormatter.noop()
)

def partitioned_parquet_at(
self,
directory: str,
partition_keys: list[str],
mapping_keys: dict[str, str] | None = None,
config: ParquetConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> PartitionedParquetFileSource:
raise NotImplementedError(type(self))

def csv_at(
self,
path: str,
Expand Down
147 changes: 147 additions & 0 deletions aligned/sources/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,109 @@ class ParquetConfig(Codable):
should_write_index: bool = field(default=False)


@dataclass
class PartitionedParquetFileSource(BatchDataSource, ColumnFeatureMappable, DataFileReference):
"""
A source pointing to a Parquet file
"""

directory: str
partition_keys: list[str]
mapping_keys: dict[str, str] = field(default_factory=dict)
config: ParquetConfig = field(default_factory=ParquetConfig)
date_formatter: DateFormatter = field(default_factory=lambda: DateFormatter.noop())

type_name: str = 'partition_parquet'

@property
def to_markdown(self) -> str:
return f'''#### Partitioned Parquet File
*Partition keys*: {self.partition_keys}
*Renames*: {self.mapping_keys}
*Directory*: {self.directory}
[Go to directory]({self.directory})''' # noqa

def job_group_key(self) -> str:
return f'{self.type_name}/{self.directory}'

def __hash__(self) -> int:
return hash(self.job_group_key())

async def to_pandas(self) -> pd.DataFrame:
return (await self.to_lazy_polars()).collect().to_pandas()

async def to_lazy_polars(self) -> pl.LazyFrame:

glob_path = f'{self.directory}/**/*.parquet'
try:
return pl.scan_parquet(glob_path, retries=3)
except OSError:
raise UnableToFindFileException(self.directory)

async def write_polars(self, df: pl.LazyFrame) -> None:
create_parent_dir(self.directory)
df.collect().write_parquet(
self.directory,
compression=self.config.compression,
use_pyarrow=True,
pyarrow_options={
'partition_cols': self.partition_keys,
},
)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
return FileFullJob(self, request, limit, date_formatter=self.date_formatter)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
) -> RetrivalJob:
return FileDateJob(
source=self,
request=request,
start_date=start_date,
end_date=end_date,
date_formatter=self.date_formatter,
)

@classmethod
def multi_source_features_for(
cls, facts: RetrivalJob, requests: list[tuple[ParquetFileSource, RetrivalRequest]]
) -> RetrivalJob:

source = requests[0][0]
if not isinstance(source, cls):
raise ValueError(f'Only {cls} is supported, recived: {source}')

# Group based on config
return FileFactualJob(
source=source,
requests=[request for _, request in requests],
facts=facts,
date_formatter=source.date_formatter,
)

async def schema(self) -> dict[str, FeatureType]:
if self.path.startswith('http'):
parquet_schema = pl.scan_parquet(self.path).schema
else:
parquet_schema = pl.read_parquet_schema(self.path)

return {name: FeatureType.from_polars(pl_type) for name, pl_type in parquet_schema.items()}

async def feature_view_code(self, view_name: str) -> str:
from aligned.feature_view.feature_view import FeatureView

raw_schema = await self.schema()
schema = {name: feat.feature_factory for name, feat in raw_schema.items()}
data_source_code = f'FileSource.parquet_at("{self.path}")'
return FeatureView.feature_view_code_template(
schema, data_source_code, view_name, 'from aligned import FileSource'
)


@dataclass
class ParquetFileSource(BatchDataSource, ColumnFeatureMappable, DataFileReference):
"""
Expand Down Expand Up @@ -642,6 +745,16 @@ def csv_at(
) -> BatchDataSource:
...

def partitioned_parquet_at(
self,
directory: str,
partition_keys: list[str],
mapping_keys: dict[str, str] | None = None,
config: ParquetConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> PartitionedParquetFileSource:
...

def parquet_at(
self, path: str, mapping_keys: dict[str, str] | None = None, config: ParquetConfig | None = None
) -> BatchDataSource:
Expand Down Expand Up @@ -688,6 +801,23 @@ def parquet_at(
path=self.path_string(path), mapping_keys=mapping_keys or {}, config=config or ParquetConfig()
)

def partitioned_parquet_at(
self,
directory: str,
partition_keys: list[str],
mapping_keys: dict[str, str] | None = None,
config: ParquetConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> PartitionedParquetFileSource:

return PartitionedParquetFileSource(
directory=self.path_string(directory),
partition_keys=partition_keys,
mapping_keys=mapping_keys or {},
config=config or ParquetConfig(),
date_formatter=date_formatter or DateFormatter.noop(),
)

def delta_at(
self, path: str, mapping_keys: dict[str, str] | None = None, config: DeltaFileConfig | None = None
) -> DeltaFileSource:
Expand Down Expand Up @@ -729,6 +859,23 @@ def csv_at(
formatter=date_formatter or DateFormatter.iso_8601(),
)

@staticmethod
def partitioned_parquet_at(
directory: str,
partition_keys: list[str],
mapping_keys: dict[str, str] | None = None,
config: ParquetConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> PartitionedParquetFileSource:

return PartitionedParquetFileSource(
directory=directory,
partition_keys=partition_keys,
mapping_keys=mapping_keys or {},
config=config or ParquetConfig(),
date_formatter=date_formatter or DateFormatter.noop(),
)

@staticmethod
def parquet_at(
path: str,
Expand Down
12 changes: 12 additions & 0 deletions aligned/sources/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
CsvConfig,
DataFileReference,
ParquetConfig,
PartitionedParquetFileSource,
StorageFileReference,
Directory,
DeltaFileConfig,
DateFormatter,
)
from aligned.storage import Storage

Expand Down Expand Up @@ -114,6 +116,16 @@ def parquet_at(
parquet_config=config or ParquetConfig(),
)

def partitioned_parquet_at(
self,
directory: str,
partition_keys: list[str],
mapping_keys: dict[str, str] | None = None,
config: ParquetConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> PartitionedParquetFileSource:
raise NotImplementedError(type(self))

def delta_at(
self, path: str, mapping_keys: dict[str, str] | None = None, config: DeltaFileConfig | None = None
) -> BatchDataSource:
Expand Down
48 changes: 47 additions & 1 deletion aligned/sources/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,53 @@ async def test_read_parquet(point_in_time_data_test: DataTest) -> None:


@pytest.mark.asyncio
async def test_parquest(point_in_time_data_test: DataTest) -> None:
async def test_partition_parquet(point_in_time_data_test: DataTest) -> None:
store = ContractStore.experimental()

agg_features: list[str] = []

for source in point_in_time_data_test.sources:
view = source.view
view_name = view.metadata.name

compiled = view.compile()

if '_agg' in view_name:
agg_features.extend([feat.name for feat in compiled.aggregated_features])
continue

entities = compiled.entitiy_names

file_source = FileSource.partitioned_parquet_at(
f'test_data/temp/{view_name}',
partition_keys=list(entities),
)
await file_source.write_polars(source.data.lazy())

view.metadata = FeatureView.metadata_with( # type: ignore
name=view.metadata.name,
description=view.metadata.description,
batch_source=file_source,
)
store.add_feature_view(view)

job = store.features_for(
point_in_time_data_test.entities,
[feat for feat in point_in_time_data_test.feature_reference if '_agg' not in feat],
event_timestamp_column='event_timestamp',
)
data = (await job.to_lazy_polars()).collect()

expected = point_in_time_data_test.expected_output.drop(agg_features)
assert expected.shape == data.shape, f'Expected: {expected.shape}\nGot: {data.shape}'
assert set(expected.columns) == set(data.columns), f'Expected: {expected.columns}\nGot: {data.columns}'

ordered_columns = data.select(expected.columns)
assert ordered_columns.equals(expected), f'Expected: {expected}\nGot: {ordered_columns}'


@pytest.mark.asyncio
async def test_parquet(point_in_time_data_test: DataTest) -> None:

store = ContractStore.experimental()

Expand Down
Loading

0 comments on commit 17dd447

Please sign in to comment.