Skip to content

Commit

Permalink
Improved datetime handeling
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Feb 3, 2024
1 parent 6a36dcf commit 5445708
Show file tree
Hide file tree
Showing 13 changed files with 293 additions and 119 deletions.
5 changes: 5 additions & 0 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation
from aligned.request.retrival_request import RequestResult, RetrivalRequest
from aligned.compiler.feature_factory import FeatureFactory
from polars.type_aliases import TimeUnit

if TYPE_CHECKING:
from aligned.retrival_job import RetrivalJob
Expand Down Expand Up @@ -505,6 +506,8 @@ class JoinAsofDataSource(BatchDataSource):
left_on: list[str] | None = None
right_on: list[str] | None = None

timestamp_unit: TimeUnit = 'us'

type_name: str = 'join_asof'

def job_group_key(self) -> str:
Expand All @@ -525,6 +528,7 @@ def all_with_limit(self, limit: int | None) -> RetrivalJob:
right_event_timestamp=self.right_event_timestamp,
left_on=self.left_on,
right_on=self.right_on,
timestamp_unit=self.timestamp_unit,
)
)

Expand All @@ -543,6 +547,7 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
right_event_timestamp=self.right_event_timestamp,
left_on=self.left_on,
right_on=self.right_on,
timestamp_unit=self.timestamp_unit,
)
.aggregate(request)
.derive_features([request])
Expand Down
35 changes: 34 additions & 1 deletion aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from aligned.request.retrival_request import AggregatedFeature, AggregateOver, RetrivalRequest
from aligned.retrival_job import RequestResult, RetrivalJob
from aligned.schemas.feature import Feature
from aligned.schemas.date_formatter import DateFormatter
from aligned.schemas.feature import Feature, FeatureType
from aligned.sources.local import DataFileReference


Expand Down Expand Up @@ -120,12 +121,39 @@ async def aggregate(request: RetrivalRequest, core_data: pl.LazyFrame) -> pl.Laz
return results


def decode_timestamps(df: pl.LazyFrame, request: RetrivalRequest, formatter: DateFormatter) -> pl.LazyFrame:

columns: set[str] = set()
dtypes = dict(zip(df.columns, df.dtypes))

for feature in request.all_features:
if (
feature.dtype == FeatureType.datetime
and feature.name in df.columns
and not isinstance(dtypes[feature.name], pl.Datetime)
):
columns.add(feature.name)

if (
request.event_timestamp
and request.event_timestamp.name in df.columns
and not isinstance(dtypes[request.event_timestamp.name], pl.Datetime)
):
columns.add(request.event_timestamp.name)

if not columns:
return df

return df.with_columns([formatter.decode_polars(column).alias(column) for column in columns])


@dataclass
class FileFullJob(RetrivalJob):

source: DataFileReference
request: RetrivalRequest
limit: int | None = field(default=None)
date_formatter: DateFormatter = field(default=DateFormatter.iso_8601())

@property
def request_result(self) -> RequestResult:
Expand Down Expand Up @@ -178,6 +206,7 @@ async def file_transform_polars(self, df: pl.LazyFrame) -> pl.LazyFrame:
if org_name != wanted_name
}
df = df.rename(mapping=renames)
df = decode_timestamps(df, self.request, self.date_formatter)

if self.request.aggregated_features:
df = await aggregate(self.request, df)
Expand All @@ -202,6 +231,7 @@ class FileDateJob(RetrivalJob):
request: RetrivalRequest
start_date: datetime
end_date: datetime
date_formatter: DateFormatter = field(default=DateFormatter.iso_8601())

@property
def request_result(self) -> RequestResult:
Expand Down Expand Up @@ -250,6 +280,7 @@ def file_transform_polars(self, df: pl.LazyFrame) -> pl.LazyFrame:

df = df.rename(mapping=dict(zip(request_features, all_names)))
event_timestamp_column = self.request.event_timestamp.name
df = decode_timestamps(df, self.request, self.date_formatter)

return df.filter(pl.col(event_timestamp_column).is_between(self.start_date, self.end_date))

Expand Down Expand Up @@ -302,6 +333,7 @@ class FileFactualJob(RetrivalJob):
source: DataFileReference | RetrivalJob
requests: list[RetrivalRequest]
facts: RetrivalJob
date_formatter: DateFormatter = field(default=DateFormatter.iso_8601())

@property
def request_result(self) -> RequestResult:
Expand Down Expand Up @@ -387,6 +419,7 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
if isinstance(self.source, ColumnFeatureMappable):
request_features = self.source.feature_identifier_for(all_names)

df = decode_timestamps(df, request, self.date_formatter)
feature_df = df.select(request_features)

renames = {
Expand Down
45 changes: 35 additions & 10 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from aligned.schemas.date_formatter import DateFormatter

import asyncio
import logging
Expand All @@ -12,6 +13,7 @@

import pandas as pd
import polars as pl
from polars.type_aliases import TimeUnit
from prometheus_client import Histogram

from aligned.exceptions import UnableToFindFileException
Expand Down Expand Up @@ -493,6 +495,7 @@ def join_asof(
right_event_timestamp: str | None = None,
left_on: str | list[str] | None = None,
right_on: str | list[str] | None = None,
timestamp_unit: TimeUnit = 'us',
) -> RetrivalJob:

if isinstance(left_on, str):
Expand All @@ -518,6 +521,7 @@ def join_asof(
right_event_timestamp=right_event_timestamp,
left_on=left_on,
right_on=right_on,
timestamp_unit=timestamp_unit,
)

def join(
Expand Down Expand Up @@ -610,11 +614,15 @@ def derive_features(self, requests: list[RetrivalRequest] | None = None) -> Retr
def combined_features(self, requests: list[RetrivalRequest] | None = None) -> RetrivalJob:
return CombineFactualJob([self], requests or self.retrival_requests)

def ensure_types(self, requests: list[RetrivalRequest] | None = None) -> RetrivalJob:
def ensure_types(
self, requests: list[RetrivalRequest] | None = None, date_formatter: DateFormatter | None = None
) -> RetrivalJob:
if not requests:
requests = self.retrival_requests

return EnsureTypesJob(job=self, requests=requests)
return EnsureTypesJob(
job=self, requests=requests, date_formatter=date_formatter or DateFormatter.iso_8601()
)

def select_columns(self, include_features: set[str]) -> RetrivalJob:
return SelectColumnsJob(include_features, self)
Expand Down Expand Up @@ -800,6 +808,21 @@ def describe(self) -> str:
return f'OnLoadJob {self.on_load} -> {self.job.describe()}'


@dataclass
class EncodeDatesJob(RetrivalJob, ModificationJob):

job: RetrivalJob
formatter: DateFormatter
columns: list[str]

async def to_polars(self) -> pl.LazyFrame:
data = await self.job.to_polars()
return data.with_columns([self.formatter.encode_polars(column) for column in self.columns])

async def to_pandas(self) -> pd.DataFrame:
return (await self.to_polars()).collect().to_pandas()


@dataclass
class InMemoryCacheJob(RetrivalJob, ModificationJob):

Expand Down Expand Up @@ -872,6 +895,8 @@ class JoinAsofJob(RetrivalJob):
left_on: list[str] | None
right_on: list[str] | None

timestamp_unit: TimeUnit = field(default='us')

@property
def request_result(self) -> RequestResult:
return RequestResult.from_result_list([self.left_job.request_result, self.right_job.request_result])
Expand All @@ -884,8 +909,10 @@ async def to_polars(self) -> pl.LazyFrame:
left = await self.left_job.to_polars()
right = await self.right_job.to_polars()

return left.join_asof(
right,
return left.with_columns(
pl.col(self.left_event_timestamp).dt.cast_time_unit(self.timestamp_unit),
).join_asof(
right.with_columns(pl.col(self.right_event_timestamp).dt.cast_time_unit(self.timestamp_unit)),
by_left=self.left_on,
by_right=self.right_on,
left_on=self.left_event_timestamp,
Expand Down Expand Up @@ -1698,6 +1725,7 @@ class EnsureTypesJob(RetrivalJob, ModificationJob):

job: RetrivalJob
requests: list[RetrivalRequest]
date_formatter: DateFormatter = field(default_factory=DateFormatter.iso_8601)

@property
def request_result(self) -> RequestResult:
Expand Down Expand Up @@ -1771,14 +1799,11 @@ async def to_polars(self) -> pl.LazyFrame:
df = df.with_columns(pl.col(feature.name).cast(pl.Int8).cast(pl.Boolean))
elif feature.dtype == FeatureType.datetime():
current_dtype = df.select([feature.name]).dtypes[0]

if isinstance(current_dtype, pl.Datetime):
continue
# Convert from ms to us
df = df.with_columns(
(pl.col(feature.name).cast(pl.Int64) * 1000)
.cast(pl.Datetime(time_zone='UTC'))
.alias(feature.name)
)

df = df.with_columns(self.date_formatter.decode_polars(feature.name))
elif (feature.dtype == FeatureType.array()) or (feature.dtype == FeatureType.embedding()):
dtype = df.select(feature.name).dtypes[0]
if dtype == pl.Utf8:
Expand Down
101 changes: 101 additions & 0 deletions aligned/schemas/date_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations
from dataclasses import dataclass, field
import polars as pl
from polars.type_aliases import TimeUnit
from aligned.schemas.codable import Codable
from mashumaro.types import SerializableType


@dataclass
class AllDateFormatters:

supported_formatters: dict[str, type[DateFormatter]]

_shared: AllDateFormatters | None = None

@classmethod
def shared(cls) -> AllDateFormatters:
if cls._shared is None:
formatters = [
Timestamp,
StringDateFormatter,
]
cls._shared = AllDateFormatters({formatter.name(): formatter for formatter in formatters})
return cls._shared


class DateFormatter(Codable, SerializableType):
@classmethod
def name(cls) -> str:
raise NotImplementedError(cls)

def decode_polars(self, column: str) -> pl.Expr:
raise NotImplementedError(type(self))

def encode_polars(self, column: str) -> pl.Expr:
raise NotImplementedError(type(self))

def _serialize(self) -> dict:
assert type(self).name() in AllDateFormatters.shared().supported_formatters
data = self.to_dict()
data['name'] = type(self).name()
return data

@classmethod
def _deserialize(cls, data: dict) -> DateFormatter:
formatter_name = data.pop('name')
formatters = AllDateFormatters.shared().supported_formatters
if formatter_name not in formatters:
raise ValueError(
f"Unknown formatter name: {formatter_name}. Supported formatters: {formatters.keys()}"
)
formatter_class = formatters[formatter_name]
return formatter_class.from_dict(data)

@staticmethod
def string_format(format: str) -> StringDateFormatter:
return StringDateFormatter(format)

@staticmethod
def iso_8601() -> StringDateFormatter:
return StringDateFormatter('yyyy-MM-ddTHH:mm:ssZ')

@staticmethod
def unix_timestamp(time_unit: TimeUnit = 'us') -> Timestamp:
return Timestamp(time_unit)


@dataclass
class Timestamp(DateFormatter):

time_unit: TimeUnit = field(default='us')

@classmethod
def name(cls) -> str:
return 'timestamp'

def decode_polars(self, column: str) -> pl.Expr:
return pl.from_epoch(column, self.time_unit)

def encode_polars(self, column: str) -> pl.Expr:
return pl.col(column).dt.timestamp(self.time_unit)


@dataclass
class StringDateFormatter(DateFormatter):

date_format: str
time_unit: TimeUnit | None = field(default=None)
time_zone: str | None = field(default=None)

@classmethod
def name(cls) -> str:
return 'string_form'

def decode_polars(self, column: str) -> pl.Expr:
return pl.col(column).str.to_datetime(
self.date_format, time_unit=self.time_unit, time_zone=self.time_zone
)

def encode_polars(self, column: str) -> pl.Expr:
return pl.col(column).dt.strftime(self.date_format)
14 changes: 12 additions & 2 deletions aligned/sources/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from aligned.storage import Storage
from aligned.feature_store import FeatureStore
from aligned.feature_source import WritableFeatureSource
from aligned.schemas.date_formatter import DateFormatter

if TYPE_CHECKING:
from aligned.compiler.feature_factory import FeatureFactory
Expand Down Expand Up @@ -100,6 +101,7 @@ class CsvFileSource(BatchDataSource, ColumnFeatureMappable, StatisticEricher, Da
path: str
mapping_keys: dict[str, str] = field(default_factory=dict)
csv_config: CsvConfig = field(default_factory=CsvConfig)
formatter: DateFormatter = field(default_factory=DateFormatter.iso_8601)

type_name: str = 'csv'

Expand Down Expand Up @@ -489,9 +491,17 @@ def json_at(path: str) -> StorageFileSource:

@staticmethod
def csv_at(
path: str, mapping_keys: dict[str, str] | None = None, csv_config: CsvConfig | None = None
path: str,
mapping_keys: dict[str, str] | None = None,
csv_config: CsvConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> CsvFileSource:
return CsvFileSource(path, mapping_keys=mapping_keys or {}, csv_config=csv_config or CsvConfig())
return CsvFileSource(
path,
mapping_keys=mapping_keys or {},
csv_config=csv_config or CsvConfig(),
formatter=date_formatter or DateFormatter.iso_8601(),
)

@staticmethod
def parquet_at(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.66"
version = "0.0.67"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down
Binary file modified test_data/credit_history_mater.parquet
Binary file not shown.
Loading

0 comments on commit 5445708

Please sign in to comment.