diff --git a/aligned/compiler/model.py b/aligned/compiler/model.py index bb579e6..8bdafdd 100644 --- a/aligned/compiler/model.py +++ b/aligned/compiler/model.py @@ -77,6 +77,11 @@ class ModelContractWrapper(Generic[T]): metadata: ModelMetadata contract: Type[T] + def __call__(self) -> T: + # Needs to compiile the model to set the location for the view features + _ = self.compile() + return self.contract() + def compile(self) -> ModelSchema: return ModelContract.compile_with_metadata(self.contract(), self.metadata) diff --git a/aligned/tests/test_model_target.py b/aligned/tests/test_model_target.py index c106b5f..320a3a3 100644 --- a/aligned/tests/test_model_target.py +++ b/aligned/tests/test_model_target.py @@ -4,7 +4,8 @@ import polars as pl import pytest -from aligned import FeatureStore +from aligned import FeatureStore, model_contract, String, Int32 +from aligned.schemas.feature import FeatureLocation @pytest.mark.asyncio @@ -54,7 +55,7 @@ async def test_titanic_model_with_targets_and_scd(titanic_feature_store_scd: Fea dataset = ( await titanic_feature_store_scd.model('titanic') .with_labels() - .features_for(entities.to_dict()) + .features_for(entities.to_dict(as_series=False)) .to_polars() ) @@ -64,3 +65,35 @@ async def test_titanic_model_with_targets_and_scd(titanic_feature_store_scd: Fea assert target_df['survived'].series_equal(expected_data['survived']) assert input_df['is_male'].series_equal(expected_data['is_male']) assert input_df['age'].series_equal(expected_data['age']) + + +@pytest.mark.asyncio +async def test_model_wrapper() -> None: + from aligned.compiler.model import ModelContractWrapper + + @model_contract( + name='test_model', + features=[], + ) + class TestModel: + id = Int32().as_entity() + + a = Int32() + + test_model_features = TestModel() + + @model_contract(name='new_model', features=[test_model_features.a]) + class NewModel: + + id = Int32().as_entity() + + x = String() + + model_wrapper: ModelContractWrapper = NewModel + compiled = model_wrapper.compile() + assert len(compiled.features) == 1 + + feature = list(compiled.features)[0] + + assert feature.location == FeatureLocation.model('test_model') + assert feature.name == 'a'