Skip to content

Commit 259a6e7

Browse files
committed
Added check for hybrid Control/Matrix Gate classes + test
1 parent 5af3ae0 commit 259a6e7

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

opensquirrel/squirrel_ir.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from typing import Callable, List, Optional, Tuple
88

99
import numpy as np
10-
1110
from opensquirrel.common import ATOL, X, Y, Z, normalize_angle, normalize_axis
1211

13-
1412
class SquirrelIRVisitor(ABC):
1513
def visit_comment(self, comment: "Comment"):
1614
pass
@@ -159,7 +157,6 @@ def is_identity(self) -> bool:
159157
# Angle and phase are already normalized.
160158
return abs(self.angle) < ATOL and abs(self.phase) < ATOL
161159

162-
163160
class MatrixGate(Gate):
164161
generator: Optional[Callable[..., "MatrixGate"]] = None
165162

@@ -173,8 +170,10 @@ def __init__(self, matrix: np.ndarray, operands: List[Qubit], generator=None, ar
173170

174171
def __eq__(self, other):
175172
# TODO: Determine whether we shall allow for a global phase difference here.
176-
if not isinstance(other, MatrixGate):
173+
if not isinstance(other, ControlledGate | MatrixGate):
177174
return False # FIXME: a MatrixGate can hide a ControlledGate. https://github.com/QuTech-Delft/OpenSquirrel/issues/88
175+
if isinstance(other,ControlledGate):
176+
return _compare_gate_classes(self,other)
178177
return np.allclose(self.matrix, other.matrix)
179178

180179
def __repr__(self):
@@ -200,8 +199,11 @@ def __init__(self, control_qubit: Qubit, target_gate: Gate, generator=None, argu
200199
self.target_gate = target_gate
201200

202201
def __eq__(self, other):
203-
if not isinstance(other, ControlledGate):
202+
203+
if not isinstance(other, ControlledGate | MatrixGate):
204204
return False # FIXME: a MatrixGate can hide a ControlledGate. https://github.com/QuTech-Delft/OpenSquirrel/issues/88
205+
if isinstance(other, MatrixGate):
206+
return _compare_gate_classes(self,other)
205207
if self.control_qubit != other.control_qubit:
206208
return False
207209

@@ -220,7 +222,23 @@ def get_qubit_operands(self) -> List[Qubit]:
220222
def is_identity(self) -> bool:
221223
return self.target_gate.is_identity()
222224

225+
def _get_qubit_num(gate: ControlledGate | MatrixGate) -> int:
226+
operand = gate.get_qubit_operands()
227+
return max(qubit.index for qubit in operand) + 1
228+
229+
def _compare_gate_classes(g1: ControlledGate | MatrixGate, g2: ControlledGate | MatrixGate) -> bool:
230+
231+
from opensquirrel.utils.matrix_expander import get_matrix
232+
233+
matrix_g1 = get_matrix(g1, _get_qubit_num(g1))
234+
matrix_g2 = get_matrix(g2, _get_qubit_num(g2))
235+
236+
if(matrix_g1.shape != matrix_g2.shape):
237+
return False
238+
239+
return np.allclose(matrix_g1, matrix_g2)
223240

241+
224242
def named_gate(gate_generator: Callable[..., Gate]) -> Callable[..., Gate]:
225243
@wraps(gate_generator)
226244
def wrapper(*args, **kwargs):

opensquirrel/utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from opensquirrel.utils.matrix_expander import get_matrix
2+
3+
__all__ = [
4+
"get_matrix",
5+
]

test/utils/test_matrix_expander.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,23 @@ def test_matrix_gate(self):
7575
),
7676
)
7777
)
78+
79+
80+
def test_gate_equality(self):
81+
cnot_matrix_gate = MatrixGate(
82+
np.array(
83+
[
84+
[1, 0, 0, 0],
85+
[0, 1, 0, 0],
86+
[0, 0, 0, 1],
87+
[0, 0, 1, 0],
88+
]
89+
),
90+
operands=[Qubit(1), Qubit(2)]
91+
)
92+
93+
cnot_controlled_gate = ControlledGate(
94+
Qubit(1), BlochSphereRotation(qubit=Qubit(2), axis=(1, 0, 0), angle=math.pi, phase=math.pi / 2)
95+
)
96+
97+
self.assertTrue(cnot_controlled_gate == cnot_matrix_gate)

0 commit comments

Comments
 (0)