7
7
from typing import Callable , List , Optional , Tuple
8
8
9
9
import numpy as np
10
-
11
10
from opensquirrel .common import ATOL , X , Y , Z , normalize_angle , normalize_axis
12
11
13
-
14
12
class SquirrelIRVisitor (ABC ):
15
13
def visit_comment (self , comment : "Comment" ):
16
14
pass
@@ -159,7 +157,6 @@ def is_identity(self) -> bool:
159
157
# Angle and phase are already normalized.
160
158
return abs (self .angle ) < ATOL and abs (self .phase ) < ATOL
161
159
162
-
163
160
class MatrixGate (Gate ):
164
161
generator : Optional [Callable [..., "MatrixGate" ]] = None
165
162
@@ -173,8 +170,10 @@ def __init__(self, matrix: np.ndarray, operands: List[Qubit], generator=None, ar
173
170
174
171
def __eq__ (self , other ):
175
172
# TODO: Determine whether we shall allow for a global phase difference here.
176
- if not isinstance (other , MatrixGate ):
173
+ if not isinstance (other , ControlledGate | MatrixGate ):
177
174
return False # FIXME: a MatrixGate can hide a ControlledGate. https://github.com/QuTech-Delft/OpenSquirrel/issues/88
175
+ if isinstance (other ,ControlledGate ):
176
+ return _compare_gate_classes (self ,other )
178
177
return np .allclose (self .matrix , other .matrix )
179
178
180
179
def __repr__ (self ):
@@ -200,8 +199,11 @@ def __init__(self, control_qubit: Qubit, target_gate: Gate, generator=None, argu
200
199
self .target_gate = target_gate
201
200
202
201
def __eq__ (self , other ):
203
- if not isinstance (other , ControlledGate ):
202
+
203
+ if not isinstance (other , ControlledGate | MatrixGate ):
204
204
return False # FIXME: a MatrixGate can hide a ControlledGate. https://github.com/QuTech-Delft/OpenSquirrel/issues/88
205
+ if isinstance (other , MatrixGate ):
206
+ return _compare_gate_classes (self ,other )
205
207
if self .control_qubit != other .control_qubit :
206
208
return False
207
209
@@ -220,7 +222,23 @@ def get_qubit_operands(self) -> List[Qubit]:
220
222
def is_identity (self ) -> bool :
221
223
return self .target_gate .is_identity ()
222
224
225
+ def _get_qubit_num (gate : ControlledGate | MatrixGate ) -> int :
226
+ operand = gate .get_qubit_operands ()
227
+ return max (qubit .index for qubit in operand ) + 1
228
+
229
+ def _compare_gate_classes (g1 : ControlledGate | MatrixGate , g2 : ControlledGate | MatrixGate ) -> bool :
230
+
231
+ from opensquirrel .utils .matrix_expander import get_matrix
232
+
233
+ matrix_g1 = get_matrix (g1 , _get_qubit_num (g1 ))
234
+ matrix_g2 = get_matrix (g2 , _get_qubit_num (g2 ))
235
+
236
+ if (matrix_g1 .shape != matrix_g2 .shape ):
237
+ return False
238
+
239
+ return np .allclose (matrix_g1 , matrix_g2 )
223
240
241
+
224
242
def named_gate (gate_generator : Callable [..., Gate ]) -> Callable [..., Gate ]:
225
243
@wraps (gate_generator )
226
244
def wrapper (* args , ** kwargs ):
0 commit comments