Skip to content

Commit 4a98e09

Browse files
authored
Merge pull request #406 from QuTech-Delft/CQT-289-Avoid-creating-an-Axis-0-0-0
[CQT-289] Avoid creating an Axis (0, 0, 0)
2 parents 231268c + 17e4c5f commit 4a98e09

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

opensquirrel/ir.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def __init__(self, *axis: AxisLike) -> None:
214214
axis: An ``AxisLike`` to create the axis from.
215215
"""
216216
axis_to_parse = axis[0] if len(axis) == 1 else cast(AxisLike, axis)
217-
self._value = self._parse_and_validate_axislike(axis_to_parse)
217+
self._value = self.normalize(self.parse(axis_to_parse))
218218

219219
@property
220220
def value(self) -> NDArray[np.float64]:
@@ -228,10 +228,11 @@ def value(self, axis: AxisLike) -> None:
228228
Args:
229229
axis: An ``AxisLike`` to create the axis from.
230230
"""
231-
self._value = self._parse_and_validate_axislike(axis)
231+
self._value = self.parse(axis)
232+
self._value = self.normalize(self._value)
232233

233-
@classmethod
234-
def _parse_and_validate_axislike(cls, axis: AxisLike) -> NDArray[np.float64]:
234+
@staticmethod
235+
def parse(axis: AxisLike) -> NDArray[np.float64]:
235236
"""Parse and validate an ``AxisLike``.
236237
237238
Check if the `axis` can be cast to a 1DArray of length 3, raise an error otherwise.
@@ -255,10 +256,13 @@ def _parse_and_validate_axislike(cls, axis: AxisLike) -> NDArray[np.float64]:
255256
if len(axis) != 3:
256257
msg = f"axis requires an ArrayLike of length 3, but received an ArrayLike of length {len(axis)}"
257258
raise ValueError(msg)
258-
return cls._normalize_axis(axis)
259+
if np.all(axis == 0):
260+
msg = "axis requires at least one element to be non-zero"
261+
raise ValueError(msg)
262+
return axis
259263

260264
@staticmethod
261-
def _normalize_axis(axis: NDArray[np.float64]) -> NDArray[np.float64]:
265+
def normalize(axis: NDArray[np.float64]) -> NDArray[np.float64]:
262266
"""Normalize a NDArray.
263267
264268
Args:

test/test_ir.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
import numpy as np
88
import pytest
9-
from numpy.typing import ArrayLike
9+
from numpy.typing import NDArray
1010

1111
from opensquirrel import I
1212
from opensquirrel.common import ATOL
1313
from opensquirrel.ir import (
1414
Axis,
15+
AxisLike,
1516
Bit,
1617
BlochSphereRotation,
1718
ControlledGate,
@@ -44,7 +45,7 @@ def test_axis_getter(self, axis: Axis) -> None:
4445
(Axis(0, 1, 0), [0, 1, 0]),
4546
],
4647
)
47-
def test_axis_setter_no_error(self, axis: Axis, new_axis: ArrayLike, expected_axis: ArrayLike) -> None:
48+
def test_axis_setter_no_error(self, axis: Axis, new_axis: AxisLike, expected_axis: list[float]) -> None:
4849
axis.value = new_axis # type: ignore[assignment]
4950
np.testing.assert_array_equal(axis, expected_axis)
5051

@@ -59,6 +60,7 @@ def test_axis_setter_no_error(self, axis: Axis, new_axis: ArrayLike, expected_ax
5960
ValueError,
6061
"axis requires an ArrayLike of length 3, but received an ArrayLike of length 4",
6162
),
63+
([0, 0, 0], ValueError, "axis requires at least one element to be non-zero"),
6264
],
6365
)
6466
def test_axis_setter_with_error(
@@ -93,6 +95,63 @@ def test_eq_true(self, axis: Axis, other: Any) -> None:
9395
def test_eq_false(self, axis: Axis, other: Any) -> None:
9496
assert axis != other
9597

98+
@pytest.mark.parametrize(
99+
("axis", "expected"),
100+
[
101+
([1, 0, 0], np.array([1, 0, 0], dtype=np.float64)),
102+
([0, 0, 0], ValueError),
103+
([1, 2], ValueError),
104+
([1, 2, 3, 4], ValueError),
105+
([0, 1, 0], np.array([0, 1, 0], dtype=np.float64)),
106+
(["a", "b", "c"], TypeError),
107+
],
108+
)
109+
def test_constructor(self, axis: AxisLike, expected: Any) -> None:
110+
if isinstance(expected, type) and issubclass(expected, Exception):
111+
with pytest.raises(expected):
112+
Axis(axis)
113+
else:
114+
assert isinstance(expected, np.ndarray)
115+
obj = Axis(axis)
116+
np.testing.assert_array_equal(obj.value, expected)
117+
118+
@pytest.mark.parametrize(
119+
("axis", "expected"),
120+
[
121+
([1, 0, 0], np.array([1, 0, 0], dtype=np.float64)),
122+
([0, 0, 0], ValueError),
123+
([1, 2], ValueError),
124+
([1, 2, 3, 4], ValueError),
125+
([0, 1, 0], np.array([0, 1, 0], dtype=np.float64)),
126+
(["a", "b", "c"], TypeError),
127+
],
128+
)
129+
def test_parse(self, axis: AxisLike, expected: Any) -> None:
130+
if isinstance(expected, type) and issubclass(expected, Exception):
131+
with pytest.raises(expected):
132+
Axis.parse(axis)
133+
else:
134+
assert isinstance(expected, np.ndarray)
135+
obj = Axis.parse(axis)
136+
np.testing.assert_array_equal(obj, expected)
137+
138+
@pytest.mark.parametrize(
139+
("axis", "expected"),
140+
[
141+
(np.array([1, 0, 0], dtype=np.float64), np.array([1, 0, 0], dtype=np.float64)),
142+
(np.array([0, 1, 0], dtype=np.float64), np.array([0, 1, 0], dtype=np.float64)),
143+
(np.array([0, 0, 1], dtype=np.float64), np.array([0, 0, 1], dtype=np.float64)),
144+
(
145+
np.array([1, 1, 1], dtype=np.float64),
146+
np.array([1 / np.sqrt(3), 1 / np.sqrt(3), 1 / np.sqrt(3)], dtype=np.float64),
147+
),
148+
],
149+
)
150+
def test_normalize(self, axis: AxisLike, expected: NDArray[np.float64]) -> None:
151+
obj = Axis.normalize(np.array(axis, dtype=np.float64))
152+
assert isinstance(expected, np.ndarray)
153+
np.testing.assert_array_almost_equal(obj, expected)
154+
96155

97156
class TestIR:
98157
def test_cnot_equality(self) -> None:

0 commit comments

Comments
 (0)