Skip to content

Commit

Permalink
Trotterization for commuting groups, bugfixes change_of_basis
Browse files Browse the repository at this point in the history
  • Loading branch information
renezander90 committed Nov 16, 2024
1 parent f3b457a commit 76fbe8f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 78 deletions.
39 changes: 23 additions & 16 deletions src/qrisp/operators/qubit/commutativity_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ def construct_change_of_basis(S):
S_reduced, independent_cols = gaussian_elimination_mod2(S, show_pivots=True)
k = len(independent_cols)

#S0 = np.vstack((z_matrix[:,independent_cols], x_matrix[:,independent_cols]))
S0 = S[:,independent_cols]

R0_inv = S_reduced[:k, :]

####################
Expand All @@ -187,11 +185,13 @@ def construct_change_of_basis(S):
# Find independent rows in X component of S0
S0X_reduced, independent_rows = gaussian_elimination_mod2(S0[-n:, :], type='column', show_pivots=True)

# Construnct S1 by applying a Hadamard (i.e., a swap) to the rows of S0 not in independent_rows
h_list = [i for i in range(n) if i not in independent_rows]
S1 = S0.copy()
for i in h_list:
S1[[i, n+i]] = S1[[n+i, i]]
h_list = []
# Construct S1 by applying a Hadamard (i.e., a swap) to the rows of S0 not in independent_rows
if len(independent_rows)<k:
h_list = [i for i in range(n) if i not in independent_rows]
for i in h_list:
S1[[i, n+i]] = S1[[n+i, i]]

# Find independent rows in X component of S1
S1X_reduced, independent_rows = gaussian_elimination_mod2(S1[-n:, :], type="column", show_pivots=True)
Expand All @@ -214,19 +214,26 @@ def construct_change_of_basis(S):
S2 = S1 @ R1 % 2

####################
# Step 3: Calculate S3: Basis extension
# Step 3: Calculate S3: Basis extension if n>k
####################

C = S2[:k, :]
D = S2[k:n, :]
F = S2[-(n-k):, :]
if n>k:

C = S2[:k, :]
D = S2[k:n, :]
F = S2[-(n-k):, :]

S3 = np.block([[C, np.transpose(D)],
[D, np.zeros((n-k,n-k), dtype=int)],
[np.eye(k, dtype=int), np.zeros((k,n-k), dtype=int)],
[F, np.eye(n-k, dtype=int)]])
R2_inv = np.block([[np.eye(k, dtype=int)],
[np.zeros((n-k,k), dtype=int)]])

else:

S3 = np.block([[C, np.transpose(D)],
[D, np.zeros((n-k,n-k), dtype=int)],
[np.eye(k, dtype=int), np.zeros((k,n-k), dtype=int)],
[F, np.eye(n-k, dtype=int)]])
R2_inv = np.block([[np.eye(k, dtype=int)],
[np.zeros((n-k,k), dtype=int)]])
S3 = S2
R2_inv = np.eye(n, dtype=int)

####################
# Step 4: Calculate S4
Expand Down
146 changes: 84 additions & 62 deletions src/qrisp/operators/qubit/qubit_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,59 +935,80 @@ def change_of_basis(self, qarg, method="commuting_qw"):
# Calculate S: Matrix where the colums correspond to the binary representation (Z/X) of the Pauli terms
x_vectors = []
z_vectors = []
coeffs = []
for term, coeff in self.terms_dict.items():
x_vector, z_vector = term.binary_representation(n)
x_vectors.append(x_vector)
z_vectors.append(z_vector)
coeffs.append(coeff)
x_matrix = np.stack(x_vectors, axis=1)
z_matrix = np.stack(z_vectors, axis=1)

S = np.vstack((z_matrix, x_matrix))
# Find qubits (rows) on which Pauli X,Y,Z operatos act
qb_indices = []
for k in range(n):
if not (np.all(x_matrix[k] == 0) and np.all(z_matrix[k] == 0)):
qb_indices.append(k)
m = len(qb_indices)

# Construct and apply change of basis
A, R_inv, h_list, s_list, perm = construct_change_of_basis(S)

def inv_graph_state(qarg):
for i in range(n):
for j in range(i):
if A[i,j]==1:
cz(qarg[perm[i]],qarg[perm[j]])
h(qarg[:n])

def change_of_basis(qarg):
for i in h_list:
h(qarg[perm[i]])
for i in s_list:
s(qarg[perm[i]])
inv_graph_state(qarg)

change_of_basis(qarg)

# Construct new QubitOperator
#
# Factor (-1) appears if S gate is applied to X, or Hadamard gate H is applied to Y:
# S^dagger X S = -Y
# S^dagger Y S = X
# S^dagger Z S = Z
# H X H = Z
# H Y H = -Y
# H Z H = X
# For the original Pauli terms this translates to: Factor (-1) appears if S gate is applied to Y, or Hadamard gate H is applied to Y. (No factor (-1) occurs if S*H is applied.)

s_vector = np.zeros(n, dtype=int)
s_vector[s_list] = 1
h_vector = np.zeros(n, dtype=int)
h_vector[h_list] = 1
sh_vector = s_vector + h_vector % 2
sign_vector = sh_vector @ (x_matrix*z_matrix) % 2

for index,z_vector in enumerate(R_inv.T):
new_factor_dict = {perm[i]:"Z" for i in range(n) if z_vector[i]==1}
new_factor_dicts.append(new_factor_dict)
prefactor = (-1)**sign_vector[index]
prefactors.append(prefactor)
if m==0:
new_factor_dicts = [{}]*self.len()
prefactors = [1]*self.len()
else:
S = np.vstack((z_matrix[qb_indices], x_matrix[qb_indices]))

# Construct and apply change of basis
A, R_inv, h_list, s_list, perm = construct_change_of_basis(S)

def inv_graph_state(qarg):
for i in range(m):
for j in range(i):
if A[i,j]==1:
cz(qarg[qb_indices[perm[i]]],qarg[qb_indices[perm[j]]])
for i in qb_indices:
h(qarg[i])

def change_of_basis(qarg):
for i in h_list:
h(qarg[qb_indices[i]])
for i in s_list:
s(qarg[qb_indices[perm[i]]])
inv_graph_state(qarg)

change_of_basis(qarg)

# Construct new QubitOperator
#
# Factor (-1) appears if S gate is applied to X, or Hadamard gate H is applied to Y:
# S^dagger X S = -Y
# S^dagger Y S = X
# S^dagger Z S = Z
# H X H = Z
# H Y H = -Y
# H Z H = X
# For the original Pauli terms this translates to: Factor (-1) appears if S gate is applied to Y, or Hadamard gate H is applied to Y. (No factor (-1) occurs if S*H is applied.)

s_vector = np.zeros(m, dtype=int)
s_vector[s_list] = 1
h_vector = np.zeros(m, dtype=int)
h_vector[h_list] = 1
sh_vector = s_vector[perm] + h_vector % 2
sign_vector = sh_vector @ (x_matrix[qb_indices]*z_matrix[qb_indices]) % 2

# Lower triangular part of A
A_low = np.tril(A)

for index,z_vector in enumerate(R_inv.T):

# Count the number of rows of the square submatrix A defined by z_vector (rows/columns), such that the number of 1's in each row is odd
# This number is always even since A is a symmetric matrix with 0's on the diagonal
n1 = sum((z_vector @ A)*z_vector % 2)

# Calculate the paritiy of the sum of the numbers of 1's with position j>i for each row of the square submatrix A defined by z_vector (rows/columns)
n2 = sum((z_vector @ A_low)*z_vector) % 2

new_factor_dict = {qb_indices[perm[i]]:"Z" for i in range(m) if z_vector[i]==1}
new_factor_dicts.append(new_factor_dict)
prefactor = (-1)**sign_vector[index]*(-1)**(n1/2+n2)
prefactors.append(prefactor)

# Ladder operators
for term, coeff in self.terms_dict.items():
Expand Down Expand Up @@ -1317,7 +1338,7 @@ def get_measurement(
# Trotterization
#

def trotterization(self):
def trotterization(self, method='commuting_qw'):
r"""
Returns a function for performing Hamiltonian simulation, i.e., approximately implementing the unitary operator $e^{itH}$ via Trotterization.
Expand Down Expand Up @@ -1345,23 +1366,24 @@ def trotterization(self):
"""

def change_of_basis(qarg, terms_dict):
for index, factor in terms_dict.items():
if factor=="X":
h(qarg[index])
if factor=="Y":
s(qarg[index])
h(qarg[index])
x(qarg[index])

commuting_groups = self.group_up(lambda a, b: a.commute(b))

def trotter_step(qarg, t, steps):
for com_group in commuting_groups:
qw_groups, bases = com_group.commuting_qw_groups(show_bases=True)
for index,basis in enumerate(bases):
qw_group = qw_groups[index]
with conjugate(qw_group.change_of_basis)(qarg) as diagonal_operator:

if method=='commuting_qw':
def trotter_step(qarg, t, steps):
for com_group in commuting_groups:
qw_groups, bases = com_group.commuting_qw_groups(show_bases=True)
for index,basis in enumerate(bases):
qw_group = qw_groups[index]
with conjugate(qw_group.change_of_basis)(qarg) as diagonal_operator:
intersect_groups = diagonal_operator.group_up(lambda a, b: not a.intersect(b))
for intersect_group in intersect_groups:
for term,coeff in intersect_group.terms_dict.items():
term.simulate(coeff*t/steps, qarg, do_change_of_basis = False)

if method=='commuting':
def trotter_step(qarg, t, steps):
for com_group in commuting_groups:
with conjugate(com_group.change_of_basis)(qarg,method="commuting") as diagonal_operator:
intersect_groups = diagonal_operator.group_up(lambda a, b: not a.intersect(b))
for intersect_group in intersect_groups:
for term,coeff in intersect_group.terms_dict.items():
Expand Down

0 comments on commit 76fbe8f

Please sign in to comment.