Skip to content

Commit 0183cf1

Browse files
committed
implemented dynamic mcp
1 parent f081ca4 commit 0183cf1

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

src/qrisp/alg_primitives/mcx_algs/balauca.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,12 +488,12 @@ def jasp_balauca_mcx(ctrls, target, ctrl_state):
488488
def jasp_balauca_mcp(phi, ctrls, ctrl_state):
489489

490490
from qrisp import mcx, QuantumBool, cp, p
491+
N = jlen(ctrls)
492+
491493
ctrl_state = jnp.int64(ctrl_state)
492-
ctrl_state = cond(ctrl_state == -1, lambda x : x + 2**ctrls.size, lambda x : x, ctrl_state)
494+
ctrl_state = cond(ctrl_state == -1, lambda x : x + 2**N, lambda x : x, ctrl_state)
493495
target = QuantumBool()
494496

495-
N = ctrls.size
496-
497497
with conjugate(ctrl_state_conjugator)(ctrls, ctrl_state):
498498

499499
with control(N == 1):

src/qrisp/core/gate_application_functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def mcp(phi, qubits, method="auto", ctrl_state=-1):
707707
708708
"""
709709

710-
from qrisp.alg_primitives.mcx_algs import hybrid_mcx
710+
from qrisp.alg_primitives.mcx_algs import hybrid_mcx, jasp_balauca_mcp
711711
from qrisp import QuantumBool
712712
from qrisp.misc import bin_rep, gate_wrap
713713
import numpy as np
@@ -726,6 +726,10 @@ def balauca_mcp(phi, qubits, ctrl_state):
726726

727727
temp.delete()
728728

729+
if check_for_tracing_mode():
730+
jasp_balauca_mcp(phi, qubits, ctrl_state)
731+
return
732+
729733
n = len(qubits)
730734

731735
if not isinstance(ctrl_state, str):

tests/jax_tests/test_dynamic_mcx.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,43 @@ def main(j):
5757
assert k[1]
5858
else:
5959
assert not k[1]
60+
61+
# Test dynamic mcp
62+
63+
@jaspify
64+
def main(phi, i):
65+
66+
qv = QuantumFloat(i)
67+
68+
x(qv[:qv.size-1])
69+
70+
with conjugate(h)(qv[qv.size-1]):
71+
mcp(phi, qv)
72+
73+
return measure(qv)
74+
75+
assert main(np.pi, 5) == 31
6076

61-
77+
@jaspify
78+
def main(phi, i, j):
79+
80+
qv = QuantumFloat(i)
81+
82+
with conjugate(h)(qv[qv.size-1]):
83+
mcp(phi, qv, ctrl_state = j)
84+
85+
return measure(qv)
86+
87+
assert main(np.pi, 5, 0) == 16
88+
89+
@jaspify
90+
def main(phi, i, j):
91+
92+
qv = QuantumFloat(i)
93+
94+
with conjugate(h)(qv[qv.size-1]):
95+
mcp(phi, [qv[i] for i in range(5)], ctrl_state = j)
96+
97+
return measure(qv)
98+
99+
assert main(np.pi, 5, 0) == 16

0 commit comments

Comments
 (0)