Skip to content

Commit

Permalink
Changed from to_polars to to_lazy_polars
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Feb 10, 2024
1 parent 4206359 commit ecd0173
Show file tree
Hide file tree
Showing 41 changed files with 334 additions and 291 deletions.
6 changes: 3 additions & 3 deletions aligned/active_learning/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class ActiveLearningJob(RetrivalJob):
selection: ActiveLearningSelection
write_policy: ActiveLearningWritePolicy

async def to_polars(self) -> pl.LazyFrame:
async def to_lazy_polars(self) -> pl.LazyFrame:
if not self.model.predictions_view.classification_targets:
logger.info('Found no target. Therefore, no data will be written to an active learning dataset.')
return await self.job.to_polars()
return await self.job.to_lazy_polars()

data = await self.job.to_polars()
data = await self.job.to_lazy_polars()
active_learning_set = self.selection.select(self.model, data, self.metric)
await self.write_policy.write(active_learning_set, self.model)
return data
Expand Down
10 changes: 10 additions & 0 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Type, TypeVar, Generic, TYPE_CHECKING
from datetime import timedelta

from uuid import uuid4

Expand Down Expand Up @@ -54,6 +55,9 @@ class ModelMetadata:
prediction_stream: StreamDataSource | None = field(default=None)
application_source: BatchDataSource | None = field(default=None)

acceptable_freshness: timedelta | None = field(default=None)
unacceptable_freshness: timedelta | None = field(default=None)

exposed_at_url: str | None = field(default=None)

dataset_store: DatasetStore | None = field(default=None)
Expand Down Expand Up @@ -171,6 +175,8 @@ def model_contract(
application_source: BatchDataSource | None = None,
dataset_store: DatasetStore | StorageFileReference | None = None,
exposed_at_url: str | None = None,
acceptable_freshness: timedelta | None = None,
unacceptable_freshness: timedelta | None = None,
) -> Callable[[Type[T]], ModelContractWrapper[T]]:
def decorator(cls: Type[T]) -> ModelContractWrapper[T]:

Expand All @@ -190,6 +196,8 @@ def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
application_source=application_source,
dataset_store=resolve_dataset_store(dataset_store) if dataset_store else None,
exposed_at_url=exposed_at_url,
acceptable_freshness=acceptable_freshness,
unacceptable_freshness=unacceptable_freshness,
)
return ModelContractWrapper(metadata, cls)

Expand Down Expand Up @@ -225,6 +233,8 @@ class MyModel(ModelContract):
classification_targets=set(),
regression_targets=set(),
recommendation_targets=set(),
acceptable_freshness=metadata.acceptable_freshness,
unacceptable_freshness=metadata.unacceptable_freshness,
)
probability_features: dict[str, set[TargetProbability]] = {}

Expand Down
5 changes: 4 additions & 1 deletion aligned/data_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ async def read_pandas(self) -> pd.DataFrame:
async def to_pandas(self) -> pd.DataFrame:
return await self.read_pandas()

async def to_polars(self) -> pl.LazyFrame:
async def to_lazy_polars(self) -> pl.LazyFrame:
raise NotImplementedError()

async def to_polars(self) -> pl.DataFrame:
return (await self.to_lazy_polars()).collect()

async def write_polars(self, df: pl.LazyFrame) -> None:
raise NotImplementedError()

Expand Down
8 changes: 4 additions & 4 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,9 @@ async def insert_into(
import polars as pl

columns = write_request.all_returned_columns
new_df = (await values.to_polars()).select(columns)
new_df = (await values.to_lazy_polars()).select(columns)
try:
existing_df = await source.to_polars()
existing_df = await source.to_lazy_polars()
write_df = pl.concat([new_df, existing_df.select(columns)], how='vertical_relaxed')
except UnableToFindFileException:
write_df = new_df
Expand Down Expand Up @@ -710,10 +710,10 @@ async def upsert_into(
if isinstance(source, WritableFeatureSource):
await source.upsert(values, [write_request])
elif isinstance(source, DataFileReference):
new_df = (await values.to_polars()).select(write_request.all_returned_columns)
new_df = (await values.to_lazy_polars()).select(write_request.all_returned_columns)
entities = list(write_request.entity_names)
try:
existing_df = await source.to_polars()
existing_df = await source.to_lazy_polars()
write_df = upsert_on_column(entities, new_df, existing_df)
except UnableToFindFileException:
write_df = new_df
Expand Down
2 changes: 1 addition & 1 deletion aligned/feature_view/combined_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SomeView:
return store.feature_view(self.metadata.name)

async def process(self, data: dict[str, list[Any]]) -> list[dict]:
df = await self.query().process_input(data).to_polars()
df = await self.query().process_input(data).to_lazy_polars()
return df.collect().to_dicts()


Expand Down
17 changes: 14 additions & 3 deletions aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import polars as pl
import pandas as pd

from datetime import timedelta
from abc import ABC, abstractproperty
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar, Generic, Type, Callable
Expand Down Expand Up @@ -58,6 +59,8 @@ class FeatureViewMetadata:
materialized_source: BatchDataSource | None = field(default=None)
contacts: list[str] | None = field(default=None)
tags: dict[str, str] = field(default_factory=dict)
acceptable_freshness: timedelta | None = field(default=None)
unacceptable_freshness: timedelta | None = field(default=None)

@staticmethod
def from_compiled(view: CompiledFeatureView) -> FeatureViewMetadata:
Expand All @@ -69,6 +72,8 @@ def from_compiled(view: CompiledFeatureView) -> FeatureViewMetadata:
stream_source=view.stream_data_source,
application_source=view.application_source,
materialized_source=view.materialized_source,
acceptable_freshness=view.acceptable_freshness,
unacceptable_freshness=view.unacceptable_freshness,
)


Expand All @@ -94,6 +99,8 @@ def feature_view(
materialized_source: BatchDataSource | None = None,
contacts: list[str] | None = None,
tags: dict[str, str] | None = None,
acceptable_freshness: timedelta | None = None,
unacceptable_freshness: timedelta | None = None,
) -> Callable[[Type[T]], FeatureViewWrapper[T]]:
def decorator(cls: Type[T]) -> FeatureViewWrapper[T]:

Expand All @@ -106,6 +113,8 @@ def decorator(cls: Type[T]) -> FeatureViewWrapper[T]:
materialized_source=materialized_source,
contacts=contacts,
tags=tags or {},
acceptable_freshness=acceptable_freshness,
unacceptable_freshness=unacceptable_freshness,
)
return FeatureViewWrapper(metadata, cls())

Expand Down Expand Up @@ -296,7 +305,7 @@ def process_input(self, data: ConvertableToRetrivalJob) -> RetrivalJob:
return self.query().process_input(data)

async def process(self, data: ConvertableToRetrivalJob) -> list[dict]:
df = await self.query().process_input(data).to_polars()
df = await self.query().process_input(data).to_lazy_polars()
return df.collect().to_dicts()

async def freshness_in_source(self, source: BatchDataSource) -> datetime | None:
Expand Down Expand Up @@ -471,6 +480,8 @@ def compile_with_metadata(feature_view: Any, metadata: FeatureViewMetadata) -> C
stream_data_source=metadata.stream_source,
application_source=metadata.application_source,
materialized_source=metadata.materialized_source,
acceptable_freshness=metadata.acceptable_freshness,
unacceptable_freshness=metadata.unacceptable_freshness,
indexes=[],
)
aggregations: list[FeatureFactory] = []
Expand Down Expand Up @@ -626,7 +637,7 @@ class SomeView(FeatureView):

@classmethod
async def process(cls, data: dict[str, list[Any]]) -> list[dict]:
df = await cls.query().process_input(data).to_polars()
df = await cls.query().process_input(data).to_lazy_polars()
return df.collect().to_dicts()

@staticmethod
Expand Down Expand Up @@ -684,7 +695,7 @@ class MyView:
@feature_view(
name="{view_name}",
description="some description",
source={batch_source_code}
source={batch_source_code},
stream_source=None,
)
class MyView:
Expand Down
14 changes: 7 additions & 7 deletions aligned/feature_view/tests/test_joined_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ async def test_join_different_types_polars() -> None:
)

new_data = left_data.join(right_data, 'inner', left_on='some_id', right_on='some_id')
result = await new_data.to_polars()
result = await new_data.to_lazy_polars()

joined = result.collect().sort('some_id', descending=False)
assert joined.frame_equal(expected_df.select(joined.columns))
assert joined.equals(expected_df.select(joined.columns))


@pytest.mark.asyncio
Expand Down Expand Up @@ -111,13 +111,13 @@ async def test_join_different_join_keys() -> None:

new_data = left_data.join(right_data, 'inner', left_on='some_id', right_on='other_id')

result = await new_data.to_polars()
result = await new_data.to_lazy_polars()
req_result = new_data.request_result

joined = result.collect().sort('some_id', descending=False)

assert joined.frame_equal(expected_df.select(joined.columns))
assert joined.select(req_result.entity_columns).frame_equal(expected_df.select(['some_id']))
assert joined.equals(expected_df.select(joined.columns))
assert joined.select(req_result.entity_columns).equals(expected_df.select(['some_id']))


@pytest.mark.asyncio
Expand All @@ -136,10 +136,10 @@ async def test_unique_entities() -> None:
},
)

result = await data.unique_on(['some_id'], sort_key='feature').to_polars()
result = await data.unique_on(['some_id'], sort_key='feature').to_lazy_polars()
sorted = result.sort('some_id').select(['some_id', 'feature']).collect()

assert sorted.frame_equal(expected_df)
assert sorted.equals(expected_df)


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion aligned/jobs/tests/test_combined_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def test_combined_polars(
job = CombineFactualJob(
jobs=[retrival_job, retrival_job_with_timestamp], combined_requests=[combined_retrival_request]
)
data = (await job.to_polars()).collect()
data = (await job.to_lazy_polars()).collect()

assert set(data.columns) == {'id', 'a', 'b', 'c', 'd', 'created_at', 'c+d', 'a+c+d'}
assert data.shape[0] == 5
4 changes: 2 additions & 2 deletions aligned/jobs/tests/test_derived_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def feature_store() -> FeatureStore:
@pytest.mark.asyncio
async def test_aggregate_over_derived() -> None:

data = await IncomeAgg.query().all().to_polars()
data = await IncomeAgg.query().all().to_lazy_polars()

df = data.collect()

Expand All @@ -142,7 +142,7 @@ async def test_aggregate_over_derived_fact() -> None:

data = await store.features_for(
entities={'user_id': ['a', 'b']}, features=['income_agg:total_amount']
).to_polars()
).to_lazy_polars()

df = data.collect()

Expand Down
20 changes: 10 additions & 10 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def request_result(self) -> RequestResult:
async def to_pandas(self) -> pd.DataFrame:
return self.df.collect().to_pandas()

async def to_polars(self) -> pl.LazyFrame:
async def to_lazy_polars(self) -> pl.LazyFrame:
return self.df


Expand Down Expand Up @@ -217,10 +217,10 @@ async def file_transform_polars(self, df: pl.LazyFrame) -> pl.LazyFrame:
return df

async def to_pandas(self) -> pd.DataFrame:
return (await self.to_polars()).collect().to_pandas()
return (await self.to_lazy_polars()).collect().to_pandas()

async def to_polars(self) -> pl.LazyFrame:
file = await self.source.to_polars()
async def to_lazy_polars(self) -> pl.LazyFrame:
file = await self.source.to_lazy_polars()
return await self.file_transform_polars(file)


Expand Down Expand Up @@ -292,8 +292,8 @@ async def to_pandas(self) -> pd.DataFrame:
file = await self.source.read_pandas()
return self.file_transformations(file)

async def to_polars(self) -> pl.LazyFrame:
file = await self.source.to_polars()
async def to_lazy_polars(self) -> pl.LazyFrame:
file = await self.source.to_lazy_polars()
return self.file_transform_polars(file)


Expand Down Expand Up @@ -369,7 +369,7 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
for request in self.requests:
all_features.update(request.all_required_features)

result = await self.facts.to_polars()
result = await self.facts.to_lazy_polars()
event_timestamp_col = 'aligned_event_timestamp'

event_timestamp_entity_columns = [
Expand Down Expand Up @@ -489,10 +489,10 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
return result.select([pl.exclude('row_id')])

async def to_pandas(self) -> pd.DataFrame:
return (await self.to_polars()).collect().to_pandas()
return (await self.to_lazy_polars()).collect().to_pandas()

async def to_polars(self) -> pl.LazyFrame:
return await self.file_transformations(await self.source.to_polars())
async def to_lazy_polars(self) -> pl.LazyFrame:
return await self.file_transformations(await self.source.to_lazy_polars())

def log_each_job(self) -> RetrivalJob:
from aligned.retrival_job import LogJob
Expand Down
3 changes: 2 additions & 1 deletion aligned/local/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def test_file_full_job_polars(retrival_request_without_derived: RetrivalRe
}
)
job = FileFullJob(source=LiteralReference(frame), request=retrival_request_without_derived)
data = (await job.to_polars()).collect()
data = (await job.to_lazy_polars()).collect()

assert set(data.columns) == {'id', 'a', 'b'}
assert data.shape[0] == 5
Expand All @@ -45,3 +45,4 @@ async def test_write_and_read_feature_store(titanic_feature_store_scd: FeatureSt
await source.write(definition.to_json().encode('utf-8'))
store = await source.feature_store()
assert store is not None
assert store.model('titanic').model.predictions_view.acceptable_freshness is not None
10 changes: 5 additions & 5 deletions aligned/psql/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def will_load_list_feature(self) -> bool:
return False

async def to_pandas(self) -> pd.DataFrame:
df = await self.to_polars()
df = await self.to_lazy_polars()
return df.collect().to_pandas()

async def to_polars(self) -> pl.LazyFrame:
async def to_lazy_polars(self) -> pl.LazyFrame:
try:
return pl.read_database(self.query, self.config.url).lazy()
except Exception as e:
Expand Down Expand Up @@ -264,9 +264,9 @@ async def to_pandas(self) -> pd.DataFrame:
job = await self.psql_job()
return await job.to_pandas()

async def to_polars(self) -> pl.LazyFrame:
async def to_lazy_polars(self) -> pl.LazyFrame:
job = await self.psql_job()
return await job.to_polars()
return await job.to_lazy_polars()

async def psql_job(self) -> PostgreSqlJob:
if isinstance(self.facts, PostgreSqlJob):
Expand Down Expand Up @@ -492,7 +492,7 @@ def aggregated_values_from_request(self, request: RetrivalRequest) -> list[Table
return fetches

async def build_request(self) -> str:
facts = await self.facts.to_polars()
facts = await self.facts.to_lazy_polars()
return self.build_request_from_facts(facts)

def build_request_from_facts(self, facts: pl.LazyFrame) -> str:
Expand Down
6 changes: 3 additions & 3 deletions aligned/redis/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def retrival_requests(self) -> list[RetrivalRequest]:
return self.requests

async def to_pandas(self) -> pd.DataFrame:
return (await self.to_polars()).collect().to_pandas()
return (await self.to_lazy_polars()).collect().to_pandas()

def describe(self) -> str:
features_to_load = [list(request.all_feature_names) for request in self.requests]
return f'Loading features from Redis using HMGET {features_to_load}'

async def to_polars(self) -> pl.LazyFrame:
async def to_lazy_polars(self) -> pl.LazyFrame:
redis = self.config.redis()

result_df = (await self.facts.to_polars()).collect()
result_df = (await self.facts.to_lazy_polars()).collect()

for request in self.requests:
redis_combine_id = 'redis_combine_entity_id'
Expand Down
4 changes: 2 additions & 2 deletions aligned/redis/tests/test_redis_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,6 @@ async def test_write_job(mocker, retrival_request: RetrivalRequest) -> None: #
await source.insert(insert_facts, [retrival_request])

job = FactualRedisJob(RedisConfig.localhost(), requests=[retrival_request], facts=facts)
data = await job.to_polars()
data = await job.to_lazy_polars()

assert data.collect().select('x').to_series().series_equal(pl.Series('x', [1, 2, 3, None]))
assert data.collect().select('x').to_series().equals(pl.Series('x', [1, 2, 3, None]))
Loading

0 comments on commit ecd0173

Please sign in to comment.