Skip to content

Commit

Permalink
Ensure order of generated Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Aug 6, 2024
1 parent dc65bd1 commit d2b7495
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 305 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ ignore_missing_imports = true # Ignore missing stubs in imported modules
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error
addopts = """
--tb=native -vv --doctest-modules --doctest-glob="*.rst"
-W ignore::DeprecationWarning
"""
# https://iscinumpy.gitlab.io/post/bound-version-constraints/#watch-for-warnings
filterwarnings = "error"
Expand Down
578 changes: 289 additions & 289 deletions schema.json

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
StrictConfig: ConfigDict = {"extra": "forbid"}


def discriminated_union_of_subclasses(cls):
def discriminated_union_of_subclasses(
cls,
discriminator: str = "type",
):
"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
Expand Down Expand Up @@ -118,7 +121,7 @@ def calculate(self) -> int:
Union[Type, Callable[[Type], Type]]: A decorator that adds the necessary
functionality to a class.
"""
tagged_union = _TaggedUnion(cls)
tagged_union = _TaggedUnion(cls, discriminator)
_tagged_unions[cls] = tagged_union
cls.__init_subclass__ = classmethod(__init_subclass__)
cls.__get_pydantic_core_schema__ = classmethod(
Expand All @@ -143,13 +146,14 @@ def uses_tagged_union(cls_or_func: T) -> T:


class _TaggedUnion:
def __init__(self, base_class: type):
def __init__(self, base_class: type, discriminator: str):
self._base_class = base_class
# The members of the tagged union, i.e. subclasses of the baseclasses
self._members: set[type] = set()
self._members: list[type] = []
# Classes and their field names that refer to this tagged union
self._referrers: dict[type | Callable, set[str]] = {}
self.type_adapter = TypeAdapter(None)
self.type_adapter: TypeAdapter = TypeAdapter(None)
self._discriminator = discriminator

def _make_union(self):
# Make a union of members
Expand All @@ -165,7 +169,7 @@ def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
assert isinstance(
field, FieldInfo
), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501
field.discriminator = "type"
field.discriminator = self._discriminator

def add_member(self, cls: type):
if cls in self._members:
Expand All @@ -175,7 +179,7 @@ def add_member(self, cls: type):
return
if cls is self._base_class:
return
self._members.add(cls)
self._members.append(cls)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set all the referrers
Expand All @@ -199,8 +203,6 @@ def add_referrer(self, cls: type | Callable, attr_name: str):
# note that we use annotations as the class has not been turned into
# a dataclass yet
cls.__annotations__[attr_name] = union
if not isclass(cls):
print(dir(cls.__defaults__))
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))


Expand Down
4 changes: 2 additions & 2 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ class Polygon(Region[Axis]):
x_axis: Axis = Field(description="The name matching the x axis of the spec")
y_axis: Axis = Field(description="The name matching the y axis of the spec")
x_verts: List[float] = Field(
description="The Nx1 x coordinates of the polygons vertices", min_len=3
description="The Nx1 x coordinates of the polygons vertices", min_length=3
)
y_verts: List[float] = Field(
description="The Nx1 y coordinates of the polygons vertices", min_len=3
description="The Nx1 y coordinates of the polygons vertices", min_length=3
)

def axis_sets(self) -> List[Set[Axis]]:
Expand Down
2 changes: 1 addition & 1 deletion src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class SmallestStepResponse:
@app.post("/valid", response_model=ValidResponse)
@uses_tagged_union
def valid(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC], discriminator="type"),
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> Union[ValidResponse, JSONResponse]:
"""Validate wether a ScanSpec can produce a viable scan.
Expand Down
6 changes: 3 additions & 3 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def bounded(
return cls(axis, start, stop, num)


Line.bounded = validate_call(Line.bounded)
Line.bounded = validate_call(Line.bounded) # type:ignore


@dataclass(config=StrictConfig)
Expand Down Expand Up @@ -281,7 +281,7 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
)


Static.duration = validate_call(Static.duration)
Static.duration = validate_call(Static.duration) # type:ignore


@dataclass(config=StrictConfig)
Expand Down Expand Up @@ -369,7 +369,7 @@ def spaced(
)


Spiral.spaced = validate_call(Spiral.spaced)
Spiral.spaced = validate_call(Spiral.spaced) # type:ignore


@dataclass(config=StrictConfig)
Expand Down

0 comments on commit d2b7495

Please sign in to comment.