Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added JoinAsof, and minor bug fixes #22

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion aligned/compiler/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations

import copy
import logging
from abc import ABC, abstractproperty
from dataclasses import dataclass, field
Expand Down Expand Up @@ -86,7 +88,20 @@ class ModelContractWrapper(Generic[T]):
def __call__(self) -> T:
# Needs to compiile the model to set the location for the view features
_ = self.compile()
return self.contract()

# Need to copy and set location in case filters are used.
# As this can lead to incorrect features otherwise
contract = copy.deepcopy(self.contract())
for attribute in dir(contract):
if attribute.startswith('__'):
continue

value = getattr(contract, attribute)
if isinstance(value, FeatureFactory):
value._location = FeatureLocation.model(self.metadata.name)
setattr(contract, attribute, copy.deepcopy(value))

return contract

def compile(self) -> ModelSchema:
return ModelContract.compile_with_metadata(self.contract(), self.metadata)
Expand Down
2 changes: 1 addition & 1 deletion aligned/data_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def upsert_on_column(columns: list[str], new_data: pl.LazyFrame, existing_data:
if column_diff:
raise ValueError(f'Mismatching columns, missing columns {column_diff}.')

combined = pl.concat([new_data, existing_data.select(new_data.columns)])
combined = pl.concat([new_data, existing_data.select(new_data.columns)], how='vertical_relaxed')
return combined.unique(columns, keep='first')


Expand Down
216 changes: 183 additions & 33 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation
from aligned.request.retrival_request import RetrivalRequest
from aligned.compiler.feature_factory import FeatureFactory

if TYPE_CHECKING:
from aligned.compiler.feature_factory import FeatureFactory
from aligned.retrival_job import RetrivalJob
from datetime import datetime

Expand All @@ -40,6 +40,7 @@ def __init__(self) -> None:
AwsS3ParquetDataSource,
RedshiftSQLDataSource,
JoinDataSource,
JoinAsofDataSource,
FilteredDataSource,
FeatureViewReferenceSource,
]
Expand Down Expand Up @@ -284,63 +285,212 @@ def job_group_key(self) -> str:
def wrap_job(self, job: RetrivalJob) -> RetrivalJob:
return job.filter(self.condition)

def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on()


def resolve_keys(keys: str | FeatureFactory | list[str] | list[FeatureFactory]) -> list[str]:

if isinstance(keys, FeatureFactory):
return [keys.name]

if isinstance(keys, str):
return [keys]

if isinstance(keys[0], FeatureFactory):
return [key.name for key in keys] # type: ignore

return keys # type: ignore


def view_wrapper_instance_source(view: Any) -> tuple[BatchDataSource, RetrivalRequest]:
from aligned.feature_view.feature_view import FeatureViewWrapper
from aligned.schemas.feature_view import FeatureViewReferenceSource

if not hasattr(view, '__view_wrapper__'):
raise ValueError(
f'Unable to join {view} as a __view_wrapper__ is needed. Make sure you have used @feature_view'
)

wrapper = getattr(view, '__view_wrapper__')
if not isinstance(wrapper, FeatureViewWrapper):
raise ValueError()

compiled_view = wrapper.compile()

return (FeatureViewReferenceSource(compiled_view), compiled_view.request_all.needed_requests[0])


def join_asof_source(
source: BatchDataSource,
left_request: RetrivalRequest,
view: Any,
left_on: list[str] | None = None,
right_on: list[str] | None = None,
) -> JoinAsofDataSource:

right_source, right_request = view_wrapper_instance_source(view)

left_event_timestamp = left_request.event_timestamp
right_event_timestamp = right_request.event_timestamp

if left_event_timestamp is None:
raise ValueError('A left event timestamp is needed, but found none.')
if right_event_timestamp is None:
raise ValueError('A right event timestamp is needed, but found none.')

return JoinAsofDataSource(
source=source,
left_request=left_request,
right_source=right_source,
right_request=right_request,
left_event_timestamp=left_event_timestamp.name,
right_event_timestamp=right_event_timestamp.name,
left_on=left_on,
right_on=right_on,
)


def join_source(
source: BatchDataSource,
view: Any,
on: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
how: str = 'inner',
left_request: RetrivalRequest | None = None,
) -> JoinDataSource:
from aligned.data_source.batch_data_source import JoinDataSource
from aligned.feature_view.feature_view import FeatureViewWrapper

right_source, right_request = view_wrapper_instance_source(view)

if on is None:
on_keys = list(right_request.entity_names)
else:
on_keys = resolve_keys(on)

if left_request is None:
if isinstance(source, JoinDataSource):
left_request = RetrivalRequest.unsafe_combine([source.left_request, source.right_request])
elif isinstance(source, FeatureViewWrapper):
left_request = source.compile().request_all.needed_requests[0]

if left_request is None:
raise ValueError('Unable to resolve the left request. Concider adding a `left_request` param.')

return JoinDataSource(
source=source,
left_request=left_request,
right_source=right_source,
right_request=right_request,
left_on=on_keys,
right_on=on_keys,
method=how,
)


@dataclass
class JoinDataSource(BatchSourceModification, BatchDataSource):
class JoinAsofDataSource(BatchDataSource):

source: BatchDataSource
left_request: RetrivalRequest
right_source: BatchDataSource
right_request: RetrivalRequest
left_on: str
right_on: str
method: str

type_name: str = 'join'
left_event_timestamp: str
right_event_timestamp: str

left_on: list[str] | None = None
right_on: list[str] | None = None

type_name: str = 'join_asof'

def job_group_key(self) -> str:
return f'join/{self.source.job_group_key()}'

def wrap_job(self, job: RetrivalJob) -> RetrivalJob:
def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

right_job = self.right_source.all_data(self.right_request, limit=None).derive_features(
[self.right_request]
)

return job.derive_features([self.left_request]).join(
right_job, self.method, (self.left_on, self.right_on)
return (
self.source.all_data(self.left_request, limit=limit)
.derive_features([self.left_request])
.join_asof(
right_job,
left_event_timestamp=self.left_event_timestamp,
right_event_timestamp=self.right_event_timestamp,
left_on=self.left_on,
right_on=self.right_on,
)
.derive_features([request])
)

def join(
self,
view: Any,
on: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
how: str = 'inner',
) -> JoinDataSource:
return join_source(self, view, on, how)

def join_asof(
self, view: Any, on: str | FeatureFactory | list[str] | list[FeatureFactory]
) -> JoinAsofDataSource:

left_on = None
right_on = None
if on:
left_on = resolve_keys(on)
right_on = left_on

left_request = RetrivalRequest.unsafe_combine([self.left_request, self.right_request])

return join_asof_source(
self, left_request=left_request, view=view, left_on=left_on, right_on=right_on
)

def join(self, view: Any, on: str | FeatureFactory, how: str = 'inner') -> BatchDataSource:
from aligned.compiler.feature_factory import FeatureFactory
from aligned.data_source.batch_data_source import JoinDataSource
from aligned.feature_view.feature_view import FeatureViewWrapper
def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on().intersection(self.right_source.depends_on())

if not hasattr(view, '__view_wrapper__'):
raise ValueError(f'Unable to join {view}')

wrapper = getattr(view, '__view_wrapper__')
if not isinstance(wrapper, FeatureViewWrapper):
raise ValueError()
@dataclass
class JoinDataSource(BatchDataSource):

if isinstance(on, FeatureFactory):
on = on.name
source: BatchDataSource
left_request: RetrivalRequest
right_source: BatchDataSource
right_request: RetrivalRequest
left_on: list[str]
right_on: list[str]
method: str

left_request = RetrivalRequest.unsafe_combine([self.left_request, self.right_request])
compiled_view = wrapper.compile()

request = compiled_view.request_all

return JoinDataSource(
source=self,
left_request=left_request,
right_source=compiled_view.materialized_source or compiled_view.source,
right_request=request.needed_requests[0],
left_on=on,
right_on=on,
method=how,
type_name: str = 'join'

def job_group_key(self) -> str:
return f'join/{self.source.job_group_key()}'

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

right_job = self.right_source.all_data(self.right_request, limit=None).derive_features(
[self.right_request]
)

return (
self.source.all_data(self.left_request, limit=limit)
.derive_features([self.left_request])
.join(right_job, method=self.method, left_on=self.left_on, right_on=self.right_on)
.derive_features([request])
)

def join(
self,
view: Any,
on: str | FeatureFactory | list[str] | list[FeatureFactory] | None = None,
how: str = 'inner',
) -> BatchDataSource:
return join_source(self, view, on, how)

def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on().intersection(self.right_source.depends_on())

Expand Down
Loading