Skip to content

Commit

Permalink
fix: feature view ref source fact req
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Dec 13, 2023
1 parent 4a5cd60 commit 8639f2a
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 95 deletions.
15 changes: 9 additions & 6 deletions aligned/compiler/transformation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def using_features(self) -> list[FeatureFactory]:
return [self.left]

def compile(self) -> Transformation:
from aligned.schemas.transformation import Equals
from aligned.schemas.transformation import EqualsLiteral, Equals

if isinstance(self.right, FeatureFactory):
raise NotImplementedError()
return Equals(self.left.name, self.right.name)
else:
return Equals(self.left.name, LiteralValue.from_value(self.right))
return EqualsLiteral(self.left.name, LiteralValue.from_value(self.right))


@dataclass
Expand Down Expand Up @@ -122,17 +122,20 @@ def compile(self) -> Transformation:
@dataclass
class NotEqualsFactory(TransformationFactory):

value: Any
value: Any | FeatureFactory
in_feature: FeatureFactory

@property
def using_features(self) -> list[FeatureFactory]:
return [self.in_feature]

def compile(self) -> Transformation:
from aligned.schemas.transformation import NotEquals as NotEqualsTransformation
from aligned.schemas.transformation import NotEqualsLiteral, NotEquals as NotEqualsTransformation

return NotEqualsTransformation(self.in_feature.name, LiteralValue.from_value(self.value))
if isinstance(self.value, FeatureFactory):
return NotEqualsTransformation(self.in_feature.name, self.value.name)
else:
return NotEqualsLiteral(self.in_feature.name, LiteralValue.from_value(self.value))


@dataclass
Expand Down
21 changes: 21 additions & 0 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,27 @@ class FilteredDataSource(BatchDataSource):
def job_group_key(self) -> str:
return f'subset/{self.source.job_group_key()}'

@classmethod
def multi_source_features_for(
cls: type[FilteredDataSource],
facts: RetrivalJob,
requests: list[tuple[FilteredDataSource, RetrivalRequest]],
) -> RetrivalJob:

sources = {source.job_group_key() for source, _ in requests if isinstance(source, BatchDataSource)}
if len(sources) != 1:
raise NotImplementedError(
f'Type: {cls} have not implemented how to load fact data with multiple sources.'
)
source, request = requests[0]

if isinstance(source.condition, Feature):
request.features.add(source.condition)
else:
request.derived_features.add(source.condition)

return source.source.features_for(facts, request).filter(source.condition)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

if isinstance(self.condition, Feature):
Expand Down
5 changes: 5 additions & 0 deletions aligned/feature_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def features_for(self, facts: RetrivalJob, request: FeatureRequest) -> RetrivalJ
)
if has_derived_features:
job = job.derive_features()

if len(requests) == 1 and requests[0][1].aggregated_features:
req = requests[0][1]
job = job.aggregate(req)

jobs.append(job)

if len(combined_requests) > 0 or len(jobs) > 1:
Expand Down
74 changes: 74 additions & 0 deletions aligned/jobs/tests/test_derived_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest

from aligned import feature_view, Float, String, FileSource
from aligned.compiler.model import model_contract
from aligned.feature_store import FeatureStore
from aligned.local.job import FileFullJob
from aligned.retrival_job import DerivedFeatureJob, RetrivalRequest
from aligned.sources.local import LiteralReference
Expand Down Expand Up @@ -94,6 +96,34 @@ class ExpenceAgg:

IncomeAgg = ExpenceAgg.with_source(named='income_agg', source=Income) # type: ignore

income_agg = IncomeAgg()


@model_contract(
name='model',
features=[
expences.abs_amount,
expences.is_expence,
income_agg.total_amount,
],
)
class Model:
user_id = String().as_entity()

pred_amount = expences.amount.as_regression_label()


def feature_store() -> FeatureStore:
store = FeatureStore.experimental()

views = [Transaction, Expences, Income, ExpenceAgg, IncomeAgg]
for view in views:
store.add_compiled_view(view.compile())

store.add_compiled_model(Model.compile())

return store


@pytest.mark.asyncio
async def test_aggregate_over_derived() -> None:
Expand All @@ -103,3 +133,47 @@ async def test_aggregate_over_derived() -> None:
df = data.collect()

assert df.height == 2


@pytest.mark.asyncio
async def test_aggregate_over_derived_fact() -> None:

store = feature_store()

data = await store.features_for(
entities={'user_id': ['a', 'b']}, features=['income_agg:total_amount']
).to_polars()

df = data.collect()

assert df.height == 2


@pytest.mark.asyncio
async def test_model_with_label_multiple_views() -> None:

store = feature_store()

entities = await store.feature_view('expence').all().to_pandas()

data_job = store.model('model').with_labels().features_for(entities)
data = await data_job.to_pandas()

expected_df = pd.DataFrame(
{
'transaction_id': ['b', 'd', 'q', 'e'],
'user_id': ['b', 'b', 'a', 'a'],
'total_amount': [109.0, 109.0, 120.0, 120.0],
'is_expence': [True, True, True, True],
'abs_amount': [20, 100, 20, 100],
'amount': [-20.0, -100.0, -20.0, -100.0],
}
)

assert data.labels.shape[0] != 0
assert data.input.shape[1] == 3
assert data.input.shape[0] != 0

assert data.data.sort_values(['user_id', 'transaction_id'])[expected_df.columns].equals(
expected_df.sort_values(['user_id', 'transaction_id'])
)
95 changes: 59 additions & 36 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ async def aggregate_over(
@dataclass
class FileFactualJob(RetrivalJob):

source: DataFileReference
source: DataFileReference | RetrivalJob
requests: list[RetrivalRequest]
facts: RetrivalJob

Expand Down Expand Up @@ -351,11 +351,13 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:

row_id_name = 'row_id'
result = result.with_row_count(row_id_name)

for request in self.requests:

entity_names = request.entity_names
all_names = request.all_required_feature_names.union(entity_names)

column_selects = list(entity_names.union({'row_id'}))

if request.event_timestamp_request:
using_event_timestamp = event_timestamp_entity_column is not None
else:
Expand All @@ -364,42 +366,52 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
if request.event_timestamp:
all_names.add(request.event_timestamp.name)

all_names = list(all_names)

request_features = list(all_names)
if isinstance(self.source, ColumnFeatureMappable):
request_features = self.source.feature_identifier_for(all_names)

feature_df = df.select(request_features)

renames = {
org_name: wanted_name
for org_name, wanted_name in zip(request_features, all_names)
if org_name != wanted_name
}
if renames:
feature_df = feature_df.rename(renames)

for entity in request.entities:
feature_df = feature_df.with_columns(pl.col(entity.name).cast(entity.dtype.polars_type))
result = result.with_columns(pl.col(entity.name).cast(entity.dtype.polars_type))

column_selects = list(entity_names.union({'row_id'}))

if using_event_timestamp:
column_selects.append(event_timestamp_col)

# Need to only select the relevent entities and row_id
# Otherwise will we get a duplicate column error
# We also need to remove the entities after the row_id is joined
new_result: pl.LazyFrame = result.select(column_selects).join(
feature_df, on=list(entity_names), how='left'
)
new_result = new_result.select(pl.exclude(list(entity_names)))
missing_agg_features = [
feat for feat in request.aggregated_features if feat.name not in df.columns
]
if request.aggregated_features and not missing_agg_features:
new_result = result.join(
df.select(request.all_returned_columns), on=list(entity_names), how='left'
)
else:
all_names = list(all_names)

request_features = list(all_names)
if isinstance(self.source, ColumnFeatureMappable):
request_features = self.source.feature_identifier_for(all_names)

feature_df = df.select(request_features)

renames = {
org_name: wanted_name
for org_name, wanted_name in zip(request_features, all_names)
if org_name != wanted_name
}
if renames:
feature_df = feature_df.rename(renames)

for entity in request.entities:
feature_df = feature_df.with_columns(pl.col(entity.name).cast(entity.dtype.polars_type))
result = result.with_columns(pl.col(entity.name).cast(entity.dtype.polars_type))

# Need to only select the relevent entities and row_id
# Otherwise will we get a duplicate column error
# We also need to remove the entities after the row_id is joined
new_result: pl.LazyFrame = result.select(column_selects).join(
feature_df, on=list(entity_names), how='left'
)
new_result = new_result.select(pl.exclude(list(entity_names)))

for group, features in request.aggregate_over().items():
aggregated_df = await aggregate_over(group, features, new_result, event_timestamp_col)
new_result = new_result.join(aggregated_df, on='row_id', how='left')
for group, features in request.aggregate_over().items():
missing_features = [
feature.name for feature in features if feature.name not in df.columns
]
if missing_features:
aggregated_df = await aggregate_over(group, features, new_result, event_timestamp_col)
new_result = new_result.join(aggregated_df, on='row_id', how='left')

if request.event_timestamp and using_event_timestamp:
field = request.event_timestamp.name
Expand All @@ -409,6 +421,7 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
new_result = new_result.with_columns(
pl.col(field).str.strptime(pl.Datetime, '%+').alias(field)
)

if ttl:
ttl_request = (pl.col(field) <= pl.col(event_timestamp_col)) & (
pl.col(field) >= pl.col(event_timestamp_col) - ttl
Expand All @@ -420,12 +433,14 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
)
new_result = new_result.sort(field, descending=True).select(pl.exclude(field))
elif request.event_timestamp:
new_result = new_result.sort([row_id_name, request.event_timestamp.name], descending=True)
new_result = new_result.sort(
[row_id_name, request.event_timestamp.name], descending=True
).select(pl.exclude(request.event_timestamp.name))

unique = new_result.unique(subset=row_id_name, keep='first')
column_selects.remove('row_id')
result = result.join(unique.select(pl.exclude(column_selects)), on=row_id_name, how='left')
result = result.select(pl.exclude('.*_right$'))
result = result.select(pl.exclude('.*_right'))

if did_rename_event_timestamp:
result = result.rename({event_timestamp_col: event_timestamp_entity_column})
Expand All @@ -437,3 +452,11 @@ async def to_pandas(self) -> pd.DataFrame:

async def to_polars(self) -> pl.LazyFrame:
return await self.file_transformations(await self.source.to_polars())

def log_each_job(self) -> RetrivalJob:
from aligned.retrival_job import LogJob

if isinstance(self.source, RetrivalJob):
return FileFactualJob(LogJob(self.source), self.requests, self.facts)
else:
return self
19 changes: 12 additions & 7 deletions aligned/request/retrival_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,11 @@ def entity_columns(self) -> list[str]:
return [entity.name for entity in self.entities]

def __add__(self, obj: 'RequestResult') -> 'RequestResult':

return RequestResult(
entities=self.entities.union(obj.entities),
features=self.features.union(obj.features),
event_timestamp='event_timestamp' if self.event_timestamp or obj.event_timestamp else None,
event_timestamp=self.event_timestamp or obj.event_timestamp,
)

def filter_features(self, features_to_include: set[str]) -> 'RequestResult':
Expand All @@ -328,6 +329,10 @@ def from_request_list(requests: list[RetrivalRequest]) -> 'RequestResult':
if request_len == 0:
return RequestResult(entities=set(), features=set(), event_timestamp=None)
elif request_len > 1:
event_timestamp = None
requests_with_event = [req.event_timestamp for req in requests if req.event_timestamp]
if requests_with_event:
event_timestamp = requests_with_event[0].name
return RequestResult(
entities=set().union(*[request.entities for request in requests]),
features=set().union(
Expand All @@ -341,9 +346,7 @@ def from_request_list(requests: list[RetrivalRequest]) -> 'RequestResult':
for request in requests
]
),
event_timestamp='event_timestamp'
if any(request.event_timestamp for request in requests)
else None,
event_timestamp=event_timestamp,
)
else:
return RequestResult.from_request(requests[0])
Expand All @@ -354,12 +357,14 @@ def from_result_list(requests: list['RequestResult']) -> 'RequestResult':
if request_len == 0:
return RequestResult(entities=set(), features=set(), event_timestamp=None)
elif request_len > 1:
event_timestamp = None
requests_with_event = [req.event_timestamp for req in requests if req.event_timestamp]
if requests_with_event:
event_timestamp = requests_with_event[0]
return RequestResult(
entities=set().union(*[request.entities for request in requests]),
features=set().union(*[request.features for request in requests]),
event_timestamp='event_timestamp'
if any(request.event_timestamp for request in requests)
else None,
event_timestamp=event_timestamp,
)
else:
return requests[0]
Expand Down
Loading

0 comments on commit 8639f2a

Please sign in to comment.