Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh authored and DiamondJoseph committed Aug 8, 2024
1 parent f49320c commit dc38b6c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 142 deletions.
123 changes: 28 additions & 95 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import partial
from inspect import isclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
Union,
get_origin,
get_type_hints,
)

import numpy as np
Expand All @@ -21,8 +16,7 @@
GetCoreSchemaHandler,
TypeAdapter,
)
from pydantic.dataclasses import rebuild_dataclass
from pydantic.fields import FieldInfo
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
"if_instance_do",
Expand Down Expand Up @@ -107,131 +101,70 @@ def calculate(self) -> int:
super_cls: The superclass of the union, Expression in the above example
discriminator: The discriminator that will be inserted into the
serialized documents for type determination. Defaults to "type".
config: A pydantic config class to be inserted into all
subclasses. Defaults to None.
Returns:
Type | Callable[[Type], Type]: A decorator that adds the necessary
Type: A decorator that adds the necessary
functionality to a class.
"""
tagged_union = _TaggedUnion(cls, discriminator)
_tagged_unions[cls] = tagged_union
cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator))
cls.__init_subclass__ = classmethod(
partial(_add_subclass_to_tagged_union, tagged_union, discriminator)
)
cls.__get_pydantic_core_schema__ = classmethod(
partial(__get_pydantic_core_schema__, tagged_union=tagged_union)
partial(_schema_of_tagged_union, tagged_union=tagged_union)
)
return cls


T = TypeVar("T", type, Callable)


def deserialize_as(cls, obj):
return _tagged_unions[cls].type_adapter.validate_python(obj)


def uses_tagged_union(cls_or_func: T) -> T:
"""
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
"""
for k, v in get_type_hints(cls_or_func).items():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls_or_func, k)
return cls_or_func


class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
self._base_class = base_class
# The members of the tagged union, i.e. subclasses of the baseclasses
self._members: list[type] = []
# Classes and their field names that refer to this tagged union
self._referrers: dict[type | Callable, set[str]] = {}
self.type_adapter: TypeAdapter = TypeAdapter(None)
self.type_adapter: TypeAdapter | None = None
self._discriminator = discriminator

def _make_union(self):
if len(self._members) > 0:
return Union[tuple(self._members)] # type: ignore # noqa

def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
# Set the field to use the `type` discriminator on deserialize
# https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators
if isclass(cls):
assert isinstance(
field, FieldInfo
), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501
field.discriminator = self._discriminator

def add_member(self, cls: type):
if cls in self._members:
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
# called muliple times for the same member, do no process if it wouldn't
# change the member list
return

self._members.append(cls)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set all the referrers
# to use this union
for referrer, fields in self._referrers.items():
if isclass(referrer):
for field in dataclasses.fields(referrer):
if field.name in fields:
field.type = union
self._set_discriminator(referrer, field.name, field.default)
rebuild_dataclass(referrer, force=True)
# Make a type adapter for use in deserialization
self.type_adapter = TypeAdapter(union)

def add_referrer(self, cls: type | Callable, attr_name: str):
self._referrers.setdefault(cls, set()).add(attr_name)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set the referrer
# (which is currently being constructed) to use it
# note that we use annotations as the class has not been turned into
# a dataclass yet
cls.__annotations__[attr_name] = union
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))


_tagged_unions: dict[type, _TaggedUnion] = {}


def __init_subclass__(discriminator: str, cls: type):

def schema(self, handler):
return tagged_union_schema(
{member.__name__: handler(member) for member in self._members},
self._discriminator,
)


def _add_subclass_to_tagged_union(
tagged_union: _TaggedUnion, discriminator: str, cls: type
):
# Add a discriminator field to the class so it can
# be identified when deserailizing, and make sure it is last in the list
# be identified when deserializing, and make sure it is last in the list
cls.__annotations__ = {
**cls.__annotations__,
discriminator: Literal[cls.__name__], # type: ignore
}
cls.type = Field(cls.__name__, repr=False) # type: ignore
# Replace any bare annotation with a discriminated union of subclasses
# and register this class as one that refers to that union so it can be updated
for k, v in get_type_hints(cls).items():
# This works for Expression[T] or Expression
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls, k)
setattr(cls, discriminator, Field(cls.__name__, repr=False)) # type: ignore

def _return_handler_of_cls(cls, source_type: Any, handler: GetCoreSchemaHandler):
return handler(cls)

cls.__get_pydantic_core_schema__ = classmethod(_return_handler_of_cls)
tagged_union.add_member(cls)

def __get_pydantic_core_schema__(

def _schema_of_tagged_union(
cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion
):
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
tagged_union.add_member(cls)
return handler(source_type)
return tagged_union.schema(handler)


def if_instance_do(x: Any, cls: type, func: Callable):
Expand Down
50 changes: 22 additions & 28 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
AxesPoints,
Axis,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
if_instance_do,
)
Expand Down Expand Up @@ -70,11 +69,6 @@ def serialize(self) -> Mapping[str, Any]:
"""Serialize the Region to a dictionary."""
return asdict(self) # type: ignore

@staticmethod
def deserialize(obj):
"""Deserialize the Region from a dictionary."""
return deserialize_as(Region, obj)


def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray:
"""Return a mask of the points inside the region.
Expand Down Expand Up @@ -119,6 +113,28 @@ def axis_sets(self) -> list[set[Axis]]:
return axis_sets


@dataclass(config=StrictConfig)
class Range(Region[Axis]):
"""Mask contains points of axis >= min and <= max.
>>> r = Range("x", 1, 2)
>>> r.mask({"x": np.array([0, 1, 2, 3, 4])})
array([False, True, True, False, False])
"""

axis: Axis = Field(description="The name matching the axis to mask in spec")
min: float = Field(description="The minimum inclusive value in the region")
max: float = Field(description="The minimum inclusive value in the region")

def axis_sets(self) -> list[set[Axis]]:
return [{self.axis}]

def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
v = points[self.axis]
mask = np.bitwise_and(v >= self.min, v <= self.max)
return mask


# Naming so we don't clash with typing.Union
@dataclass(config=StrictConfig)
class UnionOf(CombinationOf[Axis]):
Expand Down Expand Up @@ -186,28 +202,6 @@ def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
return mask


@dataclass(config=StrictConfig)
class Range(Region[Axis]):
"""Mask contains points of axis >= min and <= max.
>>> r = Range("x", 1, 2)
>>> r.mask({"x": np.array([0, 1, 2, 3, 4])})
array([False, True, True, False, False])
"""

axis: Axis = Field(description="The name matching the axis to mask in spec")
min: float = Field(description="The minimum inclusive value in the region")
max: float = Field(description="The minimum inclusive value in the region")

def axis_sets(self) -> list[set[Axis]]:
return [{self.axis}]

def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
v = points[self.axis]
mask = np.bitwise_and(v >= self.min, v <= self.max)
return mask


@dataclass(config=StrictConfig)
class Rectangle(Region[Axis]):
"""Mask contains points of axis within a rotated xy rectangle.
Expand Down
7 changes: 1 addition & 6 deletions src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import Field
from pydantic.dataclasses import dataclass

from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union
from scanspec.core import AxesPoints, Frames, Path

from .specs import Line, Spec

Expand All @@ -27,7 +27,6 @@


@dataclass
@uses_tagged_union
class ValidResponse:
"""Response model for spec validation."""

Expand All @@ -44,7 +43,6 @@ class PointsFormat(str, Enum):


@dataclass
@uses_tagged_union
class PointsRequest:
"""A request for generated scan points."""

Expand Down Expand Up @@ -125,7 +123,6 @@ class SmallestStepResponse:


@app.post("/valid", response_model=ValidResponse)
@uses_tagged_union
def valid(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> ValidResponse | JSONResponse:
Expand Down Expand Up @@ -198,7 +195,6 @@ def bounds(


@app.post("/gap", response_model=GapResponse)
@uses_tagged_union
def gap(
spec: Spec = Body(
...,
Expand All @@ -224,7 +220,6 @@ def gap(


@app.post("/smalleststep", response_model=SmallestStepResponse)
@uses_tagged_union
def smallest_step(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> SmallestStepResponse:
Expand Down
14 changes: 1 addition & 13 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

from collections.abc import Callable, Mapping
from dataclasses import asdict
from collections.abc import Callable
from typing import (
Any,
Generic,
)

Expand All @@ -18,7 +16,6 @@
Path,
SnakedFrames,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
gap_between_frames,
if_instance_do,
Expand Down Expand Up @@ -106,15 +103,6 @@ def concat(self, other: Spec) -> Concat[Axis]:
"""`Concat` the Spec with another, iterating one after the other."""
return Concat(self, other)

def serialize(self) -> Mapping[str, Any]:
"""Serialize the spec to a dictionary."""
return asdict(self) # type: ignore

@staticmethod
def deserialize(obj):
"""Deserialize the spec from a dictionary."""
return deserialize_as(Spec, obj)


@dataclass(config=StrictConfig)
class Product(Spec[Axis]):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_basemodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pydantic import BaseModel, Field

from scanspec.specs import Line, Spec


def test_base_model():
class Foo(BaseModel):
# class Foo(BaseModel):
spec: Spec = Field(description="This is for test")
# spec: float = 1.0

Foo(spec=Line("x", 1, 2, 5))

0 comments on commit dc38b6c

Please sign in to comment.