Skip to content

Commit

Permalink
Fixed date formatting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 5, 2024
1 parent b3db5a7 commit 9d5737a
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 17 deletions.
20 changes: 10 additions & 10 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,18 +677,18 @@ async def insert_into(

columns = write_request.all_returned_columns

if isinstance(source, ColumnFeatureMappable):
new_cols = source.feature_identifier_for(columns)

mappings = dict(zip(columns, new_cols))
values = values.rename(mappings)
columns = new_cols
existing_df = (await source.to_lazy_polars()).rename(mappings)
else:
existing_df = await source.to_lazy_polars()

new_df = (await values.to_lazy_polars()).select(columns)
try:
if isinstance(source, ColumnFeatureMappable):
new_cols = source.feature_identifier_for(columns)

mappings = dict(zip(columns, new_cols))
values = values.rename(mappings)
columns = new_cols
existing_df = (await source.to_lazy_polars()).rename(mappings)
else:
existing_df = await source.to_lazy_polars()

write_df = pl.concat([new_df, existing_df.select(columns)], how='vertical_relaxed')
except UnableToFindFileException:
write_df = new_df
Expand Down
4 changes: 4 additions & 0 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,8 @@ async def to_pandas(self) -> pd.DataFrame:
)
elif feature.dtype == FeatureType.json():
pass
elif feature.dtype == FeatureType.datetime():
pass
else:
if feature.dtype.is_numeric:
df[feature.name] = pd.to_numeric(df[feature.name], errors='coerce').astype(
Expand Down Expand Up @@ -1815,6 +1817,8 @@ async def to_lazy_polars(self) -> pl.LazyFrame:
df = df.with_columns(pl.col(feature.name).str.json_extract(pl.List(pl.Utf8)))
elif feature.dtype == FeatureType.json():
pass
elif feature.dtype == FeatureType.datetime():
pass
else:
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type, strict=False))

Expand Down
7 changes: 5 additions & 2 deletions aligned/schemas/date_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,23 @@ def iso_8601() -> StringDateFormatter:
return StringDateFormatter('yyyy-MM-ddTHH:mm:ssZ')

@staticmethod
def unix_timestamp(time_unit: TimeUnit = 'us') -> Timestamp:
return Timestamp(time_unit)
def unix_timestamp(time_unit: TimeUnit = 'us', time_zone: str | None = 'UTC') -> Timestamp:
return Timestamp(time_unit, time_zone)


@dataclass
class Timestamp(DateFormatter):

time_unit: TimeUnit = field(default='us')
time_zone: str | None = field(default='UTC')

@classmethod
def name(cls) -> str:
return 'timestamp'

def decode_polars(self, column: str) -> pl.Expr:
if self.time_zone:
return pl.from_epoch(column, self.time_unit).dt.replace_time_zone(self.time_zone)
return pl.from_epoch(column, self.time_unit)

def encode_polars(self, column: str) -> pl.Expr:
Expand Down
69 changes: 65 additions & 4 deletions aligned/sources/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from aligned.data_file import DataFileReference, upsert_on_column
from aligned.data_source.batch_data_source import BatchDataSource, ColumnFeatureMappable
from aligned.enricher import CsvFileEnricher, Enricher, LoadedStatEnricher, StatisticEricher, TimespanSelector
from aligned.enricher import CsvFileEnricher, Enricher, LoadedStatEnricher, TimespanSelector
from aligned.exceptions import UnableToFindFileException
from aligned.local.job import FileDateJob, FileFactualJob, FileFullJob
from aligned.request.retrival_request import RetrivalRequest
Expand Down Expand Up @@ -97,7 +97,7 @@ class CsvConfig(Codable):


@dataclass
class CsvFileSource(BatchDataSource, ColumnFeatureMappable, StatisticEricher, DataFileReference):
class CsvFileSource(BatchDataSource, ColumnFeatureMappable, DataFileReference, WritableFeatureSource):
"""
A source pointing to a CSV file
"""
Expand Down Expand Up @@ -143,6 +143,60 @@ async def to_lazy_polars(self) -> pl.LazyFrame:
except OSError:
raise UnableToFindFileException(self.path)

async def upsert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
if len(requests) != 1:
raise ValueError('Csv files only support one write request as of now')

request = requests[0]

data = await job.to_lazy_polars()
potential_timestamps = request.all_features

if request.event_timestamp:
potential_timestamps.add(request.event_timestamp)

for feature in potential_timestamps:
if feature.dtype.name == 'datetime':
data = data.with_columns(self.formatter.encode_polars(feature.name))

if self.mapping_keys:
mapping = {self.mapping_keys.get(name, name): name for name in data.columns}
data = data.rename(mapping)

new_df = data.select(request.all_returned_columns)
entities = list(request.entity_names)
try:
existing_df = await self.to_lazy_polars()
write_df = upsert_on_column(entities, new_df, existing_df)
except UnableToFindFileException:
write_df = new_df

await self.write_polars(write_df)

async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
if len(requests) != 1:
raise ValueError('Csv files only support one write request as of now')

request = requests[0]

data = await job.to_lazy_polars()
for feature in request.features:
if feature.dtype.name == 'datetime':
data = data.with_columns(self.formatter.encode_polars(feature.name))

if self.mapping_keys:
mapping = {self.mapping_keys.get(name, name): name for name in data.columns}
data = data.rename(mapping)

try:
existing_df = await self.to_lazy_polars()

write_df = pl.concat([data, existing_df.select(data.columns)], how='vertical_relaxed')
except UnableToFindFileException:
write_df = data

await self.write_polars(write_df.select(request.all_returned_columns))

async def write_pandas(self, df: pd.DataFrame) -> None:
create_parent_dir(self.path)
df.to_csv(
Expand Down Expand Up @@ -180,12 +234,18 @@ def enricher(self) -> CsvFileEnricher:
return CsvFileEnricher(file=self.path)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
return FileFullJob(self, request, limit)
return FileFullJob(self, request, limit, date_formatter=self.formatter)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
) -> RetrivalJob:
return FileDateJob(source=self, request=request, start_date=start_date, end_date=end_date)
return FileDateJob(
source=self,
request=request,
start_date=start_date,
end_date=end_date,
date_formatter=self.formatter,
)

@classmethod
def multi_source_features_for(
Expand All @@ -204,6 +264,7 @@ def multi_source_features_for(
source=source,
requests=[request for _, request in requests],
facts=facts,
date_formatter=source.formatter,
)

async def schema(self) -> dict[str, FeatureFactory]:
Expand Down
47 changes: 47 additions & 0 deletions aligned/sources/tests/test_parquet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
import polars as pl
from pathlib import Path

from aligned import FeatureStore, FileSource
from aligned.feature_view.feature_view import FeatureView
from aligned.schemas.date_formatter import DateFormatter
from conftest import DataTest


Expand Down Expand Up @@ -100,3 +103,47 @@ async def test_parquet_without_event_timestamp(

ordered_columns = data.select(expected.columns)
assert ordered_columns.equals(expected), f'Expected: {expected}\nGot: {ordered_columns}'


@pytest.mark.asyncio
async def test_read_csv(point_in_time_data_test: DataTest) -> None:

store = FeatureStore.experimental()

for source in point_in_time_data_test.sources:
view = source.view
view_name = view.metadata.name
if '_agg' in view_name:
continue

file_source = FileSource.csv_at(
f'test_data/{view_name}.csv', date_formatter=DateFormatter.unix_timestamp()
)

view.metadata = FeatureView.metadata_with( # type: ignore
name=view.metadata.name,
description=view.metadata.description,
batch_source=file_source,
)
compiled = view.compile_instance()
assert compiled.source.path == file_source.path # type: ignore

store.add_compiled_view(compiled)

Path(file_source.path).unlink(missing_ok=True)

await store.feature_view(compiled.name).insert(
store.feature_view(compiled.name).process_input(source.data)
)

csv = pl.read_csv(file_source.path)
schemas = dict(csv.schema)

for feature in view.compile().request_all.request_result.features:
if feature.dtype.name == 'datetime':
assert schemas[feature.name].is_numeric()

# Polars
stored = await store.feature_view(compiled.name).all().to_polars()
df = stored.select(source.data.columns)
assert df.equals(source.data)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.76"
version = "0.0.77"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down
7 changes: 7 additions & 0 deletions test_data/credit_history.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
due_sum,dob_ssn,credit_card_due,bankruptcies,student_loan_due,event_timestamp
30747,19530219_5179,8419,0,22328,1587924064746575
5459,19520816_8737,2944,0,2515,1587924064746575
33833,19860413_2537,833,0,33000,1587924064746575
54891,19530219_5179,5936,0,48955,1588010464746575
11076,19520816_8737,1575,0,9501,1588010464746575
41773,19860413_2537,6263,0,35510,1588010464746575
7 changes: 7 additions & 0 deletions test_data/loan.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
loan_amount,loan_status,loan_id,event_timestamp,personal_income
35000,True,10000,1587924064746575,59000
1000,False,10001,1587924064746575,9600
5500,True,10002,1587924064746575,9600
35000,True,10000,1588010464746575,65500
35000,True,10001,1588010464746575,54400
2500,True,10002,1588010464746575,9900

0 comments on commit 9d5737a

Please sign in to comment.