Skip to content

Commit 0cd5889

Browse files
committed
initial commit
1 parent f49320c commit 0cd5889

File tree

5 files changed

+57
-107
lines changed

5 files changed

+57
-107
lines changed

src/scanspec/core.py

Lines changed: 22 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
from collections.abc import Callable, Iterable, Iterator, Sequence
54
from functools import partial
65
from inspect import isclass
@@ -10,8 +9,6 @@
109
Literal,
1110
TypeVar,
1211
Union,
13-
get_origin,
14-
get_type_hints,
1512
)
1613

1714
import numpy as np
@@ -21,8 +18,8 @@
2118
GetCoreSchemaHandler,
2219
TypeAdapter,
2320
)
24-
from pydantic.dataclasses import rebuild_dataclass
2521
from pydantic.fields import FieldInfo
22+
from pydantic_core.core_schema import tagged_union_schema
2623

2724
__all__ = [
2825
"if_instance_do",
@@ -116,7 +113,9 @@ def calculate(self) -> int:
116113
"""
117114
tagged_union = _TaggedUnion(cls, discriminator)
118115
_tagged_unions[cls] = tagged_union
119-
cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator))
116+
cls.__init_subclass__ = classmethod(
117+
partial(__init_subclass__, tagged_union, discriminator)
118+
)
120119
cls.__get_pydantic_core_schema__ = classmethod(
121120
partial(__get_pydantic_core_schema__, tagged_union=tagged_union)
122121
)
@@ -126,35 +125,13 @@ def calculate(self) -> int:
126125
T = TypeVar("T", type, Callable)
127126

128127

129-
def deserialize_as(cls, obj):
130-
return _tagged_unions[cls].type_adapter.validate_python(obj)
131-
132-
133-
def uses_tagged_union(cls_or_func: T) -> T:
134-
"""
135-
Decorator that processes the type hints of a class or function to detect and
136-
register any tagged unions. If a tagged union is detected in the type hints,
137-
it registers the class or function as a referrer to that tagged union.
138-
Args:
139-
cls_or_func (T): The class or function to be processed for tagged unions.
140-
Returns:
141-
T: The original class or function, unmodified.
142-
"""
143-
for k, v in get_type_hints(cls_or_func).items():
144-
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
145-
if tagged_union:
146-
tagged_union.add_referrer(cls_or_func, k)
147-
return cls_or_func
148-
149-
150128
class _TaggedUnion:
151129
def __init__(self, base_class: type, discriminator: str):
152130
self._base_class = base_class
153131
# The members of the tagged union, i.e. subclasses of the baseclasses
154132
self._members: list[type] = []
155133
# Classes and their field names that refer to this tagged union
156-
self._referrers: dict[type | Callable, set[str]] = {}
157-
self.type_adapter: TypeAdapter = TypeAdapter(None)
134+
self.type_adapter: TypeAdapter | None = None
158135
self._discriminator = discriminator
159136

160137
def _make_union(self):
@@ -173,55 +150,36 @@ def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
173150
def add_member(self, cls: type):
174151
if cls in self._members:
175152
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
176-
# called muliple times for the same member, do no process if it wouldn't
153+
# called multiple times for the same member, do no process if it wouldn't
177154
# change the member list
178155
return
179156

180157
self._members.append(cls)
181158
union = self._make_union()
182159
if union:
183-
# There are more than 1 subclasses in the union, so set all the referrers
184-
# to use this union
185-
for referrer, fields in self._referrers.items():
186-
if isclass(referrer):
187-
for field in dataclasses.fields(referrer):
188-
if field.name in fields:
189-
field.type = union
190-
self._set_discriminator(referrer, field.name, field.default)
191-
rebuild_dataclass(referrer, force=True)
192-
# Make a type adapter for use in deserialization
193160
self.type_adapter = TypeAdapter(union)
194161

195-
def add_referrer(self, cls: type | Callable, attr_name: str):
196-
self._referrers.setdefault(cls, set()).add(attr_name)
197-
union = self._make_union()
198-
if union:
199-
# There are more than 1 subclasses in the union, so set the referrer
200-
# (which is currently being constructed) to use it
201-
# note that we use annotations as the class has not been turned into
202-
# a dataclass yet
203-
cls.__annotations__[attr_name] = union
204-
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))
205-
206162

207163
_tagged_unions: dict[type, _TaggedUnion] = {}
208164

209165

210-
def __init_subclass__(discriminator: str, cls: type):
166+
def __init_subclass__(tagged_union: _TaggedUnion, discriminator: str, cls: type):
211167
# Add a discriminator field to the class so it can
212-
# be identified when deserailizing, and make sure it is last in the list
168+
# be identified when deserializing, and make sure it is last in the list
213169
cls.__annotations__ = {
214170
**cls.__annotations__,
215171
discriminator: Literal[cls.__name__], # type: ignore
216172
}
217173
cls.type = Field(cls.__name__, repr=False) # type: ignore
218-
# Replace any bare annotation with a discriminated union of subclasses
219-
# and register this class as one that refers to that union so it can be updated
220-
for k, v in get_type_hints(cls).items():
221-
# This works for Expression[T] or Expression
222-
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
223-
if tagged_union:
224-
tagged_union.add_referrer(cls, k)
174+
175+
def __get_pydantic_core_schema__(
176+
cls, source_type: Any, handler: GetCoreSchemaHandler
177+
):
178+
handler.generate_schema(cls)
179+
return handler(cls)
180+
181+
cls.__get_pydantic_core_schema__ = classmethod(__get_pydantic_core_schema__)
182+
tagged_union.add_member(cls)
225183

226184

227185
def __get_pydantic_core_schema__(
@@ -230,8 +188,11 @@ def __get_pydantic_core_schema__(
230188
# Rebuild any dataclass (including this one) that references this union
231189
# Note that this has to be done after the creation of the dataclass so that
232190
# previously created classes can refer to this newly created class
233-
tagged_union.add_member(cls)
234-
return handler(source_type)
191+
# return handler(tagged_union._make_union())
192+
return tagged_union_schema(
193+
{member.__name__: handler(member) for member in tagged_union._members},
194+
tagged_union._discriminator,
195+
)
235196

236197

237198
def if_instance_do(x: Any, cls: type, func: Callable):

src/scanspec/regions.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
AxesPoints,
1313
Axis,
1414
StrictConfig,
15-
deserialize_as,
1615
discriminated_union_of_subclasses,
1716
if_instance_do,
1817
)
@@ -70,10 +69,27 @@ def serialize(self) -> Mapping[str, Any]:
7069
"""Serialize the Region to a dictionary."""
7170
return asdict(self) # type: ignore
7271

73-
@staticmethod
74-
def deserialize(obj):
75-
"""Deserialize the Region from a dictionary."""
76-
return deserialize_as(Region, obj)
72+
73+
@dataclass(config=StrictConfig)
74+
class Range(Region[Axis]):
75+
"""Mask contains points of axis >= min and <= max.
76+
77+
>>> r = Range("x", 1, 2)
78+
>>> r.mask({"x": np.array([0, 1, 2, 3, 4])})
79+
array([False, True, True, False, False])
80+
"""
81+
82+
axis: Axis = Field(description="The name matching the axis to mask in spec")
83+
min: float = Field(description="The minimum inclusive value in the region")
84+
max: float = Field(description="The minimum inclusive value in the region")
85+
86+
def axis_sets(self) -> list[set[Axis]]:
87+
return [{self.axis}]
88+
89+
def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
90+
v = points[self.axis]
91+
mask = np.bitwise_and(v >= self.min, v <= self.max)
92+
return mask
7793

7894

7995
def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray:
@@ -186,28 +202,6 @@ def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
186202
return mask
187203

188204

189-
@dataclass(config=StrictConfig)
190-
class Range(Region[Axis]):
191-
"""Mask contains points of axis >= min and <= max.
192-
193-
>>> r = Range("x", 1, 2)
194-
>>> r.mask({"x": np.array([0, 1, 2, 3, 4])})
195-
array([False, True, True, False, False])
196-
"""
197-
198-
axis: Axis = Field(description="The name matching the axis to mask in spec")
199-
min: float = Field(description="The minimum inclusive value in the region")
200-
max: float = Field(description="The minimum inclusive value in the region")
201-
202-
def axis_sets(self) -> list[set[Axis]]:
203-
return [{self.axis}]
204-
205-
def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
206-
v = points[self.axis]
207-
mask = np.bitwise_and(v >= self.min, v <= self.max)
208-
return mask
209-
210-
211205
@dataclass(config=StrictConfig)
212206
class Rectangle(Region[Axis]):
213207
"""Mask contains points of axis within a rotated xy rectangle.

src/scanspec/service.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic import Field
1212
from pydantic.dataclasses import dataclass
1313

14-
from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union
14+
from scanspec.core import AxesPoints, Frames, Path
1515

1616
from .specs import Line, Spec
1717

@@ -27,7 +27,6 @@
2727

2828

2929
@dataclass
30-
@uses_tagged_union
3130
class ValidResponse:
3231
"""Response model for spec validation."""
3332

@@ -44,7 +43,6 @@ class PointsFormat(str, Enum):
4443

4544

4645
@dataclass
47-
@uses_tagged_union
4846
class PointsRequest:
4947
"""A request for generated scan points."""
5048

@@ -125,7 +123,6 @@ class SmallestStepResponse:
125123

126124

127125
@app.post("/valid", response_model=ValidResponse)
128-
@uses_tagged_union
129126
def valid(
130127
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
131128
) -> ValidResponse | JSONResponse:
@@ -198,7 +195,6 @@ def bounds(
198195

199196

200197
@app.post("/gap", response_model=GapResponse)
201-
@uses_tagged_union
202198
def gap(
203199
spec: Spec = Body(
204200
...,
@@ -224,7 +220,6 @@ def gap(
224220

225221

226222
@app.post("/smalleststep", response_model=SmallestStepResponse)
227-
@uses_tagged_union
228223
def smallest_step(
229224
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
230225
) -> SmallestStepResponse:

src/scanspec/specs.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable, Mapping
4-
from dataclasses import asdict
3+
from collections.abc import Callable
54
from typing import (
6-
Any,
75
Generic,
86
)
97

@@ -18,7 +16,6 @@
1816
Path,
1917
SnakedFrames,
2018
StrictConfig,
21-
deserialize_as,
2219
discriminated_union_of_subclasses,
2320
gap_between_frames,
2421
if_instance_do,
@@ -106,15 +103,6 @@ def concat(self, other: Spec) -> Concat[Axis]:
106103
"""`Concat` the Spec with another, iterating one after the other."""
107104
return Concat(self, other)
108105

109-
def serialize(self) -> Mapping[str, Any]:
110-
"""Serialize the spec to a dictionary."""
111-
return asdict(self) # type: ignore
112-
113-
@staticmethod
114-
def deserialize(obj):
115-
"""Deserialize the spec from a dictionary."""
116-
return deserialize_as(Spec, obj)
117-
118106

119107
@dataclass(config=StrictConfig)
120108
class Product(Spec[Axis]):

tests/test_basemodel.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pydantic import BaseModel, Field
2+
3+
from scanspec.specs import Line, Spec
4+
5+
6+
def test_base_model():
7+
class Foo(BaseModel):
8+
# class Foo(BaseModel):
9+
spec: Spec = Field(description="This is for test")
10+
# spec: float = 1.0
11+
12+
Foo(spec=Line("x", 1, 2, 5))

0 commit comments

Comments
 (0)