Skip to content

Commit

Permalink
fix: pydantic nested schema serialization (#346)
Browse files Browse the repository at this point in the history
* fix: pydantic nested schema serialization

* fix: parametrize test models

* fix: replace conditional_dataclass with no-op decorator

* Add AvroBaseModel to unit-tests

* remove debug code

* modify tests to match AvroBaseModel vs AvroModel behavior

* fix linting

---------

Co-authored-by: David Devlin <d.devlin@keelvar.com>
Co-authored-by: Marcos Schroh <2828842+marcosschroh@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 19, 2023
1 parent fd261a6 commit 2b3e37b
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 87 deletions.
7 changes: 3 additions & 4 deletions dataclasses_avroschema/avrodantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> Js
"""
data = dict(self)

standardize_method = standardize_factory or standardize_custom_type

# te standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return {
key: value.asdict() if isinstance(value, AvroBaseModel) else standardize_custom_type(value)
for key, value in data.items()
}
return standardize_method(data)

def validate_avro(self) -> bool:
"""
Expand Down
2 changes: 2 additions & 0 deletions dataclasses_avroschema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def standardize_custom_type(value: typing.Any) -> typing.Any:
return tuple(standardize_custom_type(v) for v in value)
elif issubclass(type(value), enum.Enum):
return value.value
elif is_pydantic_model(type(value)):
return standardize_custom_type(value.asdict())
return value


Expand Down
42 changes: 27 additions & 15 deletions tests/serialization/test_logical_types_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,30 @@
import uuid
from dataclasses import dataclass

import pytest

from dataclasses_avroschema import AvroModel, serialization, types
from dataclasses_avroschema.avrodantic import AvroBaseModel

a_datetime = datetime.datetime(2019, 10, 12, 17, 57, 42, tzinfo=datetime.timezone.utc)


def test_logical_types():
@dataclass
class LogicalTypes(AvroModel):
parametrize_base_model = pytest.mark.parametrize(
"model_class, decorator", [(AvroModel, dataclass), (AvroBaseModel, lambda f: f)]
)


@parametrize_base_model
def test_logical_types(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class LogicalTypes(model_class):
"Some logical types"
birthday: datetime.date
meeting_time: datetime.time
meeting_time_micro: types.TimeMicro
release_datetime: datetime.datetime
release_datetime_micro: types.DateTimeMicro
event_uuid: uuid.uuid4
event_uuid: uuid.UUID

data = {
"birthday": a_datetime.date(),
Expand Down Expand Up @@ -52,11 +61,12 @@ class LogicalTypes(AvroModel):
assert logical_types.to_json() == json.dumps(data_json)


def test_logical_union():
@dataclass
class UnionSchema(AvroModel):
@parametrize_base_model
def test_logical_union(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class UnionSchema(model_class):
"Some Unions"
logical_union: typing.Union[datetime.datetime, datetime.date, uuid.uuid4]
logical_union: typing.Union[datetime.datetime, datetime.date, uuid.UUID]

data = {
"logical_union": a_datetime.date(),
Expand All @@ -79,17 +89,18 @@ class UnionSchema(AvroModel):
assert logical_types.to_json() == json.dumps(data_json)


def test_logical_types_with_defaults():
@dataclass
class LogicalTypes(AvroModel):
@parametrize_base_model
def test_logical_types_with_defaults(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class LogicalTypes(model_class):
"Some logical types"
implicit_decimal: types.condecimal(max_digits=3, decimal_places=2)
birthday: datetime.date = a_datetime.date()
meeting_time: datetime.time = a_datetime.time()
release_datetime: datetime.datetime = a_datetime
meeting_time_micro: types.TimeMicro = a_datetime.time()
release_datetime_micro: types.DateTimeMicro = a_datetime
event_uuid: uuid.uuid4 = "09f00184-7721-4266-a955-21048a5cc235"
event_uuid: uuid.UUID = uuid.UUID("09f00184-7721-4266-a955-21048a5cc235")
decimal_with_default: types.condecimal(max_digits=6, decimal_places=5) = decimal.Decimal("3.14159")

data = {
Expand Down Expand Up @@ -130,9 +141,10 @@ class LogicalTypes(AvroModel):

# A decimal.Decimal default is serialized into bytes by dataclasses-avroschema to be deserialized by fastavro
# this test is to make sure that process works as expected
def test_decimals_defaults():
@dataclass
class LogicalTypes(AvroModel):
@parametrize_base_model
def test_decimals_defaults(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class LogicalTypes(model_class):
"Some logical types"
explicit: types.condecimal(max_digits=3, decimal_places=2)
explicit_decimal_with_default: types.condecimal(max_digits=6, decimal_places=5) = decimal.Decimal("3.14159")
Expand Down
110 changes: 63 additions & 47 deletions tests/serialization/test_nested_schema_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@
import datetime
import typing

import pytest

from dataclasses_avroschema import AvroModel
from dataclasses_avroschema.avrodantic import AvroBaseModel

parametrize_base_model = pytest.mark.parametrize(
"model_class, decorator", [(AvroModel, dataclasses.dataclass), (AvroBaseModel, lambda f: f)]
)


def test_one_to_one_relationship():
@parametrize_base_model
def test_one_to_one_relationship(model_class: typing.Type[AvroModel], decorator: typing.Callable):
"""
Test schema relationship one-to-one serialization
"""

@dataclasses.dataclass
class Address(AvroModel):
@decorator
class Address(model_class):
"An Address"
street: str
street_number: int

@dataclasses.dataclass
class User(AvroModel):
@decorator
class User(model_class):
"An User with Address"
name: str
age: int
Expand Down Expand Up @@ -54,20 +62,21 @@ class User(AvroModel):
assert user.to_dict() == expected


def test_one_to_many_relationship():
@parametrize_base_model
def test_one_to_many_relationship(model_class: typing.Type[AvroModel], decorator: typing.Callable):
"""
Test schema relationship one-to-many serialization
"""

@dataclasses.dataclass
class Address(AvroModel):
@decorator
class Address(model_class):
"An Address"
street: str
street_number: int
created_at: datetime.datetime

@dataclasses.dataclass
class User(AvroModel):
@decorator
class User(model_class):
"User with multiple Address"
name: str
age: int
Expand Down Expand Up @@ -119,19 +128,20 @@ class User(AvroModel):
assert user.to_json()


def test_one_to_many_map_relationship():
@parametrize_base_model
def test_one_to_many_map_relationship(model_class: typing.Type[AvroModel], decorator: typing.Callable):
"""
Test schema relationship one-to-many using a map serialization
"""

@dataclasses.dataclass
class Address(AvroModel):
@decorator
class Address(model_class):
"An Address"
street: str
street_number: int

@dataclasses.dataclass
class User(AvroModel):
@decorator
class User(model_class):
"User with multiple Address"
name: str
age: int
Expand Down Expand Up @@ -171,23 +181,24 @@ class User(AvroModel):
assert user.to_dict() == expected


def test_nested_schemas_splitted() -> None:
@parametrize_base_model
def test_nested_schemas_splitted(model_class: typing.Type[AvroModel], decorator: typing.Callable) -> None:
"""
This test will cover the cases when nested schemas are
used in a separate way.
"""

@dataclasses.dataclass
class A(AvroModel):
@decorator
class A(model_class):
class Meta:
namespace = "namespace"

@dataclasses.dataclass
class B(AvroModel):
@decorator
class B(model_class):
a: A

@dataclasses.dataclass
class C(AvroModel):
@decorator
class C(model_class):
b: B
a: A

Expand All @@ -198,33 +209,34 @@ class C(AvroModel):
assert c.serialize() == b""


def test_nested_schemas_splitted_with_unions() -> None:
@parametrize_base_model
def test_nested_schemas_splitted_with_unions(model_class: typing.Type[AvroModel], decorator: typing.Callable) -> None:
"""
This test will cover the cases when nested schemas with Unions that are
used in a separate way.
"""

@dataclasses.dataclass
class S1(AvroModel):
@decorator
class S1(model_class):
pass

@dataclasses.dataclass
class S2(AvroModel):
@decorator
class S2(model_class):
pass

@dataclasses.dataclass
class A(AvroModel):
@decorator
class A(model_class):
s: typing.Union[S1, S2]

class Meta:
namespace = "namespace"

@dataclasses.dataclass
class B(AvroModel):
@decorator
class B(model_class):
a: A

@dataclasses.dataclass
class C(AvroModel):
@decorator
class C(model_class):
b: B
a: A

Expand All @@ -235,22 +247,25 @@ class C(AvroModel):
assert c.serialize() == b"\x00\x00"


def test_nested_scheamas_splitted_with_intermediates() -> None:
@dataclasses.dataclass
class A(AvroModel):
@parametrize_base_model
def test_nested_schemas_splitted_with_intermediates(
model_class: typing.Type[AvroModel], decorator: typing.Callable
) -> None:
@decorator
class A(model_class):
class Meta:
namespace = "namespace"

@dataclasses.dataclass
class B(AvroModel):
@decorator
class B(model_class):
a: A

@dataclasses.dataclass
class C(AvroModel):
@decorator
class C(model_class):
a: A

@dataclasses.dataclass
class D(AvroModel):
@decorator
class D(model_class):
b: B
c: C

Expand All @@ -263,14 +278,15 @@ class D(AvroModel):
assert c.serialize() == b""


def test_nested_several_layers():
@dataclasses.dataclass
class Friend(AvroModel):
@parametrize_base_model
def test_nested_several_layers(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class Friend(model_class):
name: str
hobbies: typing.List[str]

@dataclasses.dataclass
class User(AvroModel):
@decorator
class User(model_class):
name: str
friends: typing.List[Friend]

Expand Down
31 changes: 21 additions & 10 deletions tests/serialization/test_primitive_types_serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import dataclasses
import json
import math
from dataclasses import dataclass
import typing

import pytest

from dataclasses_avroschema import AvroModel, types
from dataclasses_avroschema.avrodantic import AvroBaseModel

parametrize_base_model = pytest.mark.parametrize(
"model_class, decorator", [(AvroModel, dataclasses.dataclass), (AvroBaseModel, lambda f: f)]
)


def test_primitive_types(user_dataclass):
Expand All @@ -24,9 +32,10 @@ def test_primitive_types(user_dataclass):
assert user.to_json() == json.dumps(data_json)


def test_primitive_types_with_defaults():
@dataclass
class User(AvroModel):
@parametrize_base_model
def test_primitive_types_with_defaults(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class User(model_class):
name: str = "marcos"
age: int = 20
has_pets: bool = False
Expand Down Expand Up @@ -65,9 +74,10 @@ class User(AvroModel):
assert user.to_json() == json.dumps(data_json)


def test_primitive_types_with_nulls():
@dataclass
class User(AvroModel):
@parametrize_base_model
def test_primitive_types_with_nulls(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class User(model_class):
name: str = None
age: int = 20
has_pets: bool = False
Expand Down Expand Up @@ -106,9 +116,10 @@ class User(AvroModel):
assert user.to_json() == json.dumps(data)


def test_float32_primitive_type():
@dataclass
class User(AvroModel):
@parametrize_base_model
def test_float32_primitive_type(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class User(model_class):
height: types.Float32 = None

data = {"height": 178.3}
Expand Down
Loading

0 comments on commit 2b3e37b

Please sign in to comment.