Skip to content

Commit ec6be11

Browse files
committed
added more tests
1 parent a7c804b commit ec6be11

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

src/scanspec/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,7 @@ def add_member(self, cls: type):
198198
# called muliple times for the same member, do no process if it wouldn't
199199
# change the member list
200200
return
201-
if cls is self._base_class:
202-
return
201+
203202
self._members.append(cls)
204203
union = self._make_union()
205204
if union:

src/scanspec/regions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterator
4-
from dataclasses import is_dataclass
5-
from typing import Generic
4+
from dataclasses import asdict, is_dataclass
5+
from typing import Any, Generic, Mapping
66

77
import numpy as np
88
from pydantic import BaseModel, Field
@@ -66,6 +66,10 @@ def __sub__(self, other) -> DifferenceOf[Axis]:
6666
def __xor__(self, other) -> SymmetricDifferenceOf[Axis]:
6767
return if_instance_do(other, Region, lambda o: SymmetricDifferenceOf(self, o))
6868

69+
def serialize(self) -> Mapping[str, Any]:
70+
"""Serialize the Region to a dictionary."""
71+
return asdict(self) # type: ignore
72+
6973
@staticmethod
7074
def deserialize(obj):
7175
"""Deserialize the Region from a dictionary."""

tests/test_serialization.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
from pydantic import ValidationError
66

7-
from scanspec.regions import Circle, Rectangle, UnionOf
7+
from scanspec.regions import Circle, Rectangle, Region, UnionOf
88
from scanspec.specs import Line, Mask, Spec, Spiral
99

1010

@@ -15,6 +15,20 @@ def test_line_serializes() -> None:
1515
assert Spec.deserialize(serialized) == ob
1616

1717

18+
def test_circle_serializes() -> None:
19+
ob = Circle("x", "y", x_middle=0, y_middle=1, radius=4)
20+
serialized = {
21+
"x_axis": "x",
22+
"y_axis": "y",
23+
"x_middle": 0.0,
24+
"y_middle": 1.0,
25+
"radius": 4.0,
26+
"type": "Circle",
27+
}
28+
assert ob.serialize() == serialized
29+
assert Region.deserialize(serialized) == ob
30+
31+
1832
def test_masked_circle_serializes() -> None:
1933
ob = Mask(Line("x", 0, 1, 4), Circle("x", "y", x_middle=0, y_middle=1, radius=4))
2034
serialized = {

0 commit comments

Comments
 (0)