From 758d9c7d5906a27d7c35b6664d10fbd8a6b5e882 Mon Sep 17 00:00:00 2001 From: Steve Lorello <42971704+slorello89@users.noreply.github.com> Date: Tue, 14 Jan 2025 13:29:20 -0500 Subject: [PATCH] fixing issue with inhereted defaults (#673) * fixing issue with inhereted defaults --- .github/workflows/ci.yml | 2 +- aredis_om/model/model.py | 17 ++++++++++++++++- tests/test_hash_model.py | 25 +++++++++++++++++++++++++ tests/test_json_model.py | 26 +++++++++++++++++++++++++- 4 files changed, 67 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 40aa64ff..1f31a115 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - pyver: [ "3.8", "3.9", "3.10", "3.11", "3.12", "pypy-3.8", "pypy-3.9", "pypy-3.10" ] + pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ] redisstack: [ "latest" ] fail-fast: false services: diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index bee37591..c8dbe26d 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1320,6 +1320,12 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = meta or getattr(new_class, "Meta", None) base_meta = getattr(new_class, "_meta", None) + if len(bases) == 1: + for f_name in bases[0].model_fields: + field = bases[0].model_fields[f_name] + print(field) + new_class.model_fields[f_name] = field + if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta new_class._meta = meta @@ -1455,7 +1461,16 @@ class Config: def __init__(__pydantic_self__, **data: Any) -> None: __pydantic_self__.validate_primary_key() - super().__init__(**data) + missing_fields = __pydantic_self__.model_fields.keys() - data.keys() - {"pk"} + + kwargs = data.copy() + + # This is a hack, we need to manually make sure we are setting up defaults correctly when we encounter them + # because inheritance apparently won't cover that in pydantic 2.0. + for field in missing_fields: + default_value = __pydantic_self__.model_fields.get(field).default # type: ignore + kwargs[field] = default_value + super().__init__(**kwargs) def __lt__(self, other): """Default sort: compare primary key of models.""" diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 97373671..875bfdb7 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -951,3 +951,28 @@ class TestLiterals(HashModel): await item.save() rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first() assert rematerialized.pk == item.pk + + +@py_test_mark_asyncio +async def test_child_class_expression_proxy(): + # https://github.com/redis/redis-om-python/issues/669 seeing weird issue with child classes initalizing all their undefined members as ExpressionProxies + class Model(HashModel): + first_name: str + last_name: str + age: int = Field(default=18) + bio: Optional[str] = Field(default=None) + + class Child(Model): + other_name: str + # is_new: bool = Field(default=True) + + await Migrator().run() + m = Child(first_name="Steve", last_name="Lorello", other_name="foo") + await m.save() + print(m.age) + assert m.age == 18 + + rematerialized = await Child.find(Child.pk == m.pk).first() + + assert rematerialized.age == 18 + assert rematerialized.bio is None diff --git a/tests/test_json_model.py b/tests/test_json_model.py index d963647b..92440dd9 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -12,6 +12,7 @@ import pytest import pytest_asyncio +from more_itertools.more import first from aredis_om import ( EmbeddedJsonModel, @@ -1159,7 +1160,29 @@ class TestLiterals(JsonModel): assert rematerialized.pk == item.pk -@py_test_mark_asyncio +async def test_child_class_expression_proxy(): + # https://github.com/redis/redis-om-python/issues/669 seeing weird issue with child classes initalizing all their undefined members as ExpressionProxies + class Model(JsonModel): + first_name: str + last_name: str + age: int = Field(default=18) + bio: Optional[str] = Field(default=None) + + class Child(Model): + is_new: bool = Field(default=True) + + await Migrator().run() + m = Child(first_name="Steve", last_name="Lorello") + await m.save() + print(m.age) + assert m.age == 18 + + rematerialized = await Child.find(Child.pk == m.pk).first() + + assert rematerialized.age == 18 + assert rematerialized.age != 19 + assert rematerialized.bio is None + async def test_merged_model_error(): class Player(EmbeddedJsonModel): username: str = Field(index=True) @@ -1173,3 +1196,4 @@ class Game(JsonModel): ) print(q.query) assert q.query == "(@player1_username:{username})| (@player2_username:{username})" +