Skip to content

Commit

Permalink
implemented dynamic balauca mcx
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Jan 2, 2025
1 parent 4a37a77 commit bc22156
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 39 deletions.
75 changes: 70 additions & 5 deletions src/qrisp/alg_primitives/mcx_algs/balauca.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
********************************************************************************/
"""

from jax.core import Tracer
import numpy as np

from jax.core import Tracer
import jax.numpy as jnp
from jax.lax import cond
from jax import jit

from qrisp.circuit import XGate, PGate, convert_to_qb_list, Qubit
from qrisp.qtypes import QuantumBool, QuantumVariable
from qrisp.core.gate_application_functions import x, cx, mcx
from qrisp.alg_primitives.mcx_algs.circuit_library import reduced_maslov_qc, margolus_qc, reduced_margolus_qc
from qrisp.alg_primitives.mcx_algs.gidney import GidneyLogicalAND
from qrisp.alg_primitives.mcx_algs.jones import jones_toffoli
from qrisp.environments.quantum_inversion import invert
from qrisp.jasp import check_for_tracing_mode, AbstractQubit
from qrisp.environments import invert, control, conjugate
from qrisp.jasp import check_for_tracing_mode, AbstractQubit, qache, jrange, make_tracer

# Ancilla supported multi controlled X with logarithmic depth based on
# https://www.iccs-meeting.org/archive/iccs2022/papers/133530169.pdf
Expand Down Expand Up @@ -433,4 +437,65 @@ def margolus(control, target):
qubit_list[2 * i + 4],
)

[qbl.delete() for qbl in dirty_ancilla_qbls]
[qbl.delete() for qbl in dirty_ancilla_qbls]



@jit
def extract_boolean_digit(integer, digit):
return jnp.bool((integer>>digit & 1))


def ctrl_state_conjugator(ctrls, ctrl_state):
for i in jrange(ctrls.size):
with control(~extract_boolean_digit(ctrl_state, i)):
x(ctrls[i])


@qache
def jasp_balauca_mcx(ctrls, target, ctrl_state):

from qrisp import mcx
ctrl_state = jnp.int32(ctrl_state)
ctrl_state = cond(ctrl_state == -1, lambda x : x + 2**ctrls.size, lambda x : x, ctrl_state)

N = ctrls.size

with conjugate(ctrl_state_conjugator)(ctrls, ctrl_state):

with control(N == 1):
cx(ctrls[0], target[0])
with control(N == 2):
mcx([ctrls[0], ctrls[1]], target[0])
with control(N > 2):
balauca_anc = QuantumVariable(N-2+N%2)
with conjugate(jasp_balauca_helper)(ctrls, balauca_anc):
mcx([balauca_anc[balauca_anc.size-1], balauca_anc[balauca_anc.size-2]], target[0])
balauca_anc.delete()

def jasp_balauca_helper(ctrls, balauca_anc):
from qrisp import mcx

N = ctrls.size
n = jnp.int32(jnp.ceil(jnp.log2(N)))

for i in jrange(N//2):
mcx([ctrls[2*i], ctrls[2*i+1]], balauca_anc[i])

with control(jnp.bool(N%2)):
cx(ctrls[N-1], balauca_anc[N//2-1+N%2])

n = jnp.int32(jnp.ceil(jnp.log2(N)))

l = make_tracer(0)
k = N
for i in jrange(n-2):
k = jnp.int32(jnp.ceil(k/2))

for j in jrange(k//2):
mcx([balauca_anc[l+2*j], balauca_anc[l+2*j+1]],
balauca_anc[l+k+j],
method = "gidney")

l = cond(jnp.bool(k%2), lambda x : x-1, lambda x : x, l)
l += k
78 changes: 45 additions & 33 deletions src/qrisp/core/gate_application_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def append_operation(operation, qubits=[], clbits=[], param_tracers = []):
from qrisp import find_qs

qs = find_qs(qubits)

qs.append(operation, qubits, clbits, param_tracers = param_tracers)


Expand Down Expand Up @@ -468,50 +467,61 @@ def benchmark_mcx(n, methods):
from qrisp.core import QuantumVariable
from qrisp.qtypes import QuantumBool

new_controls = []

for qbl in controls:
if isinstance(qbl, QuantumBool):
new_controls.append(qbl[0])
else:
new_controls.append(qbl)

if isinstance(target, (list, QuantumVariable)):
if not check_for_tracing_mode():

if len(target) > 1:
raise Exception("Target of mcx contained more than one qubit")
target = target[0]
new_controls = []

for qbl in controls:
if isinstance(qbl, QuantumBool):
new_controls.append(qbl[0])
else:
new_controls.append(qbl)

if isinstance(target, (list, QuantumVariable)):

if len(target) > 1:
raise Exception("Target of mcx contained more than one qubit")
target = target[0]


qubits_0 = new_controls
qubits_1 = [target]

qubits_0 = new_controls
qubits_1 = [target]

n = len(qubits_0)

if n == 0:
return controls, target

if not isinstance(ctrl_state, str):
if ctrl_state == -1:
ctrl_state += 2**n
ctrl_state = bin_rep(ctrl_state, n)[::-1]
n = len(qubits_0)

if len(ctrl_state) != n:
raise Exception(
f"Given control state {ctrl_state} does not match control qubit amount {n}"
)
if n == 0:
return controls, target
elif n == 1:
append_operation(
std_ops.MCXGate(len(qubits_0), ctrl_state, method=method),
qubits_0 + qubits_1,
)
return

if not isinstance(ctrl_state, str):
if ctrl_state == -1:
ctrl_state += 2**n
ctrl_state = bin_rep(ctrl_state, n)[::-1]

if len(ctrl_state) != n:
raise Exception(
f"Given control state {ctrl_state} does not match control qubit amount {n}"
)
else:
qubits_0 = controls
qubits_1 = [target]
from qrisp.alg_primitives.mcx_algs import (
balauca_dirty,
balauca_mcx,
hybrid_mcx,
maslov_mcx,
yong_mcx,
jasp_balauca_mcx

)

if method in ["gray", "gray_pt", "gray_pt_inv"] or len(qubits_0) == 1:
if len(qubits_0) == 1:
method = "gray"
if method in ["gray", "gray_pt", "gray_pt_inv"]:
append_operation(
std_ops.MCXGate(len(qubits_0), ctrl_state, method=method),
qubits_0 + qubits_1,
Expand Down Expand Up @@ -552,7 +562,10 @@ def benchmark_mcx(n, methods):
[qv.delete() for qv in ancilla]

elif method == "balauca":
balauca_mcx(qubits_0, qubits_1, ctrl_state=ctrl_state)
if check_for_tracing_mode():
jasp_balauca_mcx(qubits_0, qubits_1, ctrl_state)
else:
balauca_mcx(qubits_0, qubits_1, ctrl_state=ctrl_state)

elif method == "balauca_dirty":
balauca_dirty(qubits_0, qubits_1, k=num_ancilla, ctrl_state=ctrl_state)
Expand Down Expand Up @@ -580,7 +593,6 @@ def benchmark_mcx(n, methods):
# return mcx(qubits_0, qubits_1, method = "maslov", ctrl_state = ctrl_state)
# else:
# return mcx(qubits_0, qubits_1, method = "balauca", ctrl_state = ctrl_state) # noqa:501

gate = std_ops.MCXGate(len(qubits_0), ctrl_state, method="auto")
append_operation(gate, qubits_0 + qubits_1)

Expand Down
3 changes: 3 additions & 0 deletions src/qrisp/core/session_merging_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import weakref
from jaxlib.xla_extension import ArrayImpl

# This module contains the necessary tools to merge QuantumSessions

Expand Down Expand Up @@ -353,6 +354,8 @@ def recursive_qs_search(input):
else:
input = list(input)
for i in range(len(input)):
if isinstance(input[i], ArrayImpl):
continue
result += recursive_qs_search(input[i])
else:
if isinstance(input, QuantumSession):
Expand Down
13 changes: 12 additions & 1 deletion src/qrisp/jasp/jasp_expression/centerclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,4 +1092,15 @@ def return_function(*args):
def check_aval_equivalence(invars_1, invars_2):
avals_1 = [invar.aval for invar in invars_1]
avals_2 = [invar.aval for invar in invars_2]
return all([type(avals_1[i]) == type(avals_2[i]) for i in range(len(avals_1))])
return all([type(avals_1[i]) == type(avals_2[i]) for i in range(len(avals_1))])

def make_tracer(x):
if isinstance(x, bool):
dtype = jnp.float32
elif isinstance(x, int):
dtype = jnp.int32
elif isinstance(x, float):
dtype = jnp.float32
elif isinstance(x, complex):
dtype = jnp.complex32
return jax.jit(lambda: jnp.array(x, dtype))()
41 changes: 41 additions & 0 deletions tests/jax_tests/test_dynamic_mcx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
\********************************************************************************
* Copyright (c) 2023 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

from qrisp import *
from qrisp.jasp import *

def test_dynamic_mcx():

@terminal_sampling
def main(i, j):
qf = QuantumFloat(i)
h(qf)
qbl = QuantumBool()
mcx(qf.reg, qbl[0], method = "balauca", ctrl_state = j)
return qf, qbl

for i in range(1, 5):
for j in range(2**i):
res_dict = main(i, j)

for k in res_dict.keys():
if k[0] == j:
assert k[1]
else:
assert not k[1]

0 comments on commit bc22156

Please sign in to comment.