Skip to content

Commit

Permalink
fix: some event timestamp bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Mats E. Mollestad committed Nov 2, 2023
1 parent 195602f commit 9e5b686
Show file tree
Hide file tree
Showing 13 changed files with 346 additions and 56 deletions.
4 changes: 3 additions & 1 deletion aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,9 @@ def __init__(self, ttl: timedelta | None = None):

def event_timestamp(self) -> EventTimestampFeature:
return EventTimestampFeature(
name=self.name, ttl=self.ttl.total_seconds() if self.ttl else None, description=self._description
name=self.name,
ttl=int(self.ttl.total_seconds()) if self.ttl else None,
description=self._description,
)


Expand Down
6 changes: 6 additions & 0 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
.freshness()
)
"""
from aligned.data_file import DataFileReference
from aligned.sources.local import data_file_freshness

if isinstance(self, DataFileReference):
return await data_file_freshness(self, event_timestamp.name)

raise NotImplementedError(f'Freshness is not implemented for {type(self)}.')


Expand Down
4 changes: 2 additions & 2 deletions aligned/feature_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def features_for(self, facts: RetrivalJob, request: FeatureRequest) -> RetrivalJ

async def freshness_for(
self, locations: dict[FeatureLocation, EventTimestamp]
) -> dict[FeatureLocation, datetime]:
) -> dict[FeatureLocation, datetime | None]:
raise NotImplementedError()


Expand Down Expand Up @@ -136,7 +136,7 @@ def all_between(self, start_date: datetime, end_date: datetime, request: Feature

async def freshness_for(
self, locations: dict[FeatureLocation, EventTimestamp]
) -> dict[FeatureLocation, datetime]:
) -> dict[FeatureLocation, datetime | None]:
locs = list(locations.keys())
results = await asyncio.gather(
*[self.sources[loc.identifier].freshness(locations[loc]) for loc in locs]
Expand Down
58 changes: 34 additions & 24 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,8 @@ def features_for(

feature_names = set()

if event_timestamp_column:
if event_timestamp_column and requests.needs_event_timestamp:
feature_names.add(event_timestamp_column)
if isinstance(entities, dict) and event_timestamp_column in entities:
length = len(list(entities.values())[0])
entities[event_timestamp_column] = [datetime.utcnow()] * length

for view, feature_set in feature_request.grouped_features.items():
if feature_set != {'*'}:
Expand Down Expand Up @@ -345,8 +342,6 @@ def _requests_for(
requests: list[RetrivalRequest] = []
entity_names = set()

needs_event_timestamp = False

for location in feature_request.locations:
location_name = location.name
if location.location == 'model':
Expand All @@ -358,8 +353,6 @@ def _requests_for(
request = view.request_for(features[location], location_name)
requests.append(request)
entity_names.update(request.entity_names)
if request.event_timestamp:
needs_event_timestamp = True

elif location_name in combined_feature_views:
cfv = combined_feature_views[location_name]
Expand All @@ -370,8 +363,6 @@ def _requests_for(
requests.extend(sub_requests.needed_requests)
for request in sub_requests.needed_requests:
entity_names.update(request.entity_names)
if request.event_timestamp:
needs_event_timestamp = True

elif location_name in feature_views:
feature_view = feature_views[location_name]
Expand All @@ -382,16 +373,18 @@ def _requests_for(
requests.extend(sub_requests.needed_requests)
for request in sub_requests.needed_requests:
entity_names.update(request.entity_names)
if request.event_timestamp:
needs_event_timestamp = True
else:
raise ValueError(
f'Unable to find: {location_name}, '
f'availible views are: {combined_feature_views.keys()}, and: {feature_views.keys()}'
)

if needs_event_timestamp and event_timestamp_column:
if event_timestamp_column:
entity_names.add(event_timestamp_column)
requests = [request.with_event_timestamp_column(event_timestamp_column) for request in requests]

else:
requests = [request.without_event_timestamp() for request in requests]

return FeatureRequest(
FeatureLocation.model('custom features'),
Expand Down Expand Up @@ -688,14 +681,14 @@ def features_for(

return job.select_columns(request.features_to_include)

async def freshness(self) -> dict[FeatureLocation, datetime]:
async def freshness(self) -> dict[FeatureLocation, datetime | None]:
from aligned.schemas.feature import EventTimestamp

locs: dict[FeatureLocation, EventTimestamp] = {}

for req in self.request().needed_requests:
if req.event_timestamp:
locs[req.location]
locs[req.location] = req.event_timestamp

return await self.store.feature_source.freshness_for(locs)

Expand Down Expand Up @@ -782,10 +775,14 @@ def process_features(self, input: RetrivalJob | ConvertableToRetrivalJob) -> Ret
.select_columns(request.features_to_include)
)

def predictions_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> RetrivalJob:
def predictions_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None
) -> RetrivalJob:

location_id = self.location.identifier
return self.store.features_for(entities, features=[f'{location_id}:*'])
return self.store.features_for(
entities, features=[f'{location_id}:*'], event_timestamp_column=event_timestamp_column
)

def all_predictions(self, limit: int | None = None) -> RetrivalJob:

Expand Down Expand Up @@ -912,7 +909,9 @@ class SupervisedModelFeatureStore:
model: ModelSchema
store: FeatureStore

def features_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> SupervisedJob:
def features_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None
) -> SupervisedJob:
"""Loads the features and labels for a model
```python
Expand Down Expand Up @@ -956,9 +955,11 @@ def features_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> Supe
else:
raise ValueError('Found no targets in the model')

request = self.store.requests_for(RawStringFeatureRequest(features))
request = self.store.requests_for(
RawStringFeatureRequest(features), event_timestamp_column=event_timestamp_column
)
target_request = self.store.requests_for(
RawStringFeatureRequest(target_features)
RawStringFeatureRequest(target_features), event_timestamp_column=event_timestamp_column
).without_event_timestamp(name_sufix='target')

total_request = FeatureRequest(
Expand All @@ -972,7 +973,9 @@ def features_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> Supe
target_columns=targets,
)

def predictions_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> RetrivalJob:
def predictions_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None
) -> RetrivalJob:
"""Loads the predictions and labels / ground truths for a model
```python
Expand Down Expand Up @@ -1017,7 +1020,9 @@ def predictions_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> R
labels = pred_view.labels()
target_features = {feature.identifier for feature in target_features}
pred_features = {f'model:{self.model.name}:{feature.name}' for feature in labels}
request = self.store.requests_for(RawStringFeatureRequest(pred_features))
request = self.store.requests_for(
RawStringFeatureRequest(pred_features), event_timestamp_column=event_timestamp_column
)
target_request = self.store.requests_for(
RawStringFeatureRequest(target_features)
).without_event_timestamp(name_sufix='target')
Expand Down Expand Up @@ -1134,12 +1139,17 @@ def previous(self, days: int = 0, minutes: int = 0, seconds: int = 0) -> Retriva
start_date = end_date - timedelta(days=days, minutes=minutes, seconds=seconds)
return self.between_dates(start_date, end_date)

def features_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> RetrivalJob:
def features_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None
) -> RetrivalJob:

request = self.view.request_all
if self.feature_filter:
request = self.view.request_for(self.feature_filter)

if not event_timestamp_column:
request = request.without_event_timestamp()

if isinstance(entities, RetrivalJob):
entity_job = entities
else:
Expand Down Expand Up @@ -1256,7 +1266,7 @@ async def batch_write(self, values: ConvertableToRetrivalJob | RetrivalJob) -> N
with feature_view_write_time.labels(self.view.name).time():
await self.source.write(job, job.retrival_requests)

async def freshness(self) -> datetime:
async def freshness(self) -> datetime | None:

view = self.view
if not view.event_timestamp:
Expand Down
41 changes: 31 additions & 10 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,38 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:

result = await self.facts.to_polars()
event_timestamp_col = 'aligned_event_timestamp'
using_event_timestamp = False
if 'event_timestamp' in result.columns:
using_event_timestamp = True
result = result.rename({'event_timestamp': event_timestamp_col})

event_timestamp_entity_columns = [
req.event_timestamp_request.entity_column for req in self.requests if req.event_timestamp_request
]
event_timestamp_entity_column = None
did_rename_event_timestamp = False

if event_timestamp_entity_columns:
event_timestamp_entity_column = event_timestamp_entity_columns[0]

if event_timestamp_entity_column and event_timestamp_entity_column in result:
result = result.rename({event_timestamp_entity_column: event_timestamp_col})
did_rename_event_timestamp = True

row_id_name = 'row_id'
result = result.with_row_count(row_id_name)

for request in self.requests:
entity_names = request.entity_names
all_names = request.all_required_feature_names.union(entity_names)

if request.event_timestamp_request:
using_event_timestamp = event_timestamp_entity_column is not None
else:
using_event_timestamp = False

if request.event_timestamp:
all_names.add(request.event_timestamp.name)

request_features = all_names
all_names = list(all_names)

request_features = list(all_names)
if isinstance(self.source, ColumnFeatureMappable):
request_features = self.source.feature_identifier_for(all_names)

Expand All @@ -259,7 +276,8 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
result = result.with_columns(pl.col(entity.name).cast(entity.dtype.polars_type))

column_selects = list(entity_names.union({'row_id'}))
if request.event_timestamp:

if using_event_timestamp:
column_selects.append(event_timestamp_col)

# Need to only select the relevent entities and row_id
Expand All @@ -274,7 +292,7 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
aggregated_df = await self.aggregate_over(group, features, new_result, event_timestamp_col)
new_result = new_result.join(aggregated_df, on='row_id', how='left')

if request.event_timestamp:
if request.event_timestamp and using_event_timestamp:
field = request.event_timestamp.name
ttl = request.event_timestamp.ttl

Expand All @@ -292,13 +310,16 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
pl.col(field).is_null() | (pl.col(field) <= pl.col(event_timestamp_col))
)
new_result = new_result.sort(field, descending=True).select(pl.exclude(field))
elif request.event_timestamp:
new_result = new_result.sort([row_id_name, request.event_timestamp.name], descending=True)

unique = new_result.unique(subset=row_id_name, keep='first')
result = result.join(unique, on=row_id_name, how='left')
column_selects.remove('row_id')
result = result.join(unique.select(pl.exclude(column_selects)), on=row_id_name, how='left')
result = result.select(pl.exclude('.*_right$'))

if using_event_timestamp:
result = result.rename({event_timestamp_col: 'event_timestamp'})
if did_rename_event_timestamp:
result = result.rename({event_timestamp_col: event_timestamp_entity_column})

return result.select([pl.exclude('row_id')])

Expand Down
27 changes: 23 additions & 4 deletions aligned/request/retrival_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class EventTimestampRequest(Codable):

event_timestamp: EventTimestamp
entity_column: str = field(default='event_timestamp')
entity_column: str | None = field(default=None)


@dataclass
Expand Down Expand Up @@ -59,8 +59,7 @@ def __init__(
self.event_timestamp_request = event_timestamp_request
elif event_timestamp:
self.event_timestamp_request = EventTimestampRequest(
event_timestamp=event_timestamp,
entity_column=entity_timestamp_columns or 'event_timestamp',
event_timestamp=event_timestamp, entity_column=entity_timestamp_columns
)
self.features_to_include = features_to_include or self.all_feature_names

Expand Down Expand Up @@ -162,13 +161,33 @@ def aggregate_over(self) -> dict[AggregateOver, set[AggregatedFeature]]:
return features

def without_event_timestamp(self, name_sufix: str | None = None) -> 'RetrivalRequest':

request = None
if self.event_timestamp_request:
request = EventTimestampRequest(self.event_timestamp_request.event_timestamp, None)

return RetrivalRequest(
name=f'{self.name}{name_sufix or ""}',
location=self.location,
entities=self.entities,
features=self.features,
derived_features=self.derived_features,
aggregated_features=self.aggregated_features,
event_timestamp_request=request,
)

def with_event_timestamp_column(self, column: str) -> 'RetrivalRequest':
et_request = None
if self.event_timestamp_request:
et_request = EventTimestampRequest(self.event_timestamp_request.event_timestamp, column)
return RetrivalRequest(
name=self.name,
location=self.location,
entities=self.entities,
features=self.features,
derived_features=self.derived_features,
aggregated_features=self.aggregated_features,
event_timestamp_request=et_request,
)

@staticmethod
Expand All @@ -187,7 +206,7 @@ def combine(requests: list['RetrivalRequest']) -> list['RetrivalRequest']:
features=request.features,
derived_features=request.derived_features,
aggregated_features=request.aggregated_features,
event_timestamp=request.event_timestamp,
event_timestamp_request=request.event_timestamp_request,
)
returned_features[fv_name] = request.returned_features
else:
Expand Down
16 changes: 9 additions & 7 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,7 @@ async def to_polars(self) -> pl.LazyFrame:
features_to_check = {feature.derived_feature for feature in request.aggregated_features}

for feature in features_to_check:

if feature.dtype == FeatureType('').bool:
df = df.with_columns(pl.col(feature.name).cast(pl.Int8).cast(pl.Boolean))
elif feature.dtype == FeatureType('').datetime:
Expand All @@ -1213,13 +1214,14 @@ async def to_polars(self) -> pl.LazyFrame:
if feature.name not in df.columns:
continue
current_dtype = df.select([feature.name]).dtypes[0]
if isinstance(current_dtype, pl.Datetime):
continue
df = df.with_columns(
(pl.col(feature.name).cast(pl.Int64) * 1000)
.cast(pl.Datetime(time_zone='UTC'))
.alias(feature.name)
)

if not isinstance(current_dtype, pl.Datetime):
df = df.with_columns(
(pl.col(feature.name).cast(pl.Int64) * 1000)
.cast(pl.Datetime(time_zone='UTC'))
.alias(feature.name)
)

return df

def remove_derived_features(self) -> RetrivalJob:
Expand Down
Loading

0 comments on commit 9e5b686

Please sign in to comment.