Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vanilla (de)serialization of discriminated unions #132

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ commands =
pre-commit: pre-commit run --all-files {posargs}
type-checking: mypy src tests {posargs}
tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs}
docs: sphinx-{posargs:build -EW --keep-going} -T docs build/html
docs: sphinx-{posargs:build -E --keep-going} -T docs build/html
"""

[tool.ruff]
Expand Down
2,404 changes: 335 additions & 2,069 deletions schema.json

Large diffs are not rendered by default.

195 changes: 83 additions & 112 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import partial
from functools import lru_cache
from inspect import isclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
Union,
get_origin,
get_type_hints,
)

import numpy as np
from pydantic import (
ConfigDict,
Field,
GetCoreSchemaHandler,
TypeAdapter,
)
from pydantic.dataclasses import rebuild_dataclass
from pydantic.fields import FieldInfo
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic.dataclasses import is_pydantic_dataclass, rebuild_dataclass
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
"if_instance_do",
Expand All @@ -43,15 +37,20 @@


def discriminated_union_of_subclasses(
cls,
super_cls: type,
discriminator: str = "type",
):
) -> type:
"""Add all subclasses of super_cls to a discriminated union.

For all subclasses of super_cls, add a discriminator field to identify
the type. Raw JSON should look like {"type": <type name>, params for
the type. Raw JSON should look like {<discriminator>: <type name>, params for
<type name>...}.

Subclasses that extend this class must be Pydantic dataclasses, and types that
need their schema to be updated when a new type that extends super_cls is
created must be either Pydantic dataclasses or BaseModels, and must be decorated
with @uses_tagged_union.

Example::

@discriminated_union_of_subclasses
Expand Down Expand Up @@ -107,131 +106,103 @@ def calculate(self) -> int:
super_cls: The superclass of the union, Expression in the above example
discriminator: The discriminator that will be inserted into the
serialized documents for type determination. Defaults to "type".
config: A pydantic config class to be inserted into all
subclasses. Defaults to None.

Returns:
Type | Callable[[Type], Type]: A decorator that adds the necessary
functionality to a class.
Type: decorated superclass with handling for subclasses to be added
to its discriminated union for deserialization
"""
tagged_union = _TaggedUnion(cls, discriminator)
_tagged_unions[cls] = tagged_union
cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator))
cls.__get_pydantic_core_schema__ = classmethod(
partial(__get_pydantic_core_schema__, tagged_union=tagged_union)
)
return cls
tagged_union = _TaggedUnion(super_cls, discriminator)
_tagged_unions[super_cls] = tagged_union

def add_subclass_to_union(subclass):
# Add a discriminator field to a subclass so it can
# be identified when deserializing
subclass.__annotations__ = {
**subclass.__annotations__,
discriminator: Literal[subclass.__name__], # type: ignore
}
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
if cls is not super_cls:
tagged_union.add_member(cls)
return handler(cls)
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
return tagged_union.schema(handler)

T = TypeVar("T", type, Callable)
super_cls.__init_subclass__ = classmethod(add_subclass_to_union) # type: ignore
super_cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union) # type: ignore
return super_cls


def deserialize_as(cls, obj):
return _tagged_unions[cls].type_adapter.validate_python(obj)
T = TypeVar("T", type, Callable)


def uses_tagged_union(cls_or_func: T) -> T:
"""
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
T = TypeVar("T", type, Callable)
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
"""
for k, v in get_type_hints(cls_or_func).items():
for v in get_type_hints(cls_or_func).values():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls_or_func, k)
tagged_union.add_reference(cls_or_func)
return cls_or_func


_tagged_unions: dict[type, _TaggedUnion] = {}
ZohebShaikh marked this conversation as resolved.
Show resolved Hide resolved


class _TaggedUnion:
ZohebShaikh marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, base_class: type, discriminator: str):
self._base_class = base_class
# The members of the tagged union, i.e. subclasses of the baseclasses
self._members: list[type] = []
# Classes and their field names that refer to this tagged union
self._referrers: dict[type | Callable, set[str]] = {}
self.type_adapter: TypeAdapter = TypeAdapter(None)
self._discriminator = discriminator

def _make_union(self):
if len(self._members) > 0:
return Union[tuple(self._members)] # type: ignore # noqa

def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
# Set the field to use the `type` discriminator on deserialize
# https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators
if isclass(cls):
assert isinstance(
field, FieldInfo
), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501
field.discriminator = self._discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._subclasses: list[type] = []
self._references: set[type | Callable] = set()

def add_member(self, cls: type):
if cls in self._members:
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
# called muliple times for the same member, do no process if it wouldn't
# change the member list
if cls in self._subclasses:
return

self._members.append(cls)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set all the referrers
# to use this union
for referrer, fields in self._referrers.items():
if isclass(referrer):
for field in dataclasses.fields(referrer):
if field.name in fields:
field.type = union
self._set_discriminator(referrer, field.name, field.default)
rebuild_dataclass(referrer, force=True)
# Make a type adapter for use in deserialization
self.type_adapter = TypeAdapter(union)

def add_referrer(self, cls: type | Callable, attr_name: str):
self._referrers.setdefault(cls, set()).add(attr_name)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set the referrer
# (which is currently being constructed) to use it
# note that we use annotations as the class has not been turned into
# a dataclass yet
cls.__annotations__[attr_name] = union
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))


_tagged_unions: dict[type, _TaggedUnion] = {}


def __init_subclass__(discriminator: str, cls: type):
# Add a discriminator field to the class so it can
# be identified when deserailizing, and make sure it is last in the list
cls.__annotations__ = {
**cls.__annotations__,
discriminator: Literal[cls.__name__], # type: ignore
}
cls.type = Field(cls.__name__, repr=False) # type: ignore
# Replace any bare annotation with a discriminated union of subclasses
# and register this class as one that refers to that union so it can be updated
for k, v in get_type_hints(cls).items():
# This works for Expression[T] or Expression
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls, k)
self._subclasses.append(cls)
for member in self._subclasses:
if member is not cls:
_TaggedUnion._rebuild(member)
for ref in self._references:
_TaggedUnion._rebuild(ref)

def add_reference(self, cls_or_func: type | Callable):
self._references.add(cls_or_func)

@staticmethod
# https://github.com/bluesky/scanspec/issues/133
def _rebuild(cls_or_func: type | Callable):
if isclass(cls_or_func):
if is_pydantic_dataclass(cls_or_func):
rebuild_dataclass(cls_or_func, force=True)
if issubclass(cls_or_func, BaseModel):
cls_or_func.model_rebuild(force=True)

def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:
return tagged_union_schema(
make_schema(tuple(self._subclasses), handler),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)


def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion
):
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
tagged_union.add_member(cls)
return handler(source_type)
@lru_cache(1)
def make_schema(members: tuple[type, ...], handler):
return {member.__name__: handler(member) for member in members}


def if_instance_do(x: Any, cls: type, func: Callable):
Expand Down
3 changes: 1 addition & 2 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mpl_toolkits.mplot3d import Axes3D, proj3d
from scipy import interpolate

from .core import Path, uses_tagged_union
from .core import Path
from .regions import Circle, Ellipse, Polygon, Rectangle, Region, find_regions
from .specs import DURATION, Spec

Expand Down Expand Up @@ -86,7 +86,6 @@ def _plot_spline(axes, ranges, arrays: list[np.ndarray], index_colours: dict[int
yield unscaled_splines


@uses_tagged_union
def plot_spec(spec: Spec[Any], title: str | None = None):
"""Plot a spec, drawing the path taken through the scan.

Expand Down
13 changes: 6 additions & 7 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations

from collections.abc import Iterator, Mapping
from dataclasses import asdict, is_dataclass
from dataclasses import is_dataclass
from typing import Any, Generic

import numpy as np
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.dataclasses import dataclass

from .core import (
AxesPoints,
Axis,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
if_instance_do,
)
Expand Down Expand Up @@ -68,12 +67,12 @@ def __xor__(self, other) -> SymmetricDifferenceOf[Axis]:

def serialize(self) -> Mapping[str, Any]:
"""Serialize the Region to a dictionary."""
return asdict(self) # type: ignore
return TypeAdapter(Region).dump_python(self)

@staticmethod
def deserialize(obj):
"""Deserialize the Region from a dictionary."""
return deserialize_as(Region, obj)
def deserialize(obj) -> Region:
"""Deserialize a Region from a dictionary."""
return TypeAdapter(Region).validate_python(obj)


def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray:
Expand Down
7 changes: 2 additions & 5 deletions src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
Points = str | list[float]


@dataclass
@uses_tagged_union
@dataclass
class ValidResponse:
"""Response model for spec validation."""

Expand All @@ -43,8 +43,8 @@ class PointsFormat(str, Enum):
BASE64_ENCODED = "BASE64_ENCODED"


@dataclass
@uses_tagged_union
@dataclass
class PointsRequest:
"""A request for generated scan points."""

Expand Down Expand Up @@ -125,7 +125,6 @@ class SmallestStepResponse:


@app.post("/valid", response_model=ValidResponse)
@uses_tagged_union
def valid(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> ValidResponse | JSONResponse:
Expand Down Expand Up @@ -198,7 +197,6 @@ def bounds(


@app.post("/gap", response_model=GapResponse)
@uses_tagged_union
def gap(
spec: Spec = Body(
...,
Expand All @@ -224,7 +222,6 @@ def gap(


@app.post("/smalleststep", response_model=SmallestStepResponse)
@uses_tagged_union
def smallest_step(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> SmallestStepResponse:
Expand Down
Loading
Loading