|
6 | 6 |
|
7 | 7 | import numpy as np
|
8 | 8 | import pytest
|
9 |
| -from numpy.typing import ArrayLike |
| 9 | +from numpy.typing import NDArray |
10 | 10 |
|
11 | 11 | from opensquirrel.common import ATOL
|
12 | 12 | from opensquirrel.ir import (
|
@@ -44,7 +44,7 @@ def test_axis_getter(self, axis: Axis) -> None:
|
44 | 44 | (Axis(0, 1, 0), [0, 1, 0]),
|
45 | 45 | ],
|
46 | 46 | )
|
47 |
| - def test_axis_setter_no_error(self, axis: Axis, new_axis: ArrayLike, expected_axis: ArrayLike) -> None: |
| 47 | + def test_axis_setter_no_error(self, axis: Axis, new_axis: AxisLike, expected_axis: list[float]) -> None: |
48 | 48 | axis.value = new_axis # type: ignore[assignment]
|
49 | 49 | np.testing.assert_array_equal(axis, expected_axis)
|
50 | 50 |
|
@@ -96,27 +96,59 @@ def test_eq_false(self, axis: Axis, other: Any) -> None:
|
96 | 96 | @pytest.mark.parametrize(
|
97 | 97 | ("axis", "expected"),
|
98 | 98 | [
|
99 |
| - ([1, 0, 0], True), |
100 |
| - ([0, 0, 0], False), |
101 |
| - ([1, 2], False), |
102 |
| - ([1, 2, 3, 4], False), |
103 |
| - ([0, 1, 0], True), |
| 99 | + ([1, 0, 0], np.array([1, 0, 0], dtype=np.float64)), |
| 100 | + ([0, 0, 0], ValueError), |
| 101 | + ([1, 2], ValueError), |
| 102 | + ([1, 2, 3, 4], ValueError), |
| 103 | + ([0, 1, 0], np.array([0, 1, 0], dtype=np.float64)), |
| 104 | + (["a", "b", "c"], TypeError), |
104 | 105 | ],
|
105 | 106 | )
|
106 |
| - def test_has_valid_value(self, axis: AxisLike, expected: bool) -> None: |
107 |
| - assert Axis.has_valid_value(axis) == expected |
| 107 | + def test_constructor(self, axis: AxisLike, expected: Any) -> None: |
| 108 | + if isinstance(expected, type) and issubclass(expected, Exception): |
| 109 | + with pytest.raises(expected): |
| 110 | + Axis(axis) |
| 111 | + else: |
| 112 | + assert isinstance(expected, np.ndarray) |
| 113 | + obj = Axis(axis) |
| 114 | + np.testing.assert_array_equal(obj.value, expected) |
108 | 115 |
|
109 | 116 | @pytest.mark.parametrize(
|
110 |
| - ("axis", "expected_error"), |
| 117 | + ("axis", "expected"), |
111 | 118 | [
|
| 119 | + ([1, 0, 0], np.array([1, 0, 0], dtype=np.float64)), |
112 | 120 | ([0, 0, 0], ValueError),
|
113 | 121 | ([1, 2], ValueError),
|
114 | 122 | ([1, 2, 3, 4], ValueError),
|
| 123 | + ([0, 1, 0], np.array([0, 1, 0], dtype=np.float64)), |
| 124 | + (["a", "b", "c"], TypeError), |
| 125 | + ], |
| 126 | + ) |
| 127 | + def test_parser(self, axis: AxisLike, expected: Any) -> None: |
| 128 | + if isinstance(expected, type) and issubclass(expected, Exception): |
| 129 | + with pytest.raises(expected): |
| 130 | + Axis.parse(axis) |
| 131 | + else: |
| 132 | + assert isinstance(expected, np.ndarray) |
| 133 | + obj = Axis.parse(axis) |
| 134 | + np.testing.assert_array_equal(obj, expected) |
| 135 | + |
| 136 | + @pytest.mark.parametrize( |
| 137 | + ("axis", "expected"), |
| 138 | + [ |
| 139 | + (np.array([1, 0, 0], dtype=np.float64), np.array([1, 0, 0], dtype=np.float64)), |
| 140 | + (np.array([0, 1, 0], dtype=np.float64), np.array([0, 1, 0], dtype=np.float64)), |
| 141 | + (np.array([0, 0, 1], dtype=np.float64), np.array([0, 0, 1], dtype=np.float64)), |
| 142 | + ( |
| 143 | + np.array([1, 1, 1], dtype=np.float64), |
| 144 | + np.array([1 / np.sqrt(3), 1 / np.sqrt(3), 1 / np.sqrt(3)], dtype=np.float64), |
| 145 | + ), |
115 | 146 | ],
|
116 | 147 | )
|
117 |
| - def test_init_with_invalid_value(self, axis: AxisLike, expected_error: type[Exception]) -> None: |
118 |
| - with pytest.raises(expected_error): |
119 |
| - Axis(axis) |
| 148 | + def test_normalize(self, axis: AxisLike, expected: NDArray[np.float64]) -> None: |
| 149 | + obj = Axis.normalize(np.array(axis, dtype=np.float64)) |
| 150 | + assert isinstance(expected, np.ndarray) |
| 151 | + np.testing.assert_array_almost_equal(obj, expected) |
120 | 152 |
|
121 | 153 |
|
122 | 154 | class TestIR:
|
|
0 commit comments