Skip to content

Commit

Permalink
Merge pull request #18 from MatsMoll/matsei/chore-cleanup
Browse files Browse the repository at this point in the history
chore: A lot of cleanup in the FeatureType
  • Loading branch information
MatsMoll authored Nov 6, 2023
2 parents 02508c6 + caddd02 commit ed69ece
Show file tree
Hide file tree
Showing 28 changed files with 566 additions and 378 deletions.
59 changes: 5 additions & 54 deletions aligned/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
import logging
import os
import sys
from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime
from functools import wraps
from pathlib import Path
from typing import Any
from typing import Any, TYPE_CHECKING

import click
from pytz import utc # type: ignore

from aligned.compiler.repo_reader import RepoReader, RepoReference
from aligned.feature_source import WritableFeatureSource
from aligned.schemas.codable import Codable
from aligned.schemas.feature import Feature
from aligned.schemas.repo_definition import RepoDefinition
from aligned.worker import StreamWorker

if TYPE_CHECKING:
from collections.abc import Callable
from datetime import datetime


def coro(func: Callable) -> Callable:
@wraps(func)
Expand Down Expand Up @@ -265,55 +265,6 @@ async def serve_worker_command(repo_path: str, worker_path: str, env_file: str)
await worker.start()


@cli.command('materialize')
@coro
@click.option(
'--repo-path',
default='.',
help='The path to the repo',
)
@click.option(
'--env-file',
default='.env',
help='The path to env variables',
)
@click.option(
'--days',
help='The number of days to materialize',
)
@click.option(
'--view',
help='The feature view to materialize',
)
async def materialize_command(repo_path: str, env_file: str, days: str, view: str) -> None:
"""
Materializes the feature store
"""
from aligned.feature_store import FeatureStore

dir = Path.cwd() if repo_path == '.' else Path(repo_path).absolute()
load_envs(dir / env_file)

sys.path.append(str(dir))
repo_def = await RepoDefinition.from_path(repo_path)
store = FeatureStore.from_definition(repo_def)
batch_store = store.offline_store()

if not isinstance(store.feature_source, WritableFeatureSource):
raise ValueError('Batch feature sources are not supported for materialization')

number_of_days = int(days)
views = [view] if view else list(store.feature_views.keys())

click.echo(f'Materializing the last {number_of_days} days')
for feature_view in views:
fv_store = batch_store.feature_view(feature_view)
click.echo(f'Materializing {feature_view}')
await store.feature_source.write(
fv_store.previous(days=number_of_days), fv_store.view.request_all.needed_requests
)


@dataclass
class CategoricalFeatureSummary(Codable):
missing_percentage: float
Expand Down
34 changes: 17 additions & 17 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __set_name__(self, owner, name):
def compile(self) -> ClassTargetProbability:
return ClassTargetProbability(
outcome=LiteralValue.from_value(self.of_value),
feature=Feature(self._name, dtype=FeatureType('').float),
feature=Feature(self._name, dtype=FeatureType.float()),
)


Expand Down Expand Up @@ -133,7 +133,7 @@ def listen_to_ground_truth_event(self, stream: StreamDataSource) -> RegressionLa
)

def send_ground_truth_event(self, when: Bool, sink_to: StreamDataSource) -> RegressionLabel:
assert when.dtype == FeatureType('').bool, 'A trigger needs a boolean condition'
assert when.dtype == FeatureType.bool(), 'A trigger needs a boolean condition'

return RegressionLabel(
self.feature, EventTrigger(when, sink_to), ground_truth_event=self.ground_truth_event
Expand Down Expand Up @@ -188,7 +188,7 @@ def listen_to_ground_truth_event(self, stream: StreamDataSource) -> Classificati
)

def send_ground_truth_event(self, when: Bool, sink_to: StreamDataSource) -> ClassificationLabel:
assert when.dtype == FeatureType('').bool, 'A trigger needs a boolean condition'
assert when.dtype == FeatureType.bool(), 'A trigger needs a boolean condition'

return ClassificationLabel(self.feature, EventTrigger(when, sink_to))

Expand Down Expand Up @@ -534,7 +534,7 @@ def __sub__(self, other: FeatureFactory) -> Float:
from aligned.compiler.transformation_factory import DifferanceBetweenFactory, TimeDifferanceFactory

feature = Float()
if self.dtype == FeatureType('').datetime:
if self.dtype == FeatureType.datetime():
feature.transformation = TimeDifferanceFactory(self, other)
else:
feature.transformation = DifferanceBetweenFactory(self, other)
Expand Down Expand Up @@ -754,7 +754,7 @@ def day_of_year(self) -> Int32:
class Bool(EquatableFeature, LogicalOperatableFeature):
@property
def dtype(self) -> FeatureType:
return FeatureType('').bool
return FeatureType.bool()

def copy_type(self) -> Bool:
return Bool()
Expand All @@ -766,7 +766,7 @@ def copy_type(self) -> Float:

@property
def dtype(self) -> FeatureType:
return FeatureType('').float
return FeatureType.float()

def aggregate(self) -> ArithmeticAggregation:
return ArithmeticAggregation(self)
Expand All @@ -778,7 +778,7 @@ def copy_type(self) -> Int32:

@property
def dtype(self) -> FeatureType:
return FeatureType('').int32
return FeatureType.int32()

def aggregate(self) -> ArithmeticAggregation:
return ArithmeticAggregation(self)
Expand All @@ -790,7 +790,7 @@ def copy_type(self) -> Int64:

@property
def dtype(self) -> FeatureType:
return FeatureType('').int64
return FeatureType.int64()

def aggregate(self) -> ArithmeticAggregation:
return ArithmeticAggregation(self)
Expand All @@ -802,7 +802,7 @@ def copy_type(self) -> UUID:

@property
def dtype(self) -> FeatureType:
return FeatureType('').uuid
return FeatureType.uuid()

def aggregate(self) -> CategoricalAggregation:
return CategoricalAggregation(self)
Expand Down Expand Up @@ -851,7 +851,7 @@ def copy_type(self) -> String:

@property
def dtype(self) -> FeatureType:
return FeatureType('').string
return FeatureType.string()

def aggregate(self) -> StringAggregation:
return StringAggregation(self)
Expand Down Expand Up @@ -924,7 +924,7 @@ def copy_type(self: Json) -> Json:

@property
def dtype(self) -> FeatureType:
return FeatureType('').string
return FeatureType.string()

def json_path_value_at(self, path: str, as_type: T) -> T:
from aligned.compiler.transformation_factory import JsonPathFactory
Expand Down Expand Up @@ -967,7 +967,7 @@ def aggregate(self) -> CategoricalAggregation:
class Timestamp(DateFeature, ArithmeticFeature):
@property
def dtype(self) -> FeatureType:
return FeatureType('').datetime
return FeatureType.datetime()


class EventTimestamp(DateFeature, ArithmeticFeature):
Expand All @@ -976,7 +976,7 @@ class EventTimestamp(DateFeature, ArithmeticFeature):

@property
def dtype(self) -> FeatureType:
return FeatureType('').datetime
return FeatureType.datetime()

def __init__(self, ttl: timedelta | None = None):
self.ttl = ttl
Expand All @@ -1000,7 +1000,7 @@ def copy_type(self) -> Embedding:

@property
def dtype(self) -> FeatureType:
return FeatureType('').embedding
return FeatureType.embedding()

def indexed(
self,
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def copy_type(self) -> List:

@property
def dtype(self) -> FeatureType:
return FeatureType('').array
return FeatureType.array()

def contains(self, value: Any) -> Bool:
from aligned.compiler.transformation_factory import ArrayContainsFactory
Expand All @@ -1047,7 +1047,7 @@ def contains(self, value: Any) -> Bool:
class ImageUrl(StringValidatable):
@property
def dtype(self) -> FeatureType:
return FeatureType('').string
return FeatureType.string()

def copy_type(self) -> ImageUrl:
return ImageUrl()
Expand All @@ -1063,7 +1063,7 @@ def load_image(self) -> Image:
class Image(FeatureFactory):
@property
def dtype(self) -> FeatureType:
return FeatureType('').array
return FeatureType.array()

def copy_type(self) -> Image:
return Image()
Expand Down
5 changes: 3 additions & 2 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class MyModel(ModelContract):
derived_features=set(),
model_version_column=None,
source=metadata.predictions_source,
historical_source=metadata.historical_source,
stream_source=metadata.predictions_stream,
classification_targets=set(),
regression_targets=set(),
Expand Down Expand Up @@ -229,7 +230,7 @@ class MyModel(ModelContract):
inference_view.features.add(
Feature(
var_name,
FeatureType('').float,
FeatureType.float(),
f"The probability of target named {feature_name} being '{feature.of_value}'.",
)
)
Expand All @@ -256,7 +257,7 @@ class MyModel(ModelContract):
dtype=transformation.dtype,
transformation=transformation,
depending_on={
FeatureReferance(feat, FeatureLocation.model(metadata.name), dtype=FeatureType('').float)
FeatureReferance(feat, FeatureLocation.model(metadata.name), dtype=FeatureType.float())
for feat in transformation.column_mappings.keys()
},
depth=1,
Expand Down
11 changes: 11 additions & 0 deletions aligned/data_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
import polars as pl


def upsert_on_column(columns: list[str], new_data: pl.LazyFrame, existing_data: pl.LazyFrame) -> pl.LazyFrame:

column_diff = set(new_data.columns).difference(existing_data.columns)

if column_diff:
raise ValueError(f'Mismatching columns, missing columns {column_diff}.')

combined = pl.concat([new_data, existing_data.select(new_data.columns)])
return combined.unique(columns, keep='first')


class DataFileReference:
"""
A reference to a data file.
Expand Down
33 changes: 27 additions & 6 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass

from mashumaro.types import SerializableType
from aligned.data_file import DataFileReference

from aligned.schemas.codable import Codable
from aligned.schemas.derivied_feature import DerivedFeature
Expand All @@ -24,7 +25,7 @@ class BatchDataSourceFactory:
_shared: BatchDataSourceFactory | None = None

def __init__(self) -> None:
from aligned.sources.local import CsvFileSource, ParquetFileSource
from aligned.sources.local import CsvFileSource, ParquetFileSource, DeltaFileSource
from aligned.sources.psql import PostgreSQLDataSource
from aligned.sources.redshift import RedshiftSQLDataSource
from aligned.sources.s3 import AwsS3CsvDataSource, AwsS3ParquetDataSource
Expand All @@ -33,6 +34,7 @@ def __init__(self) -> None:
PostgreSQLDataSource,
ParquetFileSource,
CsvFileSource,
DeltaFileSource,
AwsS3CsvDataSource,
AwsS3ParquetDataSource,
RedshiftSQLDataSource,
Expand Down Expand Up @@ -128,6 +130,12 @@ def _deserialize(cls, value: dict) -> BatchDataSource:
def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
if isinstance(self, BatchSourceModification):
return self.wrap_job(self.source.all_data(request, limit))

if isinstance(self, DataFileReference):
from aligned.local.job import FileFullJob

return FileFullJob(self, request=request, limit=limit)

raise NotImplementedError()

def all_between_dates(
Expand All @@ -136,22 +144,35 @@ def all_between_dates(
start_date: datetime,
end_date: datetime,
) -> RetrivalJob:

if isinstance(self, BatchSourceModification):
return self.wrap_job(self.source.all_between_dates(request, start_date, end_date))

if isinstance(self, DataFileReference):
from aligned.local.job import FileDateJob

return FileDateJob(self, request=request, start_date=start_date, end_date=end_date)

raise NotImplementedError()

@classmethod
def multi_source_features_for(
cls: type[T], facts: RetrivalJob, requests: list[tuple[T, RetrivalRequest]]
) -> RetrivalJob:
if len(requests) != 1:
raise NotImplementedError()

source, _ = requests[0]
if not isinstance(source, BatchSourceModification):
sources = {source for source, _ in requests}
if len(sources) != 1:
raise NotImplementedError()

return source.wrap_job(type(source.source).multi_source_features_for(facts, requests))
source, _ = requests[0]
if isinstance(source, BatchSourceModification):
return source.wrap_job(type(source.source).multi_source_features_for(facts, requests))
elif isinstance(source, DataFileReference):
from aligned.local.job import FileFactualJob

return FileFactualJob(source, [request for _, request in requests], facts)
else:
raise NotImplementedError(f'Type: {cls} have not implemented how to load fact data')

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return type(self).multi_source_features_for(facts, [(self, request)])
Expand Down
Loading

0 comments on commit ed69ece

Please sign in to comment.