Skip to content

Commit

Permalink
fix rename bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 4, 2024
1 parent 3e8912d commit b3db5a7
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 3 deletions.
6 changes: 4 additions & 2 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from copy import copy

from typing import TYPE_CHECKING, TypeVar, Any, Callable, Coroutine
from dataclasses import dataclass
Expand Down Expand Up @@ -801,8 +802,9 @@ class ColumnFeatureMappable:
mapping_keys: dict[str, str]

def with_renames(self: T, mapping_keys: dict[str, str]) -> T:
self.mapping_keys = mapping_keys # type: ignore
return self
new = copy(self)
new.mapping_keys = mapping_keys # type: ignore
return new

def columns_for(self, features: list[Feature]) -> list[str]:
return [self.mapping_keys.get(feature.name, feature.name) for feature in features]
Expand Down
47 changes: 47 additions & 0 deletions aligned/tests/test_model_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ class NewModel:
assert feature.name == 'a'


def test_with_renames() -> None:
from aligned import FileSource

source = FileSource.parquet_at('test_data/test_model.parquet').with_renames({'some_id': 'id'})
other = source.with_renames({'other_id': 'id'})

assert source.mapping_keys == {'some_id': 'id'}
assert other.mapping_keys == {'other_id': 'id'}


@pytest.mark.asyncio
async def test_model_insert_predictions() -> None:
"""
Expand Down Expand Up @@ -136,6 +146,43 @@ class TestModel:
assert preds.select(expected_frame.columns).equals(expected_frame)


@pytest.mark.asyncio
async def test_model_insert_predictions_csv() -> None:
"""
Test the insert (aka. ish append) method on the feature store.
"""
from aligned import FileSource, FeatureStore

path = 'test_data/test_model.csv'

@model_contract(
name='test_model',
features=[],
prediction_source=FileSource.csv_at(path).with_renames({'some_id': 'id'}),
)
class TestModel:
id = Int32().as_entity()

a = Int32()

store = FeatureStore.experimental()
initial_frame = pl.DataFrame({'id': [1, 2, 3], 'a': [1, 2, 3]})
initial_frame.write_csv(path)

expected_frame = pl.DataFrame({'id': [1, 2, 3, 1, 2, 3], 'a': [10, 14, 20, 1, 2, 3]})

store.add_compiled_model(TestModel.compile()) # type: ignore

await store.insert_into(FeatureLocation.model('test_model'), {'id': [1, 2, 3], 'a': [10, 14, 20]})

preds = await store.model('test_model').all_predictions().to_polars()

stored_data = pl.read_csv(path).select(id=pl.col('some_id'), a=pl.col('a'))
assert stored_data.equals(expected_frame)

assert preds.select(expected_frame.columns).equals(expected_frame)


@pytest.mark.asyncio
async def test_model_upsert_predictions() -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.75"
version = "0.0.76"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down
7 changes: 7 additions & 0 deletions test_data/test_model.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
some_id,a
1,10
2,14
3,20
1,1
2,2
3,3
Binary file modified test_data/test_model.parquet
Binary file not shown.

0 comments on commit b3db5a7

Please sign in to comment.