From bc221569cde0edc61d08de6b53f6146c289fd341 Mon Sep 17 00:00:00 2001 From: positr0nium Date: Thu, 2 Jan 2025 20:13:16 +0100 Subject: [PATCH] implemented dynamic balauca mcx --- src/qrisp/alg_primitives/mcx_algs/balauca.py | 75 ++++++++++++++++-- src/qrisp/core/gate_application_functions.py | 78 +++++++++++-------- src/qrisp/core/session_merging_tools.py | 3 + src/qrisp/jasp/jasp_expression/centerclass.py | 13 +++- tests/jax_tests/test_dynamic_mcx.py | 41 ++++++++++ 5 files changed, 171 insertions(+), 39 deletions(-) create mode 100644 tests/jax_tests/test_dynamic_mcx.py diff --git a/src/qrisp/alg_primitives/mcx_algs/balauca.py b/src/qrisp/alg_primitives/mcx_algs/balauca.py index be043d92..06093400 100644 --- a/src/qrisp/alg_primitives/mcx_algs/balauca.py +++ b/src/qrisp/alg_primitives/mcx_algs/balauca.py @@ -16,16 +16,20 @@ ********************************************************************************/ """ -from jax.core import Tracer import numpy as np + +from jax.core import Tracer +import jax.numpy as jnp +from jax.lax import cond +from jax import jit + from qrisp.circuit import XGate, PGate, convert_to_qb_list, Qubit from qrisp.qtypes import QuantumBool, QuantumVariable from qrisp.core.gate_application_functions import x, cx, mcx from qrisp.alg_primitives.mcx_algs.circuit_library import reduced_maslov_qc, margolus_qc, reduced_margolus_qc from qrisp.alg_primitives.mcx_algs.gidney import GidneyLogicalAND -from qrisp.alg_primitives.mcx_algs.jones import jones_toffoli -from qrisp.environments.quantum_inversion import invert -from qrisp.jasp import check_for_tracing_mode, AbstractQubit +from qrisp.environments import invert, control, conjugate +from qrisp.jasp import check_for_tracing_mode, AbstractQubit, qache, jrange, make_tracer # Ancilla supported multi controlled X with logarithmic depth based on # https://www.iccs-meeting.org/archive/iccs2022/papers/133530169.pdf @@ -433,4 +437,65 @@ def margolus(control, target): qubit_list[2 * i + 4], ) - [qbl.delete() for qbl in dirty_ancilla_qbls] \ No newline at end of file + [qbl.delete() for qbl in dirty_ancilla_qbls] + + + +@jit +def extract_boolean_digit(integer, digit): + return jnp.bool((integer>>digit & 1)) + + +def ctrl_state_conjugator(ctrls, ctrl_state): + for i in jrange(ctrls.size): + with control(~extract_boolean_digit(ctrl_state, i)): + x(ctrls[i]) + + +@qache +def jasp_balauca_mcx(ctrls, target, ctrl_state): + + from qrisp import mcx + ctrl_state = jnp.int32(ctrl_state) + ctrl_state = cond(ctrl_state == -1, lambda x : x + 2**ctrls.size, lambda x : x, ctrl_state) + + N = ctrls.size + + with conjugate(ctrl_state_conjugator)(ctrls, ctrl_state): + + with control(N == 1): + cx(ctrls[0], target[0]) + with control(N == 2): + mcx([ctrls[0], ctrls[1]], target[0]) + with control(N > 2): + balauca_anc = QuantumVariable(N-2+N%2) + with conjugate(jasp_balauca_helper)(ctrls, balauca_anc): + mcx([balauca_anc[balauca_anc.size-1], balauca_anc[balauca_anc.size-2]], target[0]) + balauca_anc.delete() + +def jasp_balauca_helper(ctrls, balauca_anc): + from qrisp import mcx + + N = ctrls.size + n = jnp.int32(jnp.ceil(jnp.log2(N))) + + for i in jrange(N//2): + mcx([ctrls[2*i], ctrls[2*i+1]], balauca_anc[i]) + + with control(jnp.bool(N%2)): + cx(ctrls[N-1], balauca_anc[N//2-1+N%2]) + + n = jnp.int32(jnp.ceil(jnp.log2(N))) + + l = make_tracer(0) + k = N + for i in jrange(n-2): + k = jnp.int32(jnp.ceil(k/2)) + + for j in jrange(k//2): + mcx([balauca_anc[l+2*j], balauca_anc[l+2*j+1]], + balauca_anc[l+k+j], + method = "gidney") + + l = cond(jnp.bool(k%2), lambda x : x-1, lambda x : x, l) + l += k \ No newline at end of file diff --git a/src/qrisp/core/gate_application_functions.py b/src/qrisp/core/gate_application_functions.py index 5dc098ee..fe509223 100644 --- a/src/qrisp/core/gate_application_functions.py +++ b/src/qrisp/core/gate_application_functions.py @@ -25,7 +25,6 @@ def append_operation(operation, qubits=[], clbits=[], param_tracers = []): from qrisp import find_qs qs = find_qs(qubits) - qs.append(operation, qubits, clbits, param_tracers = param_tracers) @@ -468,50 +467,61 @@ def benchmark_mcx(n, methods): from qrisp.core import QuantumVariable from qrisp.qtypes import QuantumBool - new_controls = [] - for qbl in controls: - if isinstance(qbl, QuantumBool): - new_controls.append(qbl[0]) - else: - new_controls.append(qbl) - - if isinstance(target, (list, QuantumVariable)): + if not check_for_tracing_mode(): - if len(target) > 1: - raise Exception("Target of mcx contained more than one qubit") - target = target[0] + new_controls = [] + + for qbl in controls: + if isinstance(qbl, QuantumBool): + new_controls.append(qbl[0]) + else: + new_controls.append(qbl) + if isinstance(target, (list, QuantumVariable)): + + if len(target) > 1: + raise Exception("Target of mcx contained more than one qubit") + target = target[0] + + + qubits_0 = new_controls + qubits_1 = [target] - qubits_0 = new_controls - qubits_1 = [target] - - n = len(qubits_0) - - if n == 0: - return controls, target - - if not isinstance(ctrl_state, str): - if ctrl_state == -1: - ctrl_state += 2**n - ctrl_state = bin_rep(ctrl_state, n)[::-1] + n = len(qubits_0) - if len(ctrl_state) != n: - raise Exception( - f"Given control state {ctrl_state} does not match control qubit amount {n}" - ) + if n == 0: + return controls, target + elif n == 1: + append_operation( + std_ops.MCXGate(len(qubits_0), ctrl_state, method=method), + qubits_0 + qubits_1, + ) + return + + if not isinstance(ctrl_state, str): + if ctrl_state == -1: + ctrl_state += 2**n + ctrl_state = bin_rep(ctrl_state, n)[::-1] + if len(ctrl_state) != n: + raise Exception( + f"Given control state {ctrl_state} does not match control qubit amount {n}" + ) + else: + qubits_0 = controls + qubits_1 = [target] from qrisp.alg_primitives.mcx_algs import ( balauca_dirty, balauca_mcx, hybrid_mcx, maslov_mcx, yong_mcx, + jasp_balauca_mcx + ) - if method in ["gray", "gray_pt", "gray_pt_inv"] or len(qubits_0) == 1: - if len(qubits_0) == 1: - method = "gray" + if method in ["gray", "gray_pt", "gray_pt_inv"]: append_operation( std_ops.MCXGate(len(qubits_0), ctrl_state, method=method), qubits_0 + qubits_1, @@ -552,7 +562,10 @@ def benchmark_mcx(n, methods): [qv.delete() for qv in ancilla] elif method == "balauca": - balauca_mcx(qubits_0, qubits_1, ctrl_state=ctrl_state) + if check_for_tracing_mode(): + jasp_balauca_mcx(qubits_0, qubits_1, ctrl_state) + else: + balauca_mcx(qubits_0, qubits_1, ctrl_state=ctrl_state) elif method == "balauca_dirty": balauca_dirty(qubits_0, qubits_1, k=num_ancilla, ctrl_state=ctrl_state) @@ -580,7 +593,6 @@ def benchmark_mcx(n, methods): # return mcx(qubits_0, qubits_1, method = "maslov", ctrl_state = ctrl_state) # else: # return mcx(qubits_0, qubits_1, method = "balauca", ctrl_state = ctrl_state) # noqa:501 - gate = std_ops.MCXGate(len(qubits_0), ctrl_state, method="auto") append_operation(gate, qubits_0 + qubits_1) diff --git a/src/qrisp/core/session_merging_tools.py b/src/qrisp/core/session_merging_tools.py index 52af7ff4..59b5b7b5 100644 --- a/src/qrisp/core/session_merging_tools.py +++ b/src/qrisp/core/session_merging_tools.py @@ -17,6 +17,7 @@ """ import weakref +from jaxlib.xla_extension import ArrayImpl # This module contains the necessary tools to merge QuantumSessions @@ -353,6 +354,8 @@ def recursive_qs_search(input): else: input = list(input) for i in range(len(input)): + if isinstance(input[i], ArrayImpl): + continue result += recursive_qs_search(input[i]) else: if isinstance(input, QuantumSession): diff --git a/src/qrisp/jasp/jasp_expression/centerclass.py b/src/qrisp/jasp/jasp_expression/centerclass.py index 7ab4ddad..35715131 100644 --- a/src/qrisp/jasp/jasp_expression/centerclass.py +++ b/src/qrisp/jasp/jasp_expression/centerclass.py @@ -1092,4 +1092,15 @@ def return_function(*args): def check_aval_equivalence(invars_1, invars_2): avals_1 = [invar.aval for invar in invars_1] avals_2 = [invar.aval for invar in invars_2] - return all([type(avals_1[i]) == type(avals_2[i]) for i in range(len(avals_1))]) \ No newline at end of file + return all([type(avals_1[i]) == type(avals_2[i]) for i in range(len(avals_1))]) + +def make_tracer(x): + if isinstance(x, bool): + dtype = jnp.float32 + elif isinstance(x, int): + dtype = jnp.int32 + elif isinstance(x, float): + dtype = jnp.float32 + elif isinstance(x, complex): + dtype = jnp.complex32 + return jax.jit(lambda: jnp.array(x, dtype))() \ No newline at end of file diff --git a/tests/jax_tests/test_dynamic_mcx.py b/tests/jax_tests/test_dynamic_mcx.py new file mode 100644 index 00000000..e7c7e823 --- /dev/null +++ b/tests/jax_tests/test_dynamic_mcx.py @@ -0,0 +1,41 @@ +""" +\******************************************************************************** +* Copyright (c) 2023 the Qrisp authors +* +* This program and the accompanying materials are made available under the +* terms of the Eclipse Public License 2.0 which is available at +* http://www.eclipse.org/legal/epl-2.0. +* +* This Source Code may also be made available under the following Secondary +* Licenses when the conditions for such availability set forth in the Eclipse +* Public License, v. 2.0 are satisfied: GNU General Public License, version 2 +* with the GNU Classpath Exception which is +* available at https://www.gnu.org/software/classpath/license.html. +* +* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 +********************************************************************************/ +""" + +from qrisp import * +from qrisp.jasp import * + +def test_dynamic_mcx(): + + @terminal_sampling + def main(i, j): + qf = QuantumFloat(i) + h(qf) + qbl = QuantumBool() + mcx(qf.reg, qbl[0], method = "balauca", ctrl_state = j) + return qf, qbl + + for i in range(1, 5): + for j in range(2**i): + res_dict = main(i, j) + + for k in res_dict.keys(): + if k[0] == j: + assert k[1] + else: + assert not k[1] + \ No newline at end of file