Skip to content

Commit

Permalink
Pyright changes and minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Oct 1, 2024
1 parent a892780 commit facd0f2
Show file tree
Hide file tree
Showing 20 changed files with 141 additions and 146 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mlruns
test_data/feature-store.json
test_data/mlruns
test_data/temp
test_data

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion aligned/compiler/aggregation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,6 @@ def compile(self) -> Transformation:
else:
return PolarsFunctionTransformation(
code=code,
function_name=dill.source.getname(self.method),
function_name=dill.source.getname(self.method), # type: ignore
dtype=self.dtype.dtype,
)
3 changes: 3 additions & 0 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,9 @@ class Timestamp(DateFeature, ArithmeticFeature):
def __init__(self, time_zone: str | None = 'UTC') -> None:
self.time_zone = time_zone

def defines_freshness(self) -> Timestamp:
return self.with_tag('freshness_timestamp')

@property
def dtype(self) -> FeatureType:
from zoneinfo import ZoneInfo
Expand Down
6 changes: 5 additions & 1 deletion aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,11 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int:
from aligned.schemas.transformation import MapArgMax

transformation = MapArgMax(
{probs._name: LiteralValue.from_value(probs.of_value) for probs in probabilities}
{
probs._name: LiteralValue.from_value(probs.of_value)
for probs in probabilities
if probs._name is not None
}
)

arg_max_feature = DerivedFeature(
Expand Down
12 changes: 12 additions & 0 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,18 @@ def random_values_for(feature: Feature, size: int, seed: int | None = None) -> p
else:
values = np.random.random(size)

if max_value is None and dtype.name.startswith('uint'):
bits = dtype.name.lstrip('uint')
if bits.isdigit():
max_value = 2 ** int(bits)
min_value = 0
elif max_value is None and dtype.name.startswith('int'):
bits = dtype.name.lstrip('int')
if bits.isdigit():
value_range = 2 ** int(bits) / 2
max_value = value_range
min_value = -value_range

if max_value and min_value:
values = values * (max_value - min_value) + min_value
elif max_value is not None:
Expand Down
3 changes: 2 additions & 1 deletion aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def feature_names(self) -> set[str]:
def unpack_feature(feature: str) -> tuple[FeatureLocation, str]:
splits = feature.split(':')
if len(splits) == 3:
return (FeatureLocation(splits[1], splits[0]), splits[2])
assert splits[0]
return (FeatureLocation(splits[1], splits[0]), splits[2]) # type: ignore
if len(splits) == 2:
return (FeatureLocation(splits[0], 'feature_view'), splits[1])
else:
Expand Down
10 changes: 5 additions & 5 deletions aligned/request/tests/test_feature_request_generation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest

from aligned.feature_view.feature_view import FeatureView
from aligned.feature_view.feature_view import FeatureViewWrapper


@pytest.mark.asyncio
async def test_fetch_all_request(titanic_feature_view: FeatureView) -> None:
async def test_fetch_all_request(titanic_feature_view: FeatureViewWrapper) -> None:

compiled_view = type(titanic_feature_view).compile()
compiled_view = titanic_feature_view.compile()
request = compiled_view.request_all

expected_features = {
Expand Down Expand Up @@ -35,9 +35,9 @@ async def test_fetch_all_request(titanic_feature_view: FeatureView) -> None:


@pytest.mark.asyncio
async def test_fetch_features_request(titanic_feature_view: FeatureView) -> None:
async def test_fetch_features_request(titanic_feature_view: FeatureViewWrapper) -> None:

compiled_view = type(titanic_feature_view).compile()
compiled_view = titanic_feature_view.compile()
wanted_features = {'cabin', 'is_male'}
request = compiled_view.request_for(wanted_features)
expected_features = {'sex', 'cabin', 'is_male'}
Expand Down
15 changes: 4 additions & 11 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,15 +1739,8 @@ async def compute_derived_features_pandas(self, df: pd.DataFrame) -> pd.DataFram

logger.debug(f'Computing feature with pandas: {feature.name}')
df[feature.name] = await feature.transformation.transform_pandas(
df[feature.depending_on_names]
df[feature.depending_on_names] # type: ignore
)
# if df[feature.name].dtype != feature.dtype.pandas_type:
# if feature.dtype.is_numeric:
# df[feature.name] = pd.to_numeric(df[feature.name], errors='coerce').astype(
# feature.dtype.pandas_type
# )
# else:
# df[feature.name] = df[feature.name].astype(feature.dtype.pandas_type)
return df

async def to_pandas(self) -> pd.DataFrame:
Expand Down Expand Up @@ -1989,7 +1982,7 @@ async def to_polars(self) -> AsyncIterator[pl.LazyFrame]:
df = raw_files[start:end, :]

chunked_job = (
LiteralRetrivalJob(df.lazy(), RequestResult.from_request_list(needed_requests))
LiteralRetrivalJob(df.lazy(), needed_requests)
.derive_features(needed_requests)
.select_columns(features_to_include_names)
)
Expand Down Expand Up @@ -2312,7 +2305,7 @@ async def combine_data(self, df: pd.DataFrame) -> pd.DataFrame:
continue
logger.debug(f'Computing feature: {feature.name}')
df[feature.name] = await feature.transformation.transform_pandas(
df[feature.depending_on_names]
df[feature.depending_on_names] # type: ignore
)
return df

Expand Down Expand Up @@ -2406,7 +2399,7 @@ async def to_pandas(self) -> pd.DataFrame:
df = await self.job.to_pandas()
if self.include_features:
total_list = list({ent.name for ent in self.request_result.entities}.union(self.include_features))
return df[total_list]
return df[total_list] # type: ignore
else:
return df

Expand Down
5 changes: 3 additions & 2 deletions aligned/sources/azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,9 @@ def delete_directory_recursively(directory_path: str) -> None:
fs.rmdir(directory_path)

upsert_on = sorted(request.entity_names)
returend_columns = request.all_returned_columns

df = await job.select(request.all_returned_columns).to_polars()
df = await job.select(returend_columns).to_polars()
unique_partitions = df.select(self.partition_keys).unique()

filters: list[pl.Expr] = []
Expand All @@ -590,7 +591,7 @@ def delete_directory_recursively(directory_path: str) -> None:

try:
existing_df = (await self.to_lazy_polars()).filter(*filters)
write_df = upsert_on_column(upsert_on, df.lazy(), existing_df).collect()
write_df = upsert_on_column(upsert_on, df.lazy(), existing_df).select(returend_columns).collect()
except (UnableToFindFileException, pl.ComputeError):
write_df = df.lazy()

Expand Down
7 changes: 4 additions & 3 deletions aligned/sources/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def to_lazy_polars(self) -> pl.LazyFrame:
schema = { # type: ignore
name: dtype.polars_type
for name, dtype in self.expected_schema.items()
if not dtype.is_datetime
if not dtype.is_datetime and dtype.name != 'bool'
}

if self.mapping_keys:
Expand Down Expand Up @@ -506,7 +506,8 @@ async def upsert(self, job: RetrivalJob, request: RetrivalRequest) -> None:

upsert_on = sorted(request.entity_names)

df = await job.select(request.all_returned_columns).to_polars()
returned_columns = request.all_returned_columns
df = await job.select(returned_columns).to_polars()
unique_partitions = df.select(self.partition_keys).unique()

filters: list[pl.Expr] = []
Expand All @@ -524,7 +525,7 @@ async def upsert(self, job: RetrivalJob, request: RetrivalRequest) -> None:

try:
existing_df = (await self.to_lazy_polars()).filter(*filters)
write_df = upsert_on_column(upsert_on, df.lazy(), existing_df).collect()
write_df = upsert_on_column(upsert_on, df.lazy(), existing_df).select(returned_columns).collect()
except (UnableToFindFileException, pl.ComputeError):
write_df = df.lazy()

Expand Down
2 changes: 1 addition & 1 deletion aligned/sources/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ async def test_read_csv(point_in_time_data_test: DataTest) -> None:
description=view.metadata.description,
batch_source=file_source,
)
compiled = view.compile_instance()
compiled = view.compile()
assert compiled.source.path == file_source.path

store.add_compiled_view(compiled)
Expand Down
20 changes: 10 additions & 10 deletions aligned/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ async def start(self) -> None:
processes = []
for topic_name, views in feature_views.items():
process_views = views
stream: StreamDataSource = views[0].view.stream_data_source
stream: StreamDataSource | None = views[0].view.stream_data_source
assert stream is not None
stream_consumer = stream.consumer(
self.read_timestamps.get(topic_name, self.default_start_timestamp)
)
Expand All @@ -204,10 +205,12 @@ async def start(self) -> None:

if not source:
logger.debug(f'Skipping to setup active learning set for {model_name}')

processes.append(
process_predictions(source.consumer(), store.model(model_name), active_learning_config)
)
else:
processes.append(
process_predictions(
source.consumer(), store.model(model_name), active_learning_config
)
)

if self.metric_logging_port:
start_http_server(self.metric_logging_port)
Expand All @@ -229,9 +232,6 @@ async def process_predictions(
logger.debug('No active learning config found, will not listen to predictions')
return

topic_name = model.model.predictions_view.stream_source.topic_name
logger.debug(f'Started listning to {topic_name}')

while True:
records = await stream_source.read()

Expand All @@ -240,7 +240,7 @@ async def process_predictions(
start_time = timeit.default_timer()

request = model.model.request_all_predictions.needed_requests[0]
job = RetrivalJob.from_dict(records, request).ensure_types([request])
job = RetrivalJob.from_dict(records, request).ensure_types([request]) # type: ignore
job = ActiveLearningJob(
job,
model.model,
Expand All @@ -264,7 +264,7 @@ def stream_job(values: list[dict], feature_view: FeatureViewStore) -> RetrivalJo
if isinstance(feature_view.view.stream_data_source, ColumnFeatureMappable):
mappings = feature_view.view.stream_data_source.mapping_keys

value_job = RetrivalJob.from_dict(values, request)
value_job = RetrivalJob.from_dict(values, request) # type: ignore

if mappings:
value_job = value_job.rename(mappings)
Expand Down
Loading

0 comments on commit facd0f2

Please sign in to comment.