Skip to content

Commit

Permalink
Allow vanilla Pydantic (de)serialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh authored and DiamondJoseph committed Aug 9, 2024
1 parent f49320c commit edff4c6
Show file tree
Hide file tree
Showing 8 changed files with 498 additions and 2,203 deletions.
2,404 changes: 335 additions & 2,069 deletions schema.json

Large diffs are not rendered by default.

187 changes: 76 additions & 111 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import partial
from functools import lru_cache
from inspect import isclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
Union,
get_origin,
get_type_hints,
)

import numpy as np
from pydantic import (
ConfigDict,
Field,
GetCoreSchemaHandler,
TypeAdapter,
)
from pydantic.dataclasses import rebuild_dataclass
from pydantic.fields import FieldInfo
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic.dataclasses import is_pydantic_dataclass, rebuild_dataclass
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
"if_instance_do",
Expand All @@ -43,13 +36,13 @@


def discriminated_union_of_subclasses(
cls,
super_cls: type,
discriminator: str = "type",
):
) -> type:
"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
the type. Raw JSON should look like {"type": <type name>, params for
the type. Raw JSON should look like {<discriminator>: <type name>, params for
<type name>...}.
Example::
Expand Down Expand Up @@ -104,134 +97,106 @@ def calculate(self) -> int:
)
Args:
super_cls: The superclass of the union, Expression in the above example
cls: The superclass of the union, Expression in the above example
discriminator: The discriminator that will be inserted into the
serialized documents for type determination. Defaults to "type".
config: A pydantic config class to be inserted into all
subclasses. Defaults to None.
Returns:
Type | Callable[[Type], Type]: A decorator that adds the necessary
functionality to a class.
Type: decorated superclass with handling for subclasses to be added
to its discriminated union for deserialization
"""
tagged_union = _TaggedUnion(cls, discriminator)
_tagged_unions[cls] = tagged_union
cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator))
cls.__get_pydantic_core_schema__ = classmethod(
partial(__get_pydantic_core_schema__, tagged_union=tagged_union)
)
return cls
tagged_union = _TaggedUnion(super_cls, discriminator)
_tagged_unions[super_cls] = tagged_union

def add_subclass_to_union(subclass):
# Add a discriminator field to a subclass so it can
# be identified when deserializing
subclass.__annotations__ = {
**subclass.__annotations__,
discriminator: Literal[subclass.__name__], # type: ignore
}
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
if cls is not super_cls:
tagged_union.add_member(cls)
return handler(cls)
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
return tagged_union.schema(handler)

T = TypeVar("T", type, Callable)
super_cls.__init_subclass__ = classmethod(add_subclass_to_union) # type: ignore
super_cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union) # type: ignore
return super_cls


def deserialize_as(cls, obj):
return _tagged_unions[cls].type_adapter.validate_python(obj)
T = TypeVar("T", type, Callable)


def uses_tagged_union(cls_or_func: T) -> T:
"""
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
T = TypeVar("T", type, Callable)
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
"""
for k, v in get_type_hints(cls_or_func).items():
for v in get_type_hints(cls_or_func).values():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls_or_func, k)
tagged_union.add_reference(cls_or_func)
return cls_or_func


_tagged_unions: dict[type, _TaggedUnion] = {}


class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
self._base_class = base_class
# The members of the tagged union, i.e. subclasses of the baseclasses
self._members: list[type] = []
# Classes and their field names that refer to this tagged union
self._referrers: dict[type | Callable, set[str]] = {}
self.type_adapter: TypeAdapter = TypeAdapter(None)
self._discriminator = discriminator

def _make_union(self):
if len(self._members) > 0:
return Union[tuple(self._members)] # type: ignore # noqa

def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
# Set the field to use the `type` discriminator on deserialize
# https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators
if isclass(cls):
assert isinstance(
field, FieldInfo
), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501
field.discriminator = self._discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._members: list[type] = []
self._references: set[type | Callable] = set()

def add_member(self, cls: type):
if cls in self._members:
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
# called muliple times for the same member, do no process if it wouldn't
# change the member list
return

self._members.append(cls)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set all the referrers
# to use this union
for referrer, fields in self._referrers.items():
if isclass(referrer):
for field in dataclasses.fields(referrer):
if field.name in fields:
field.type = union
self._set_discriminator(referrer, field.name, field.default)
rebuild_dataclass(referrer, force=True)
# Make a type adapter for use in deserialization
self.type_adapter = TypeAdapter(union)

def add_referrer(self, cls: type | Callable, attr_name: str):
self._referrers.setdefault(cls, set()).add(attr_name)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set the referrer
# (which is currently being constructed) to use it
# note that we use annotations as the class has not been turned into
# a dataclass yet
cls.__annotations__[attr_name] = union
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))


_tagged_unions: dict[type, _TaggedUnion] = {}


def __init_subclass__(discriminator: str, cls: type):
# Add a discriminator field to the class so it can
# be identified when deserailizing, and make sure it is last in the list
cls.__annotations__ = {
**cls.__annotations__,
discriminator: Literal[cls.__name__], # type: ignore
}
cls.type = Field(cls.__name__, repr=False) # type: ignore
# Replace any bare annotation with a discriminated union of subclasses
# and register this class as one that refers to that union so it can be updated
for k, v in get_type_hints(cls).items():
# This works for Expression[T] or Expression
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls, k)
for member in self._members:
if member is not cls:
_TaggedUnion._rebuild(member)
for ref in self._references:
_TaggedUnion._rebuild(ref)

def add_reference(self, cls_or_func: type | Callable):
self._references.add(cls_or_func)

@staticmethod
# https://github.com/bluesky/scanspec/issues/133
def _rebuild(cls_or_func: type | Callable):
if isclass(cls_or_func):
if is_pydantic_dataclass(cls_or_func):
rebuild_dataclass(cls_or_func, force=True)
if issubclass(cls_or_func, BaseModel):
cls_or_func.model_rebuild(force=True)

def schema(self, handler):
return tagged_union_schema(
make_schema(tuple(self._members), handler),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)


def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion
):
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
tagged_union.add_member(cls)
return handler(source_type)
@lru_cache(1)
def make_schema(members: tuple[type, ...], handler):
return {member.__name__: handler(member) for member in members}


def if_instance_do(x: Any, cls: type, func: Callable):
Expand Down
3 changes: 1 addition & 2 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mpl_toolkits.mplot3d import Axes3D, proj3d
from scipy import interpolate

from .core import Path, uses_tagged_union
from .core import Path
from .regions import Circle, Ellipse, Polygon, Rectangle, Region, find_regions
from .specs import DURATION, Spec

Expand Down Expand Up @@ -86,7 +86,6 @@ def _plot_spline(axes, ranges, arrays: list[np.ndarray], index_colours: dict[int
yield unscaled_splines


@uses_tagged_union
def plot_spec(spec: Spec[Any], title: str | None = None):
"""Plot a spec, drawing the path taken through the scan.
Expand Down
13 changes: 6 additions & 7 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations

from collections.abc import Iterator, Mapping
from dataclasses import asdict, is_dataclass
from dataclasses import is_dataclass
from typing import Any, Generic

import numpy as np
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.dataclasses import dataclass

from .core import (
AxesPoints,
Axis,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
if_instance_do,
)
Expand Down Expand Up @@ -68,12 +67,12 @@ def __xor__(self, other) -> SymmetricDifferenceOf[Axis]:

def serialize(self) -> Mapping[str, Any]:
"""Serialize the Region to a dictionary."""
return asdict(self) # type: ignore
return TypeAdapter(Region).dump_python(self)

@staticmethod
def deserialize(obj):
"""Deserialize the Region from a dictionary."""
return deserialize_as(Region, obj)
def deserialize(obj) -> Region:
"""Deserialize a Region from a dictionary."""
return TypeAdapter(Region).validate_python(obj)


def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray:
Expand Down
7 changes: 2 additions & 5 deletions src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
Points = str | list[float]


@dataclass
@uses_tagged_union
@dataclass
class ValidResponse:
"""Response model for spec validation."""

Expand All @@ -43,8 +43,8 @@ class PointsFormat(str, Enum):
BASE64_ENCODED = "BASE64_ENCODED"


@dataclass
@uses_tagged_union
@dataclass
class PointsRequest:
"""A request for generated scan points."""

Expand Down Expand Up @@ -125,7 +125,6 @@ class SmallestStepResponse:


@app.post("/valid", response_model=ValidResponse)
@uses_tagged_union
def valid(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> ValidResponse | JSONResponse:
Expand Down Expand Up @@ -198,7 +197,6 @@ def bounds(


@app.post("/gap", response_model=GapResponse)
@uses_tagged_union
def gap(
spec: Spec = Body(
...,
Expand All @@ -224,7 +222,6 @@ def gap(


@app.post("/smalleststep", response_model=SmallestStepResponse)
@uses_tagged_union
def smallest_step(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> SmallestStepResponse:
Expand Down
14 changes: 6 additions & 8 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from collections.abc import Callable, Mapping
from dataclasses import asdict
from typing import (
Any,
Generic,
)

import numpy as np
from pydantic import Field, validate_call
from pydantic import Field, TypeAdapter, validate_call
from pydantic.dataclasses import dataclass

from .core import (
Expand All @@ -18,7 +17,6 @@
Path,
SnakedFrames,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
gap_between_frames,
if_instance_do,
Expand Down Expand Up @@ -107,13 +105,13 @@ def concat(self, other: Spec) -> Concat[Axis]:
return Concat(self, other)

def serialize(self) -> Mapping[str, Any]:
"""Serialize the spec to a dictionary."""
return asdict(self) # type: ignore
"""Serialize the Spec to a dictionary."""
return TypeAdapter(Spec).dump_python(self)

@staticmethod
def deserialize(obj):
"""Deserialize the spec from a dictionary."""
return deserialize_as(Spec, obj)
def deserialize(obj) -> Spec:
"""Deserialize a Spec from a dictionary."""
return TypeAdapter(Spec).validate_python(obj)


@dataclass(config=StrictConfig)
Expand Down
Loading

0 comments on commit edff4c6

Please sign in to comment.