Skip to content

Commit

Permalink
fix: join source bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Dec 16, 2023
1 parent 8639f2a commit 9b5f943
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 20 deletions.
29 changes: 21 additions & 8 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ def join_asof_source(
def join_source(
source: BatchDataSource,
view: Any,
on: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
on_left: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
on_right: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
how: str = 'inner',
left_request: RetrivalRequest | None = None,
) -> JoinDataSource:
Expand All @@ -392,10 +393,15 @@ def join_source(

right_source, right_request = view_wrapper_instance_source(view)

if on is None:
on_keys = list(right_request.entity_names)
if on_left is None:
left_keys = list(right_request.entity_names)
else:
on_keys = resolve_keys(on)
left_keys = resolve_keys(on_left)

if on_right is None:
right_keys = list(right_request.entity_names)
else:
right_keys = resolve_keys(on_right)

if left_request is None:
if isinstance(source, JoinDataSource):
Expand All @@ -411,8 +417,8 @@ def join_source(
left_request=left_request,
right_source=right_source,
right_request=right_request,
left_on=on_keys,
right_on=on_keys,
left_on=left_keys,
right_on=right_keys,
method=how,
)

Expand Down Expand Up @@ -545,9 +551,16 @@ def join(
self,
view: Any,
on: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
on_left: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
on_right: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
how: str = 'inner',
) -> BatchDataSource:
return join_source(self, view, on, how)
) -> JoinDataSource:

if on:
on_left = on
on_right = on

return join_source(self, view, on_left, on_right, how)

def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on().intersection(self.right_source.depends_on())
Expand Down
2 changes: 0 additions & 2 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,6 @@ def features_for(
if isinstance(entities, dict):
# Do not load the features if they already exist as an entity
request.features = {feature for feature in request.features if feature.name not in entities}
if len(request.features) == 0 and request.location.location != 'combined_view':
request.derived_features = set()

return self.features_for_request(requests, entities, feature_names)

Expand Down
16 changes: 13 additions & 3 deletions aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,29 @@ def filter(

if materialize_source:
meta.materialized_source = materialize_source
meta.source = main_source

return FeatureViewWrapper(metadata=meta, view=self.view)

def join(
self, view: Any, on: str | FeatureFactory | list[str] | list[FeatureFactory], how: str = 'inner'
self,
view: Any,
on: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
on_left: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
on_right: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
how: str = 'inner',
) -> JoinDataSource:
from aligned.schemas.feature_view import FeatureViewReferenceSource

compiled_view = self.compile()
source = FeatureViewReferenceSource(compiled_view)

return join_source(source, view, on, how, left_request=compiled_view.request_all.needed_requests[0])
if on:
on_left = on
on_right = on

return join_source(
source, view, on_left, on_right, how, left_request=compiled_view.request_all.needed_requests[0]
)

def join_asof(
self, view: Any, on: str | FeatureFactory | list[str] | list[FeatureFactory]
Expand Down
58 changes: 56 additions & 2 deletions aligned/feature_view/tests/test_joined_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ class RightData:
other_feature = Int32()


@feature_view(name='right', source=FileSource.csv_at('some_file.csv'))
class RightOtherIdData:

other_id = Int32().as_entity()

other_feature = Int32()


@pytest.mark.asyncio
async def test_join_different_types_polars() -> None:

Expand Down Expand Up @@ -52,12 +60,58 @@ async def test_join_different_types_polars() -> None:


@pytest.mark.asyncio
async def test_unique_entities() -> None:
async def test_join_different_join_keys() -> None:

left_data = LeftData.from_data( # type: ignore
pl.DataFrame(
{'some_id': [1, 2, 3], 'feature': [2, 3, 4]}, schema={'some_id': pl.Int8, 'feature': pl.Int32}
)
)

right_data = RightOtherIdData.from_data( # type: ignore
pl.DataFrame(
{'other_id': [1, 3, 2], 'other_feature': [3, 4, 5]},
schema={'other_id': pl.Int16, 'other_feature': pl.Int32},
)
)

expected_df = pl.DataFrame(
data={'some_id': [1, 2, 3], 'feature': [2, 3, 4], 'other_feature': [3, 5, 4]},
schema={
'some_id': pl.Int32,
'feature': pl.Int32,
'other_feature': pl.Int32,
},
)

new_data = left_data.join(right_data, 'inner', left_on='some_id', right_on='other_id')

result = await new_data.to_polars()
req_result = new_data.request_result

joined = result.collect().sort('some_id', descending=False)

assert joined.frame_equal(expected_df.select(joined.columns))
assert joined.select(req_result.entity_columns).frame_equal(expected_df.select(['some_id']))


@pytest.mark.asyncio
async def test_unique_entities() -> None:

data = LeftData.from_data( # type: ignore
pl.DataFrame(
{'some_id': [1, 3, 3], 'feature': [2, 3, 4]}, schema={'some_id': pl.Int8, 'feature': pl.Int32}
)
)
expected_df = pl.DataFrame(
data={'some_id': [1, 3], 'feature': [2, 4]},
schema={
'some_id': pl.Int8,
'feature': pl.Int32,
},
)

result = await data.unique_on(['some_id'], sort_key='feature').to_polars()
sorted = result.sort('some_id').select(['some_id', 'feature']).collect()

left_data.unique_entities()
assert sorted.frame_equal(expected_df)
32 changes: 29 additions & 3 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def cached_at(self, location: DataFileReference | str) -> SupervisedJob:
self.target_columns,
)

def validate(self, validator: Validator) -> SupervisedJob:
def drop_invalid(self, validator: Validator) -> SupervisedJob:
return SupervisedJob(
self.job.drop_invalid(validator),
self.target_columns,
Expand Down Expand Up @@ -401,6 +401,12 @@ def unique_on(self, unique_on: list[str], sort_key: str | None = None) -> Retriv
def unique_entities(self) -> RetrivalJob:
request = self.request_result

if not request.event_timestamp:
logger.info(
'Unable to find event_timestamp for `unique_entities`. '
'This can lead to inconsistent features.'
)

return self.unique_on(unique_on=request.entity_columns, sort_key=request.event_timestamp)

def fill_missing_columns(self) -> RetrivalJob:
Expand Down Expand Up @@ -594,7 +600,21 @@ class JoinJobs(RetrivalJob):

@property
def request_result(self) -> RequestResult:
return RequestResult.from_result_list([self.left_job.request_result, self.right_job.request_result])
request = RequestResult.from_result_list(
[self.left_job.request_result, self.right_job.request_result]
)

right_entities = self.right_job.request_result.entities

for feature in right_entities:
if feature.name in self.right_on:
request.entities.remove(feature)

return request

@property
def retrival_requests(self) -> list[RetrivalRequest]:
return RetrivalRequest.combine(self.left_job.retrival_requests + self.right_job.retrival_requests)

async def to_polars(self) -> pl.LazyFrame:
left = await self.left_job.to_polars()
Expand Down Expand Up @@ -1375,8 +1395,14 @@ async def to_pandas(self) -> pd.DataFrame:
df[feature.name] = pd.to_datetime(df[feature.name], infer_datetime_format=True, utc=True)
elif feature.dtype == FeatureType.datetime() or feature.dtype == FeatureType.string():
continue
elif feature.dtype != FeatureType.array():
elif feature.dtype == FeatureType.array():
import json

if df[feature.name].dtype == 'object':
df[feature.name] = df[feature.name].apply(
lambda x: json.loads(x) if isinstance(x, str) else x
)
else:
if feature.dtype.is_numeric:
df[feature.name] = pd.to_numeric(df[feature.name], errors='coerce').astype(
feature.dtype.pandas_type
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.53"
version = "0.0.54"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down
Binary file modified test_data/credit_history_mater.parquet
Binary file not shown.
2 changes: 1 addition & 1 deletion test_data/feature-store.json

Large diffs are not rendered by default.

Binary file modified test_data/test_model.parquet
Binary file not shown.

0 comments on commit 9b5f943

Please sign in to comment.