diff --git a/aligned/data_source/batch_data_source.py b/aligned/data_source/batch_data_source.py index dc71ddc..634609c 100644 --- a/aligned/data_source/batch_data_source.py +++ b/aligned/data_source/batch_data_source.py @@ -1,4 +1,5 @@ from __future__ import annotations +from copy import copy from typing import TYPE_CHECKING, TypeVar, Any, Callable, Coroutine from dataclasses import dataclass @@ -801,8 +802,9 @@ class ColumnFeatureMappable: mapping_keys: dict[str, str] def with_renames(self: T, mapping_keys: dict[str, str]) -> T: - self.mapping_keys = mapping_keys # type: ignore - return self + new = copy(self) + new.mapping_keys = mapping_keys # type: ignore + return new def columns_for(self, features: list[Feature]) -> list[str]: return [self.mapping_keys.get(feature.name, feature.name) for feature in features] diff --git a/aligned/tests/test_model_target.py b/aligned/tests/test_model_target.py index 05a1c3e..a08c5be 100644 --- a/aligned/tests/test_model_target.py +++ b/aligned/tests/test_model_target.py @@ -99,6 +99,16 @@ class NewModel: assert feature.name == 'a' +def test_with_renames() -> None: + from aligned import FileSource + + source = FileSource.parquet_at('test_data/test_model.parquet').with_renames({'some_id': 'id'}) + other = source.with_renames({'other_id': 'id'}) + + assert source.mapping_keys == {'some_id': 'id'} + assert other.mapping_keys == {'other_id': 'id'} + + @pytest.mark.asyncio async def test_model_insert_predictions() -> None: """ @@ -136,6 +146,43 @@ class TestModel: assert preds.select(expected_frame.columns).equals(expected_frame) +@pytest.mark.asyncio +async def test_model_insert_predictions_csv() -> None: + """ + Test the insert (aka. ish append) method on the feature store. + """ + from aligned import FileSource, FeatureStore + + path = 'test_data/test_model.csv' + + @model_contract( + name='test_model', + features=[], + prediction_source=FileSource.csv_at(path).with_renames({'some_id': 'id'}), + ) + class TestModel: + id = Int32().as_entity() + + a = Int32() + + store = FeatureStore.experimental() + initial_frame = pl.DataFrame({'id': [1, 2, 3], 'a': [1, 2, 3]}) + initial_frame.write_csv(path) + + expected_frame = pl.DataFrame({'id': [1, 2, 3, 1, 2, 3], 'a': [10, 14, 20, 1, 2, 3]}) + + store.add_compiled_model(TestModel.compile()) # type: ignore + + await store.insert_into(FeatureLocation.model('test_model'), {'id': [1, 2, 3], 'a': [10, 14, 20]}) + + preds = await store.model('test_model').all_predictions().to_polars() + + stored_data = pl.read_csv(path).select(id=pl.col('some_id'), a=pl.col('a')) + assert stored_data.equals(expected_frame) + + assert preds.select(expected_frame.columns).equals(expected_frame) + + @pytest.mark.asyncio async def test_model_upsert_predictions() -> None: """ diff --git a/pyproject.toml b/pyproject.toml index e18fb32..27f723b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aligned" -version = "0.0.75" +version = "0.0.76" description = "A data managment and lineage tool for ML applications." authors = ["Mats E. Mollestad "] license = "Apache-2.0" diff --git a/test_data/test_model.csv b/test_data/test_model.csv new file mode 100644 index 0000000..9bfa2bd --- /dev/null +++ b/test_data/test_model.csv @@ -0,0 +1,7 @@ +some_id,a +1,10 +2,14 +3,20 +1,1 +2,2 +3,3 diff --git a/test_data/test_model.parquet b/test_data/test_model.parquet index 5bab2d3..4db5d82 100644 Binary files a/test_data/test_model.parquet and b/test_data/test_model.parquet differ