Skip to content

Commit b56c810

Browse files
committed
Restructure Axis and Axis testing
1 parent 9b18e12 commit b56c810

File tree

2 files changed

+55
-39
lines changed

2 files changed

+55
-39
lines changed

opensquirrel/ir.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -214,28 +214,8 @@ def __init__(self, *axis: AxisLike) -> None:
214214
axis: An ``AxisLike`` to create the axis from.
215215
"""
216216
axis_to_parse = axis[0] if len(axis) == 1 else cast(AxisLike, axis)
217-
if not self.has_valid_value(axis_to_parse):
218-
raise ValueError
219-
self._value = self._parse_and_validate_axislike(axis_to_parse)
220-
221-
@staticmethod
222-
def has_valid_value(*axis: AxisLike) -> bool:
223-
"""Check if the axis has a valid value.
224-
225-
Args:
226-
axis: An ``AxisLike`` to validate.
227-
228-
Returns:
229-
True if the axis is valid, False otherwise.
230-
"""
231-
try:
232-
axis_to_check = axis[0] if len(axis) == 1 else cast(AxisLike, axis)
233-
axis_array = np.asarray(axis_to_check, dtype=float).flatten()
234-
if len(axis_array) != 3:
235-
return False
236-
return not np.all(axis_array == 0)
237-
except (ValueError, TypeError):
238-
return False
217+
self._value = self.parse(axis_to_parse)
218+
self._value = self.normalize(self._value)
239219

240220
@property
241221
def value(self) -> NDArray[np.float64]:
@@ -249,10 +229,11 @@ def value(self, axis: AxisLike) -> None:
249229
Args:
250230
axis: An ``AxisLike`` to create the axis from.
251231
"""
252-
self._value = self._parse_and_validate_axislike(axis)
232+
self._value = self.parse(axis)
233+
self._value = self.normalize(self._value)
253234

254235
@classmethod
255-
def _parse_and_validate_axislike(cls, axis: AxisLike) -> NDArray[np.float64]:
236+
def parse(cls, axis: AxisLike) -> NDArray[np.float64]:
256237
"""Parse and validate an ``AxisLike``.
257238
258239
Check if the `axis` can be cast to a 1DArray of length 3, raise an error otherwise.
@@ -276,10 +257,13 @@ def _parse_and_validate_axislike(cls, axis: AxisLike) -> NDArray[np.float64]:
276257
if len(axis) != 3:
277258
msg = f"axis requires an ArrayLike of length 3, but received an ArrayLike of length {len(axis)}"
278259
raise ValueError(msg)
279-
return cls._normalize_axis(axis)
260+
if np.any(axis != 0).item():
261+
return axis
262+
msg = "axis requires at least one element to be non-zero"
263+
raise ValueError(msg)
280264

281265
@staticmethod
282-
def _normalize_axis(axis: NDArray[np.float64]) -> NDArray[np.float64]:
266+
def normalize(axis: NDArray[np.float64]) -> NDArray[np.float64]:
283267
"""Normalize a NDArray.
284268
285269
Args:

test/test_ir.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import pytest
9-
from numpy.typing import ArrayLike
9+
from numpy.typing import NDArray
1010

1111
from opensquirrel.common import ATOL
1212
from opensquirrel.ir import (
@@ -44,7 +44,7 @@ def test_axis_getter(self, axis: Axis) -> None:
4444
(Axis(0, 1, 0), [0, 1, 0]),
4545
],
4646
)
47-
def test_axis_setter_no_error(self, axis: Axis, new_axis: ArrayLike, expected_axis: ArrayLike) -> None:
47+
def test_axis_setter_no_error(self, axis: Axis, new_axis: AxisLike, expected_axis: list[float]) -> None:
4848
axis.value = new_axis # type: ignore[assignment]
4949
np.testing.assert_array_equal(axis, expected_axis)
5050

@@ -96,27 +96,59 @@ def test_eq_false(self, axis: Axis, other: Any) -> None:
9696
@pytest.mark.parametrize(
9797
("axis", "expected"),
9898
[
99-
([1, 0, 0], True),
100-
([0, 0, 0], False),
101-
([1, 2], False),
102-
([1, 2, 3, 4], False),
103-
([0, 1, 0], True),
99+
([1, 0, 0], np.array([1, 0, 0], dtype=np.float64)),
100+
([0, 0, 0], ValueError),
101+
([1, 2], ValueError),
102+
([1, 2, 3, 4], ValueError),
103+
([0, 1, 0], np.array([0, 1, 0], dtype=np.float64)),
104+
(["a", "b", "c"], TypeError),
104105
],
105106
)
106-
def test_has_valid_value(self, axis: AxisLike, expected: bool) -> None:
107-
assert Axis.has_valid_value(axis) == expected
107+
def test_constructor(self, axis: AxisLike, expected: Any) -> None:
108+
if isinstance(expected, type) and issubclass(expected, Exception):
109+
with pytest.raises(expected):
110+
Axis(axis)
111+
else:
112+
assert isinstance(expected, np.ndarray)
113+
obj = Axis(axis)
114+
np.testing.assert_array_equal(obj.value, expected)
108115

109116
@pytest.mark.parametrize(
110-
("axis", "expected_error"),
117+
("axis", "expected"),
111118
[
119+
([1, 0, 0], np.array([1, 0, 0], dtype=np.float64)),
112120
([0, 0, 0], ValueError),
113121
([1, 2], ValueError),
114122
([1, 2, 3, 4], ValueError),
123+
([0, 1, 0], np.array([0, 1, 0], dtype=np.float64)),
124+
(["a", "b", "c"], TypeError),
125+
],
126+
)
127+
def test_parser(self, axis: AxisLike, expected: Any) -> None:
128+
if isinstance(expected, type) and issubclass(expected, Exception):
129+
with pytest.raises(expected):
130+
Axis.parse(axis)
131+
else:
132+
assert isinstance(expected, np.ndarray)
133+
obj = Axis.parse(axis)
134+
np.testing.assert_array_equal(obj, expected)
135+
136+
@pytest.mark.parametrize(
137+
("axis", "expected"),
138+
[
139+
(np.array([1, 0, 0], dtype=np.float64), np.array([1, 0, 0], dtype=np.float64)),
140+
(np.array([0, 1, 0], dtype=np.float64), np.array([0, 1, 0], dtype=np.float64)),
141+
(np.array([0, 0, 1], dtype=np.float64), np.array([0, 0, 1], dtype=np.float64)),
142+
(
143+
np.array([1, 1, 1], dtype=np.float64),
144+
np.array([1 / np.sqrt(3), 1 / np.sqrt(3), 1 / np.sqrt(3)], dtype=np.float64),
145+
),
115146
],
116147
)
117-
def test_init_with_invalid_value(self, axis: AxisLike, expected_error: type[Exception]) -> None:
118-
with pytest.raises(expected_error):
119-
Axis(axis)
148+
def test_normalize(self, axis: AxisLike, expected: NDArray[np.float64]) -> None:
149+
obj = Axis.normalize(np.array(axis, dtype=np.float64))
150+
assert isinstance(expected, np.ndarray)
151+
np.testing.assert_array_almost_equal(obj, expected)
120152

121153

122154
class TestIR:

0 commit comments

Comments
 (0)