Skip to content

Commit

Permalink
fix: updated model wrapper logic so it can be referenced
Browse files Browse the repository at this point in the history
  • Loading branch information
Mats E. Mollestad committed Oct 16, 2023
1 parent 9092b6f commit 9d19b0a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
5 changes: 5 additions & 0 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class ModelContractWrapper(Generic[T]):
metadata: ModelMetadata
contract: Type[T]

def __call__(self) -> T:
# Needs to compiile the model to set the location for the view features
_ = self.compile()
return self.contract()

def compile(self) -> ModelSchema:
return ModelContract.compile_with_metadata(self.contract(), self.metadata)

Expand Down
37 changes: 35 additions & 2 deletions aligned/tests/test_model_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import polars as pl
import pytest

from aligned import FeatureStore
from aligned import FeatureStore, model_contract, String, Int32
from aligned.schemas.feature import FeatureLocation


@pytest.mark.asyncio
Expand Down Expand Up @@ -54,7 +55,7 @@ async def test_titanic_model_with_targets_and_scd(titanic_feature_store_scd: Fea
dataset = (
await titanic_feature_store_scd.model('titanic')
.with_labels()
.features_for(entities.to_dict())
.features_for(entities.to_dict(as_series=False))
.to_polars()
)

Expand All @@ -64,3 +65,35 @@ async def test_titanic_model_with_targets_and_scd(titanic_feature_store_scd: Fea
assert target_df['survived'].series_equal(expected_data['survived'])
assert input_df['is_male'].series_equal(expected_data['is_male'])
assert input_df['age'].series_equal(expected_data['age'])


@pytest.mark.asyncio
async def test_model_wrapper() -> None:
from aligned.compiler.model import ModelContractWrapper

@model_contract(
name='test_model',
features=[],
)
class TestModel:
id = Int32().as_entity()

a = Int32()

test_model_features = TestModel()

@model_contract(name='new_model', features=[test_model_features.a])
class NewModel:

id = Int32().as_entity()

x = String()

model_wrapper: ModelContractWrapper = NewModel
compiled = model_wrapper.compile()
assert len(compiled.features) == 1

feature = list(compiled.features)[0]

assert feature.location == FeatureLocation.model('test_model')
assert feature.name == 'a'

0 comments on commit 9d19b0a

Please sign in to comment.