Skip to content

Commit dbc2615

Browse files
ZohebShaikhDiamondJoseph
authored andcommitted
initial commit
1 parent f49320c commit dbc2615

File tree

5 files changed

+93
-130
lines changed

5 files changed

+93
-130
lines changed

src/scanspec/core.py

Lines changed: 45 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,22 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
from collections.abc import Callable, Iterable, Iterator, Sequence
5-
from functools import partial
6-
from inspect import isclass
4+
from functools import lru_cache
75
from typing import (
86
Any,
97
Generic,
108
Literal,
119
TypeVar,
12-
Union,
13-
get_origin,
14-
get_type_hints,
1510
)
1611

1712
import numpy as np
1813
from pydantic import (
1914
ConfigDict,
2015
Field,
2116
GetCoreSchemaHandler,
22-
TypeAdapter,
2317
)
2418
from pydantic.dataclasses import rebuild_dataclass
25-
from pydantic.fields import FieldInfo
19+
from pydantic_core.core_schema import tagged_union_schema
2620

2721
__all__ = [
2822
"if_instance_do",
@@ -43,13 +37,13 @@
4337

4438

4539
def discriminated_union_of_subclasses(
46-
cls,
40+
cls: type,
4741
discriminator: str = "type",
48-
):
42+
) -> type:
4943
"""Add all subclasses of super_cls to a discriminated union.
5044
5145
For all subclasses of super_cls, add a discriminator field to identify
52-
the type. Raw JSON should look like {"type": <type name>, params for
46+
the type. Raw JSON should look like {<discriminator>: <type name>, params for
5347
<type name>...}.
5448
5549
Example::
@@ -107,131 +101,69 @@ def calculate(self) -> int:
107101
super_cls: The superclass of the union, Expression in the above example
108102
discriminator: The discriminator that will be inserted into the
109103
serialized documents for type determination. Defaults to "type".
110-
config: A pydantic config class to be inserted into all
111-
subclasses. Defaults to None.
112104
113105
Returns:
114-
Type | Callable[[Type], Type]: A decorator that adds the necessary
106+
Type: A decorator that adds the necessary
115107
functionality to a class.
116108
"""
117109
tagged_union = _TaggedUnion(cls, discriminator)
118-
_tagged_unions[cls] = tagged_union
119-
cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator))
120-
cls.__get_pydantic_core_schema__ = classmethod(
121-
partial(__get_pydantic_core_schema__, tagged_union=tagged_union)
122-
)
123-
return cls
124110

111+
def add_subclass_to_union(subclass):
112+
# Add a discriminator field to a subclass so it can
113+
# be identified when deserializing
114+
subclass.__annotations__ = {
115+
**subclass.__annotations__,
116+
discriminator: Literal[subclass.__name__], # type: ignore
117+
}
118+
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore
125119

126-
T = TypeVar("T", type, Callable)
120+
def default_handler(subclass, source_type: Any, handler: GetCoreSchemaHandler):
121+
tagged_union.add_member(subclass)
122+
return handler(subclass)
127123

124+
subclass.__get_pydantic_core_schema__ = classmethod(default_handler)
128125

129-
def deserialize_as(cls, obj):
130-
return _tagged_unions[cls].type_adapter.validate_python(obj)
126+
def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
127+
# Rebuild any dataclass (including this one) that references this union
128+
# Note that this has to be done after the creation of the dataclass so that
129+
# previously created classes can refer to this newly created class
130+
return tagged_union.schema(handler)
131131

132+
cls.__init_subclass__ = classmethod(add_subclass_to_union)
133+
cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union)
134+
return cls
132135

133-
def uses_tagged_union(cls_or_func: T) -> T:
134-
"""
135-
Decorator that processes the type hints of a class or function to detect and
136-
register any tagged unions. If a tagged union is detected in the type hints,
137-
it registers the class or function as a referrer to that tagged union.
138-
Args:
139-
cls_or_func (T): The class or function to be processed for tagged unions.
140-
Returns:
141-
T: The original class or function, unmodified.
142-
"""
143-
for k, v in get_type_hints(cls_or_func).items():
144-
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
145-
if tagged_union:
146-
tagged_union.add_referrer(cls_or_func, k)
147-
return cls_or_func
136+
137+
T = TypeVar("T", type, Callable)
148138

149139

150140
class _TaggedUnion:
151141
def __init__(self, base_class: type, discriminator: str):
152142
self._base_class = base_class
153-
# The members of the tagged union, i.e. subclasses of the baseclasses
154-
self._members: list[type] = []
155143
# Classes and their field names that refer to this tagged union
156-
self._referrers: dict[type | Callable, set[str]] = {}
157-
self.type_adapter: TypeAdapter = TypeAdapter(None)
158144
self._discriminator = discriminator
159-
160-
def _make_union(self):
161-
if len(self._members) > 0:
162-
return Union[tuple(self._members)] # type: ignore # noqa
163-
164-
def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
165-
# Set the field to use the `type` discriminator on deserialize
166-
# https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators
167-
if isclass(cls):
168-
assert isinstance(
169-
field, FieldInfo
170-
), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501
171-
field.discriminator = self._discriminator
145+
# The members of the tagged union, i.e. subclasses of the baseclass
146+
self._members: list[type] = []
172147

173148
def add_member(self, cls: type):
174149
if cls in self._members:
175-
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
176-
# called muliple times for the same member, do no process if it wouldn't
177-
# change the member list
178150
return
179-
180151
self._members.append(cls)
181-
union = self._make_union()
182-
if union:
183-
# There are more than 1 subclasses in the union, so set all the referrers
184-
# to use this union
185-
for referrer, fields in self._referrers.items():
186-
if isclass(referrer):
187-
for field in dataclasses.fields(referrer):
188-
if field.name in fields:
189-
field.type = union
190-
self._set_discriminator(referrer, field.name, field.default)
191-
rebuild_dataclass(referrer, force=True)
192-
# Make a type adapter for use in deserialization
193-
self.type_adapter = TypeAdapter(union)
194-
195-
def add_referrer(self, cls: type | Callable, attr_name: str):
196-
self._referrers.setdefault(cls, set()).add(attr_name)
197-
union = self._make_union()
198-
if union:
199-
# There are more than 1 subclasses in the union, so set the referrer
200-
# (which is currently being constructed) to use it
201-
# note that we use annotations as the class has not been turned into
202-
# a dataclass yet
203-
cls.__annotations__[attr_name] = union
204-
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))
205-
206-
207-
_tagged_unions: dict[type, _TaggedUnion] = {}
208-
209-
210-
def __init_subclass__(discriminator: str, cls: type):
211-
# Add a discriminator field to the class so it can
212-
# be identified when deserailizing, and make sure it is last in the list
213-
cls.__annotations__ = {
214-
**cls.__annotations__,
215-
discriminator: Literal[cls.__name__], # type: ignore
216-
}
217-
cls.type = Field(cls.__name__, repr=False) # type: ignore
218-
# Replace any bare annotation with a discriminated union of subclasses
219-
# and register this class as one that refers to that union so it can be updated
220-
for k, v in get_type_hints(cls).items():
221-
# This works for Expression[T] or Expression
222-
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
223-
if tagged_union:
224-
tagged_union.add_referrer(cls, k)
225-
226-
227-
def __get_pydantic_core_schema__(
228-
cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion
229-
):
230-
# Rebuild any dataclass (including this one) that references this union
231-
# Note that this has to be done after the creation of the dataclass so that
232-
# previously created classes can refer to this newly created class
233-
tagged_union.add_member(cls)
234-
return handler(source_type)
152+
for member in self._members:
153+
if member != cls:
154+
rebuild_dataclass(member, force=True)
155+
156+
def schema(self, handler):
157+
return tagged_union_schema(
158+
make_schema(tuple(self._members), handler),
159+
discriminator=self._discriminator,
160+
ref=self._base_class.__name__,
161+
)
162+
163+
164+
@lru_cache(1)
165+
def make_schema(members: tuple[type, ...], handler):
166+
return {member.__name__: handler(member) for member in members}
235167

236168

237169
def if_instance_do(x: Any, cls: type, func: Callable):

src/scanspec/regions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
from typing import Any, Generic
66

77
import numpy as np
8-
from pydantic import BaseModel, Field
8+
from pydantic import BaseModel, Field, TypeAdapter
99
from pydantic.dataclasses import dataclass
1010

1111
from .core import (
1212
AxesPoints,
1313
Axis,
1414
StrictConfig,
15-
deserialize_as,
1615
discriminated_union_of_subclasses,
1716
if_instance_do,
1817
)
@@ -71,9 +70,8 @@ def serialize(self) -> Mapping[str, Any]:
7170
return asdict(self) # type: ignore
7271

7372
@staticmethod
74-
def deserialize(obj):
75-
"""Deserialize the Region from a dictionary."""
76-
return deserialize_as(Region, obj)
73+
def deserialize(obj: Mapping[str, Any]) -> Region:
74+
return TypeAdapter(Region).validate_python(obj)
7775

7876

7977
def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray:

src/scanspec/service.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic import Field
1212
from pydantic.dataclasses import dataclass
1313

14-
from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union
14+
from scanspec.core import AxesPoints, Frames, Path
1515

1616
from .specs import Line, Spec
1717

@@ -27,7 +27,6 @@
2727

2828

2929
@dataclass
30-
@uses_tagged_union
3130
class ValidResponse:
3231
"""Response model for spec validation."""
3332

@@ -44,7 +43,6 @@ class PointsFormat(str, Enum):
4443

4544

4645
@dataclass
47-
@uses_tagged_union
4846
class PointsRequest:
4947
"""A request for generated scan points."""
5048

@@ -125,7 +123,6 @@ class SmallestStepResponse:
125123

126124

127125
@app.post("/valid", response_model=ValidResponse)
128-
@uses_tagged_union
129126
def valid(
130127
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
131128
) -> ValidResponse | JSONResponse:
@@ -198,7 +195,6 @@ def bounds(
198195

199196

200197
@app.post("/gap", response_model=GapResponse)
201-
@uses_tagged_union
202198
def gap(
203199
spec: Spec = Body(
204200
...,
@@ -224,7 +220,6 @@ def gap(
224220

225221

226222
@app.post("/smalleststep", response_model=SmallestStepResponse)
227-
@uses_tagged_union
228223
def smallest_step(
229224
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
230225
) -> SmallestStepResponse:

src/scanspec/specs.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99

1010
import numpy as np
11-
from pydantic import Field, validate_call
11+
from pydantic import Field, TypeAdapter, validate_call
1212
from pydantic.dataclasses import dataclass
1313

1414
from .core import (
@@ -18,7 +18,6 @@
1818
Path,
1919
SnakedFrames,
2020
StrictConfig,
21-
deserialize_as,
2221
discriminated_union_of_subclasses,
2322
gap_between_frames,
2423
if_instance_do,
@@ -107,13 +106,12 @@ def concat(self, other: Spec) -> Concat[Axis]:
107106
return Concat(self, other)
108107

109108
def serialize(self) -> Mapping[str, Any]:
110-
"""Serialize the spec to a dictionary."""
109+
"""Serialize the Spec to a dictionary."""
111110
return asdict(self) # type: ignore
112111

113112
@staticmethod
114-
def deserialize(obj):
115-
"""Deserialize the spec from a dictionary."""
116-
return deserialize_as(Spec, obj)
113+
def deserialize(obj: Mapping[str, Any]) -> Spec:
114+
return TypeAdapter(Spec).validate_python(obj)
117115

118116

119117
@dataclass(config=StrictConfig)

tests/test_basemodel.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pytest
2+
from pydantic import BaseModel, TypeAdapter
3+
4+
from scanspec.specs import Line, Spec
5+
6+
7+
class Foo(BaseModel):
8+
spec: Spec
9+
10+
11+
simple_foo = Foo(spec=Line("x", 1, 2, 5))
12+
nested_foo = Foo(spec=(Line("x", 1, 2, 5) + Line("x", 1, 2, 5)) * Line("y", 1, 2, 5))
13+
14+
15+
@pytest.mark.parametrize("model", [simple_foo, nested_foo])
16+
def test_model_validation(model: Foo):
17+
# To/from Python dict
18+
as_dict = model.model_dump()
19+
deserialized = Foo.model_validate(as_dict)
20+
assert deserialized == model
21+
22+
# To/from Json dict
23+
as_json = model.model_dump_json()
24+
deserialized = Foo.model_validate_json(as_json)
25+
assert deserialized == model
26+
27+
28+
@pytest.mark.parametrize("model", [simple_foo, nested_foo])
29+
def test_type_adapter(model: Foo):
30+
type_adapter = TypeAdapter(Foo)
31+
32+
# To/from Python dict
33+
as_dict = model.model_dump()
34+
deserialized = type_adapter.validate_python(as_dict)
35+
assert deserialized == model
36+
37+
# To/from Json dict
38+
as_json = model.model_dump_json()
39+
deserialized = type_adapter.validate_json(as_json)
40+
assert deserialized == model

0 commit comments

Comments
 (0)