Skip to content

Commit

Permalink
fix: revert Pydantic custom serializer (#435)
Browse files Browse the repository at this point in the history
* The Pydantic custom serializer resulted in `model_dump` to only
include minimal fields for ga4gh serialization. This caused issues in
downstream apps that leverage FastAPI
  • Loading branch information
korikuzma authored Jul 19, 2024
1 parent 03425d3 commit 4b0fa9f
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"pydantic~=2.1",
"bioutils",
"requests",
"canonicaljson",
]

[project.optional-dependencies]
Expand Down
3 changes: 1 addition & 2 deletions src/ga4gh/core/entity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, Annotated, Optional, Union, List
from enum import Enum

from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer, ConfigDict
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict

from ga4gh.core import GA4GH_IR_REGEXP

Expand Down Expand Up @@ -78,7 +78,6 @@ class IRI(RootModel):
def __hash__(self):
return self.root.__hash__()

@model_serializer(when_used='json')
def ga4gh_serialize(self):
m = GA4GH_IR_REGEXP.match(self.root)
if m is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/ga4gh/core/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
For that reason, they are implemented here in one file.
"""

from canonicaljson import encode_canonical_json
import contextvars
import re
from contextlib import ContextDecorator
Expand Down Expand Up @@ -194,6 +194,6 @@ def ga4gh_serialize(obj: BaseModel, as_version: PrevVrsVersion | None = None) ->
PrevVrsVersion.validate(as_version)

if as_version is None:
return obj.model_dump_json().encode("utf-8")
return encode_canonical_json(obj.ga4gh_serialize())
else:
return obj.ga4gh_serialize_as_version(as_version)
11 changes: 5 additions & 6 deletions src/ga4gh/vrs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
)
from ga4gh.core.pydantic import get_pydantic_root

from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer, ConfigDict
from canonicaljson import encode_canonical_json
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict

from ga4gh.core.pydantic import (
getattr_in
Expand Down Expand Up @@ -178,7 +179,7 @@ def _recurse_ga4gh_serialize(obj):
elif isinstance(obj, _ValueObject):
return obj.ga4gh_serialize()
elif isinstance(obj, RootModel):
return _recurse_ga4gh_serialize(obj.model_dump(mode='json'))
return _recurse_ga4gh_serialize(obj.model_dump())
elif isinstance(obj, str):
return obj
elif isinstance(obj, list):
Expand All @@ -193,9 +194,8 @@ class _ValueObject(DomainEntity, ABC):
"""

def __hash__(self):
return self.model_dump_json().__hash__()
return encode_canonical_json(self.ga4gh_serialize()).decode("utf-8").__hash__()

@model_serializer(when_used='json')
def ga4gh_serialize(self) -> Dict:
out = OrderedDict()
for k in self.ga4gh.keys:
Expand Down Expand Up @@ -242,7 +242,7 @@ def compute_digest(self, store=True, as_version: PrevVrsVersion | None = None) -
returned following the conventions of the VRS version indicated by ``as_version_``.
"""
if as_version is None:
digest = sha512t24u(self.model_dump_json().encode("utf-8"))
digest = sha512t24u(encode_canonical_json(self.ga4gh_serialize()))
if store:
self.digest = digest
else:
Expand Down Expand Up @@ -580,7 +580,6 @@ class CisPhasedBlock(_VariationBase):
)
sequenceReference: Optional[SequenceReference] = Field(None, description="An optional Sequence Reference on which all of the in-cis Alleles are found. When defined, this may be used to implicitly define the `sequenceReference` attribute for each of the CisPhasedBlock member Alleles.")

@model_serializer(when_used="json")
def ga4gh_serialize(self) -> Dict:
out = _ValueObject.ga4gh_serialize(self)
out["members"] = sorted(out["members"])
Expand Down
4 changes: 3 additions & 1 deletion tests/validation/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def ga4gh_1_3_serialize(*args, **kwargs):
return ga4gh_serialize(*args, **kwargs)

fxs = {
"ga4gh_serialize": lambda o: ga4gh_serialize(o).decode() if ga4gh_serialize(o) else None,
"ga4gh_serialize": ga4gh_serialize,
"ga4gh_digest": ga4gh_digest,
"ga4gh_identify": ga4gh_identify,
"ga4gh_1_3_digest": ga4gh_1_3_digest,
Expand Down Expand Up @@ -60,6 +60,8 @@ def flatten_tests(vts):
def test_validation(cls, data, fn, exp):
o = getattr(models, cls)(**data)
fx = fxs[fn]
if fn == "ga4gh_serialize":
exp = exp.encode("utf-8")
assert fx(o) == exp


Expand Down

0 comments on commit 4b0fa9f

Please sign in to comment.