Skip to content

Commit

Permalink
Better schema support
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 13, 2024
1 parent 15ef90d commit 225addd
Show file tree
Hide file tree
Showing 21 changed files with 263 additions and 209 deletions.
18 changes: 15 additions & 3 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,21 +1125,33 @@ def aggregate(self) -> CategoricalAggregation:


class Timestamp(DateFeature, ArithmeticFeature):

time_zone: str | None

def __init__(self, time_zone: str | None = 'UTC') -> None:
self.time_zone = time_zone

@property
def dtype(self) -> FeatureType:
return FeatureType.datetime()
from zoneinfo import ZoneInfo

return FeatureType.datetime(ZoneInfo(self.time_zone) if self.time_zone else None)


class EventTimestamp(DateFeature, ArithmeticFeature):

ttl: timedelta | None
time_zone: str | None

@property
def dtype(self) -> FeatureType:
return FeatureType.datetime()
from zoneinfo import ZoneInfo

return FeatureType.datetime(ZoneInfo(self.time_zone) if self.time_zone else None)

def __init__(self, ttl: timedelta | None = None):
def __init__(self, ttl: timedelta | None = None, time_zone: str | None = 'UTC') -> None:
self.ttl = ttl
self.time_zone = time_zone

def event_timestamp(self) -> EventTimestampFeature:
return EventTimestampFeature(
Expand Down
22 changes: 19 additions & 3 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from aligned.schemas.codable import Codable
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation, FeatureType
from aligned.request.retrival_request import RequestResult, RetrivalRequest
from aligned.compiler.feature_factory import FeatureFactory
from polars.type_aliases import TimeUnit
Expand Down Expand Up @@ -203,7 +203,7 @@ def multi_source_features_for(
def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return type(self).multi_source_features_for(facts, [(self, request)])

async def schema(self) -> dict[str, FeatureFactory]:
async def schema(self) -> dict[str, FeatureType]:
"""Returns the schema for the data source
```python
Expand Down Expand Up @@ -259,7 +259,8 @@ class MyView(FeatureView):
from aligned.feature_view.feature_view import FeatureView

schema = await self.schema()
return FeatureView.feature_view_code_template(schema, f'{self}', view_name)
feature_types = {name: feature_type.feature_factory for name, feature_type in schema.items()}
return FeatureView.feature_view_code_template(feature_types, f'{self}', view_name)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
"""
Expand Down Expand Up @@ -378,6 +379,9 @@ class FilteredDataSource(BatchDataSource):
def job_group_key(self) -> str:
return f'subset/{self.source.job_group_key()}'

async def schema(self) -> dict[str, FeatureType]:
return await self.source.schema()

@classmethod
def multi_source_features_for(
cls: type[FilteredDataSource],
Expand Down Expand Up @@ -599,6 +603,12 @@ class JoinAsofDataSource(BatchDataSource):

type_name: str = 'join_asof'

async def schema(self) -> dict[str, FeatureType]:
left_schema = await self.source.schema()
right_schema = await self.right_source.schema()

return {**left_schema, **right_schema}

def job_group_key(self) -> str:
return f'join/{self.source.job_group_key()}'

Expand Down Expand Up @@ -720,6 +730,12 @@ class JoinDataSource(BatchDataSource):

type_name: str = 'join'

async def schema(self) -> dict[str, FeatureType]:
left_schema = await self.source.schema()
right_schema = await self.right_source.schema()

return {**left_schema, **right_schema}

def job_group_key(self) -> str:
return f'join/{self.source.job_group_key()}'

Expand Down
2 changes: 1 addition & 1 deletion aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def my_function(data: Annotated[pd.DataFrame, MyView]):
"""

def decorator(func: Callable) -> Callable:
def func_wrapper(*args, **kwargs) -> Any:
def func_wrapper(*args, **kwargs) -> Any: # type: ignore
from typing import _AnnotatedAlias # type: ignore

params_to_check = {
Expand Down
10 changes: 3 additions & 7 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ async def to_polars(self) -> SupervisedDataSet[pl.DataFrame]:
async def to_lazy_polars(self) -> SupervisedDataSet[pl.LazyFrame]:
data = await self.job.to_lazy_polars()
if self.should_filter_out_null_targets:
data = data.drop_nulls([column for column in self.target_columns])
data = data.drop_nulls(list(self.target_columns))

features = [
feature.name
Expand Down Expand Up @@ -1782,9 +1782,7 @@ async def to_pandas(self) -> pd.DataFrame:
df[feature.name] = df[feature.name].apply(
lambda x: json.loads(x) if isinstance(x, str) else x
)
elif feature.dtype == FeatureType.json():
pass
elif feature.dtype == FeatureType.datetime():
elif (feature.dtype == FeatureType.json()) or feature.dtype.is_datetime:
pass
else:
if feature.dtype.is_numeric:
Expand Down Expand Up @@ -1829,9 +1827,7 @@ async def to_lazy_polars(self) -> pl.LazyFrame:
dtype = df.select(feature.name).dtypes[0]
if dtype == pl.Utf8:
df = df.with_columns(pl.col(feature.name).str.json_extract(pl.List(pl.Utf8)))
elif feature.dtype == FeatureType.json():
pass
elif feature.dtype == FeatureType.datetime():
elif (feature.dtype == FeatureType.json()) or feature.dtype.is_datetime:
pass
else:
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type, strict=False))
Expand Down
85 changes: 63 additions & 22 deletions aligned/schemas/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,33 @@

from dataclasses import dataclass
from typing import Literal
from zoneinfo import ZoneInfo

import polars as pl

import aligned.compiler.feature_factory as ff
from aligned.schemas.codable import Codable
from aligned.schemas.constraints import Constraint

NAME_POLARS_MAPPING = {
'string': pl.Utf8,
'int8': pl.Int8,
'int16': pl.Int16,
'int32': pl.Int32,
'int64': pl.Int64,
'float': pl.Float64,
'double': pl.Float64,
'bool': pl.Boolean,
'date': pl.Date,
'datetime': pl.Datetime,
'time': pl.Time,
'timedelta': pl.Duration,
'uuid': pl.Utf8,
'array': pl.List(pl.Utf8),
'embedding': pl.List,
'json': pl.Utf8,
}
NAME_POLARS_MAPPING = [
('string', pl.Utf8),
('int8', pl.Int8),
('int16', pl.Int16),
('int32', pl.Int32),
('int64', pl.Int64),
('float', pl.Float64),
('float', pl.Float32),
('double', pl.Float64),
('bool', pl.Boolean),
('date', pl.Date),
('datetime', pl.Datetime),
('time', pl.Time),
('timedelta', pl.Duration),
('uuid', pl.Utf8),
('array', pl.List(pl.Utf8)),
('embedding', pl.List),
('json', pl.Utf8),
]


@dataclass
Expand All @@ -47,6 +49,10 @@ def is_numeric(self) -> bool:
'double',
} # Can be represented as an int

@property
def is_datetime(self) -> bool:
return self.name.startswith('datetime')

@property
def python_type(self) -> type:
from datetime import date, datetime, time, timedelta
Expand Down Expand Up @@ -98,10 +104,27 @@ def pandas_type(self) -> str | type:

@property
def polars_type(self) -> type:
return NAME_POLARS_MAPPING[self.name]
if self.name.startswith('datetime-'):
time_zone = self.name.split('-')[1]
return pl.Datetime(time_zone=time_zone) # type: ignore

for name, dtype in NAME_POLARS_MAPPING:
if name == self.name:
return dtype

raise ValueError(f'Unable to find a value that can represent {self.name}')

@property
def feature_factory(self) -> ff.FeatureFactory:

if self.name.startswith('datetime-'):
time_zone = self.name.split('-')[1]
return ff.Timestamp(time_zone=time_zone)

if self.name.startswith('array-'):
sub_type = '-'.join(self.name.split('-')[1:])
return ff.List(FeatureType(name=sub_type).feature_factory)

return {
'string': ff.String(),
'int8': ff.Int8(),
Expand Down Expand Up @@ -135,9 +158,25 @@ def __pre_serialize__(self) -> FeatureType:

@staticmethod
def from_polars(polars_type: pl.DataType) -> FeatureType:
for name, dtype in NAME_POLARS_MAPPING.items():
if isinstance(polars_type, pl.Datetime):
if polars_type.time_zone:
return FeatureType(name=f'datetime-{polars_type.time_zone}')
return FeatureType(name='datetime')

if isinstance(polars_type, pl.List):
if polars_type.inner:
sub_type = FeatureType.from_polars(polars_type.inner) # type: ignore
return FeatureType(name=f'array-{sub_type.name}')

return FeatureType(name='array')

if isinstance(polars_type, pl.Struct):
return FeatureType(name='json')

for name, dtype in NAME_POLARS_MAPPING:
if polars_type.is_(dtype):
return FeatureType(name=name)

raise ValueError(f'Unable to find a value that can represent {polars_type}')

@staticmethod
Expand Down Expand Up @@ -181,8 +220,10 @@ def uuid() -> FeatureType:
return FeatureType(name='uuid')

@staticmethod
def datetime() -> FeatureType:
return FeatureType(name='datetime')
def datetime(tz: ZoneInfo | None = ZoneInfo('UTC')) -> FeatureType:
if not tz:
return FeatureType(name='datetime')
return FeatureType(name=f'datetime-{tz.key}')

@staticmethod
def json() -> FeatureType:
Expand Down
8 changes: 7 additions & 1 deletion aligned/schemas/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aligned.schemas.codable import Codable
from aligned.schemas.derivied_feature import AggregatedFeature, DerivedFeature
from aligned.schemas.event_trigger import EventTrigger
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation, FeatureType
from aligned.schemas.vector_storage import VectorIndex

if TYPE_CHECKING:
Expand Down Expand Up @@ -329,6 +329,12 @@ class FeatureViewReferenceSource(BatchDataSource):
def job_group_key(self) -> str:
return self.view.name

async def schema(self) -> dict[str, FeatureType]:
if self.view.materialized_source:
return await self.view.materialized_source.schema()

return await self.view.source.schema()

def sub_request(self, request: RetrivalRequest) -> RetrivalRequest:

sub_references: set[str] = request.entity_names.union(request.feature_names)
Expand Down
7 changes: 6 additions & 1 deletion aligned/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from aligned.request.retrival_request import FeatureRequest, RetrivalRequest
from aligned.schemas.codable import Codable
from aligned.schemas.feature import FeatureLocation
from aligned.schemas.feature import FeatureLocation, FeatureType
from aligned.schemas.feature import EventTimestamp, Feature, FeatureReferance
from aligned.schemas.event_trigger import EventTrigger
from aligned.schemas.target import ClassificationTarget, RecommendationTarget, RegressionTarget
Expand Down Expand Up @@ -176,6 +176,11 @@ class ModelSource(BatchDataSource):

type_name: str = 'model_source'

async def schema(self) -> dict[str, FeatureType]:
if self.model.predictions_view.source:
return await self.model.predictions_view.source.schema()
return {}

def source(self) -> FeatureViewReferenceSource:
return FeatureViewReferenceSource(self.pred_view, FeatureLocation.model(self.pred_view.name))

Expand Down
26 changes: 12 additions & 14 deletions aligned/sources/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from aligned.schemas.date_formatter import DateFormatter

if TYPE_CHECKING:
from aligned.compiler.feature_factory import FeatureFactory
from datetime import datetime
from aligned.schemas.repo_definition import RepoDefinition
from aligned.feature_store import FeatureStore
Expand Down Expand Up @@ -266,14 +265,15 @@ def multi_source_features_for(
date_formatter=source.formatter,
)

async def schema(self) -> dict[str, FeatureFactory]:
async def schema(self) -> dict[str, FeatureType]:
df = await self.to_lazy_polars()
return {name: FeatureType.from_polars(pl_type).feature_factory for name, pl_type in df.schema.items()}
return {name: FeatureType.from_polars(pl_type) for name, pl_type in df.schema.items()}

async def feature_view_code(self, view_name: str) -> str:
from aligned.feature_view.feature_view import FeatureView

schema = await self.schema()
raw_schema = await self.schema()
schema = {name: feat.feature_factory for name, feat in raw_schema.items()}
data_source_code = f'FileSource.csv_at("{self.path}", csv_config={self.csv_config})'
return FeatureView.feature_view_code_template(
schema,
Expand Down Expand Up @@ -370,20 +370,19 @@ def multi_source_features_for(
facts=facts,
)

async def schema(self) -> dict[str, FeatureFactory]:
async def schema(self) -> dict[str, FeatureType]:
if self.path.startswith('http'):
parquet_schema = pl.scan_parquet(self.path).schema
else:
parquet_schema = pl.read_parquet_schema(self.path)

return {
name: FeatureType.from_polars(pl_type).feature_factory for name, pl_type in parquet_schema.items()
}
return {name: FeatureType.from_polars(pl_type) for name, pl_type in parquet_schema.items()}

async def feature_view_code(self, view_name: str) -> str:
from aligned.feature_view.feature_view import FeatureView

schema = await self.schema()
raw_schema = await self.schema()
schema = {name: feat.feature_factory for name, feat in raw_schema.items()}
data_source_code = f'FileSource.parquet_at("{self.path}")'
return FeatureView.feature_view_code_template(
schema, data_source_code, view_name, 'from aligned import FileSource'
Expand Down Expand Up @@ -436,16 +435,15 @@ async def write_polars(self, df: pl.LazyFrame) -> None:
self.path, mode=self.config.mode, overwrite_schema=self.config.overwrite_schema
)

async def schema(self) -> dict[str, FeatureFactory]:
async def schema(self) -> dict[str, FeatureType]:
parquet_schema = pl.read_delta(self.path).schema
return {
name: FeatureType.from_polars(pl_type).feature_factory for name, pl_type in parquet_schema.items()
}
return {name: FeatureType.from_polars(pl_type) for name, pl_type in parquet_schema.items()}

async def feature_view_code(self, view_name: str) -> str:
from aligned.feature_view.feature_view import FeatureView

schema = await self.schema()
raw_schema = await self.schema()
schema = {name: feat.feature_factory for name, feat in raw_schema.items()}
data_source_code = f'FileSource.parquet_at("{self.path}")'
return FeatureView.feature_view_code_template(
schema, data_source_code, view_name, 'from aligned import FileSource'
Expand Down
Loading

0 comments on commit 225addd

Please sign in to comment.