Skip to content

Commit 73b786d

Browse files
author
pablolh
committed
Fix matrix expander
Let's fix modules one by one. This commit was initially only intended to move this piece of code to its own folder, to add structure to the codebase, but it turned out I discovered a bug in it. - Fix untested bug where expanded matrix was wrong (not even unitary anymore) - Move matrix expander to a utils folder - Fix formatting - Add auxiliary functions - Add doctest and test
1 parent 38cb4e5 commit 73b786d

File tree

5 files changed

+215
-46
lines changed

5 files changed

+215
-46
lines changed

opensquirrel/matrix_expander.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

opensquirrel/test_interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from opensquirrel.common import ArgType
44
from opensquirrel.gates import querySemantic, querySignature
5-
from opensquirrel.matrix_expander import getBigMatrix
5+
from opensquirrel.utils.matrix_expander import get_expanded_matrix
66

77

88
class TestInterpreter:
@@ -24,7 +24,7 @@ def process(self, squirrelAST):
2424
semantic = querySemantic(
2525
self.gates, gateName, *[gateArgs[i] for i in range(len(gateArgs)) if signature[i] != ArgType.QUBIT]
2626
)
27-
bigMatrix = getBigMatrix(semantic, qubitOperands, totalQubits=squirrelAST.nQubits)
27+
bigMatrix = get_expanded_matrix(semantic, qubitOperands, total_qubits=squirrelAST.nQubits)
2828
totalUnitary = bigMatrix @ totalUnitary
2929

3030
return totalUnitary

opensquirrel/utils/__init__.py

Whitespace-only changes.

opensquirrel/utils/matrix_expander.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import math
2+
3+
import numpy as np
4+
5+
from opensquirrel.common import Can1
6+
from opensquirrel.gates import MultiQubitMatrixSemantic, Semantic, SingleQubitAxisAngleSemantic
7+
8+
9+
def extract_bits(x: int, bit_indices: [int]) -> int:
10+
"""
11+
Extract the bits of input at given indices, placing the bits in order from least to most significant.
12+
Equivalent to pext instruction.
13+
14+
Args:
15+
x: A non-negative integer from which one wants to extract the bits.
16+
bit_indices: The indices of the bits to extract, 0 being the least significant bit.
17+
18+
Returns:
19+
The extracted bits of x in order, as a non-negative integer.
20+
21+
Examples:
22+
>>> extract_bits(1, [0]) # 0b01
23+
1
24+
>>> extract_bits(1111, [2]) # 0b01
25+
1
26+
>>> extract_bits(1111, [5]) # 0b0
27+
0
28+
>>> extract_bits(1111, [2, 5]) # 0b01
29+
1
30+
>>> extract_bits(101, [1, 0]) # 0b10
31+
2
32+
>>> extract_bits(101, [0, 1]) # 0b01
33+
1
34+
"""
35+
result = 0
36+
for i, bit_index in enumerate(bit_indices):
37+
result |= ((x & (1 << bit_index)) >> bit_index) << i
38+
39+
return result
40+
41+
42+
def deposit_bits(x: int, bit_indices: [int]) -> int:
43+
"""
44+
Creates an integer whose bit values at given indices are taken from the bits of x, or 0 if they are not specified.
45+
Equivalently, takes the bits from x and places them at given indices in the result.
46+
Equivalent to pdep instruction.
47+
48+
Args:
49+
x: A non-negative integer giving the bit values.
50+
bit_indices: The indices where to deposit the bits of x in the result.
51+
52+
Returns:
53+
A bitstring whose bit values are taken from the bits of x.
54+
55+
Examples:
56+
>>> deposit_bits(0b0, [5]) # 0b000000
57+
0
58+
>>> deposit_bits(0b1, [5]) # 0b100000
59+
32
60+
>>> deposit_bits(0b000, [1, 2, 3]) # 0b0000
61+
0
62+
>>> deposit_bits(0b001, [1, 2, 3]) # 0b0010
63+
2
64+
>>> deposit_bits(0b011, [1, 2, 3]) # 0b0110
65+
6
66+
>>> deposit_bits(0b0101, [1, 2, 3]) # 0b1010
67+
10
68+
"""
69+
result = 0
70+
for i, bit_index in enumerate(bit_indices):
71+
result |= ((x & (1 << i)) >> i) << bit_index
72+
73+
return result
74+
75+
76+
def clear_bits(x: int, bit_indices: [int]) -> int:
77+
"""
78+
Clears given bits of input.
79+
80+
Args:
81+
x: A non-negative integer.
82+
bit_indices: Some bit indices to clear in x.
83+
84+
Returns:
85+
x with given bits reset to 0.
86+
87+
Examples:
88+
>>> clear_bits(0b1111, [1, 3]) # 0b0101
89+
5
90+
"""
91+
result = x
92+
for index in bit_indices:
93+
result &= ~(1 << index)
94+
95+
return result
96+
97+
98+
def get_expanded_matrix(semantic: Semantic, qubit_operands: [int], total_qubits: int) -> np.ndarray:
99+
"""
100+
Get the unitary matrix corresponding to the gate applied to those qubit operands.
101+
This can be used for, e.g.,
102+
- testing,
103+
- permuting the operands of a multi-qubit gates,
104+
- simulating a circuit (simulation in this way is inefficient for large numbers of qubits).
105+
106+
Args:
107+
semantic: The semantic of the gate.
108+
qubit_operands: The qubit indices on which the gate operates.
109+
total_qubits: The total number of qubits.
110+
111+
Example:
112+
>>> X = SingleQubitAxisAngleSemantic((1, 0, 0), math.pi, math.pi / 2)
113+
>>> get_expanded_matrix(X, [1], 2).astype(int) # X q[1]
114+
array([[0, 0, 1, 0],
115+
[0, 0, 0, 1],
116+
[1, 0, 0, 0],
117+
[0, 1, 0, 0]])
118+
119+
>>> CNOT = MultiQubitMatrixSemantic(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]))
120+
>>> get_expanded_matrix(CNOT, [0, 2], 3) # CNOT q[0], q[2]
121+
array([[1, 0, 0, 0, 0, 0, 0, 0],
122+
[0, 0, 0, 0, 0, 1, 0, 0],
123+
[0, 0, 1, 0, 0, 0, 0, 0],
124+
[0, 0, 0, 0, 0, 0, 0, 1],
125+
[0, 0, 0, 0, 1, 0, 0, 0],
126+
[0, 1, 0, 0, 0, 0, 0, 0],
127+
[0, 0, 0, 0, 0, 0, 1, 0],
128+
[0, 0, 0, 1, 0, 0, 0, 0]])
129+
>>> get_expanded_matrix(CNOT, [1, 2], 3) # CNOT q[1], q[2]
130+
array([[1, 0, 0, 0, 0, 0, 0, 0],
131+
[0, 1, 0, 0, 0, 0, 0, 0],
132+
[0, 0, 0, 0, 0, 0, 1, 0],
133+
[0, 0, 0, 0, 0, 0, 0, 1],
134+
[0, 0, 0, 0, 1, 0, 0, 0],
135+
[0, 0, 0, 0, 0, 1, 0, 0],
136+
[0, 0, 1, 0, 0, 0, 0, 0],
137+
[0, 0, 0, 1, 0, 0, 0, 0]])
138+
"""
139+
if isinstance(semantic, SingleQubitAxisAngleSemantic):
140+
assert len(qubit_operands) == 1
141+
142+
which_qubit = qubit_operands[0]
143+
144+
axis, angle, phase = semantic.axis, semantic.angle, semantic.phase
145+
result = np.kron(
146+
np.kron(np.eye(1 << (total_qubits - which_qubit - 1)), Can1(axis, angle, phase)), np.eye(1 << which_qubit)
147+
)
148+
assert result.shape == (1 << total_qubits, 1 << total_qubits)
149+
return result
150+
151+
assert isinstance(semantic, MultiQubitMatrixSemantic)
152+
153+
# The convention is to write gate matrices with operands reversed.
154+
# For instance, the first operand of CNOT is the control qubit, and this is written as
155+
# 1, 0, 0, 0
156+
# 0, 1, 0, 0
157+
# 0, 0, 0, 1
158+
# 0, 0, 1, 0
159+
# which corresponds to control being q[1] and target being q[0],
160+
# since qubit #i corresponds to the i-th LEAST significant bit.
161+
qubit_operands.reverse()
162+
163+
m = semantic.matrix
164+
165+
assert m.shape == (1 << len(qubit_operands), 1 << len(qubit_operands))
166+
167+
result = np.zeros((1 << total_qubits, 1 << total_qubits), dtype=m.dtype)
168+
169+
for input in range(1 << total_qubits):
170+
small_matrix_col_index = extract_bits(input, qubit_operands)
171+
172+
col = m[:, small_matrix_col_index]
173+
174+
for output, value in enumerate(col):
175+
large_output = clear_bits(input, qubit_operands)
176+
177+
large_output |= deposit_bits(output, qubit_operands)
178+
179+
result[large_output][input] = value
180+
181+
assert result.shape == (1 << total_qubits, 1 << total_qubits)
182+
return result

test/test_testinterpreter.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,37 @@ def test_hadamard_cnot(self):
198198
)
199199
)
200200

201+
def test_hadamard_cnot_0_2(self):
202+
circuit = Circuit.from_string(
203+
DefaultGates,
204+
r"""
205+
version 3.0
206+
qubit[3] q
207+
208+
h q[0]
209+
cnot q[0], q[2]
210+
""",
211+
)
212+
print(circuit.test_get_circuit_matrix())
213+
self.assertTrue(
214+
np.allclose(
215+
circuit.test_get_circuit_matrix(),
216+
math.sqrt(0.5)
217+
* np.array(
218+
[
219+
[1, 1, 0, 0, 0, 0, 0, 0],
220+
[0, 0, 0, 0, 1, -1, 0, 0],
221+
[0, 0, 1, 1, 0, 0, 0, 0],
222+
[0, 0, 0, 0, 0, 0, 1, -1],
223+
[0, 0, 0, 0, 1, 1, 0, 0],
224+
[1, -1, 0, 0, 0, 0, 0, 0],
225+
[0, 0, 0, 0, 0, 0, 1, 1],
226+
[0, 0, 1, -1, 0, 0, 0, 0],
227+
]
228+
),
229+
)
230+
)
231+
201232

202233
if __name__ == "__main__":
203234
unittest.main()

0 commit comments

Comments
 (0)