Skip to content

Commit f081ca4

Browse files
committed
fixed a bug that prevented proper compilation of balauca mcx in traced mode of static lists of controls
1 parent b5e751b commit f081ca4

File tree

5 files changed

+64
-12
lines changed

5 files changed

+64
-12
lines changed

src/qrisp/alg_primitives/mcx_algs/balauca.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from qrisp.alg_primitives.mcx_algs.circuit_library import reduced_maslov_qc, margolus_qc, reduced_margolus_qc
3030
from qrisp.alg_primitives.mcx_algs.gidney import GidneyLogicalAND
3131
from qrisp.environments import invert, control, conjugate
32-
from qrisp.jasp import check_for_tracing_mode, AbstractQubit, qache, jrange, make_tracer
32+
from qrisp.jasp import check_for_tracing_mode, AbstractQubit, qache, jrange, make_tracer, jlen
3333

3434
# Ancilla supported multi controlled X with logarithmic depth based on
3535
# https://www.iccs-meeting.org/archive/iccs2022/papers/133530169.pdf
@@ -447,19 +447,28 @@ def extract_boolean_digit(integer, digit):
447447

448448

449449
def ctrl_state_conjugator(ctrls, ctrl_state):
450-
for i in jrange(ctrls.size):
450+
451+
if isinstance(ctrls, list):
452+
xrange = range
453+
else:
454+
xrange = jrange
455+
456+
N = jlen(ctrls)
457+
458+
for i in xrange(N):
451459
with control(~extract_boolean_digit(ctrl_state, i)):
452460
x(ctrls[i])
453461

454462

455463
@qache
456464
def jasp_balauca_mcx(ctrls, target, ctrl_state):
457465

466+
N = jlen(ctrls)
467+
458468
from qrisp import mcx
459469
ctrl_state = jnp.int64(ctrl_state)
460-
ctrl_state = cond(ctrl_state == -1, lambda x : x + 2**ctrls.size, lambda x : x, ctrl_state)
470+
ctrl_state = cond(ctrl_state == -1, lambda x : x + 2**N, lambda x : x, ctrl_state)
461471

462-
N = ctrls.size
463472

464473
with conjugate(ctrl_state_conjugator)(ctrls, ctrl_state):
465474

@@ -506,23 +515,31 @@ def jasp_balauca_mcp(phi, ctrls, ctrl_state):
506515
def jasp_balauca_helper(ctrls, balauca_anc):
507516
from qrisp import mcx
508517

509-
N = ctrls.size
518+
if isinstance(ctrls, list):
519+
xrange = range
520+
import numpy as jnp
521+
else:
522+
xrange = jrange
523+
import jax.numpy as jnp
524+
525+
N = jlen(ctrls)
526+
510527
n = jnp.int64(jnp.ceil(jnp.log2(N)))
511528

512-
for i in jrange(N//2):
529+
for i in xrange(N//2):
513530
mcx([ctrls[2*i], ctrls[2*i+1]], balauca_anc[i])
514531

515-
with control(jnp.bool(N%2)):
532+
with control(N%2 != 0):
516533
cx(ctrls[N-1], balauca_anc[N//2-1+N%2])
517534

518535
n = jnp.int64(jnp.ceil(jnp.log2(N)))
519536

520537
l = make_tracer(0)
521538
k = N
522-
for i in jrange(n-2):
539+
for i in xrange(n-2):
523540
k = jnp.int64(jnp.ceil(k/2))
524541

525-
for j in jrange(k//2):
542+
for j in xrange(k//2):
526543
mcx([balauca_anc[l+2*j], balauca_anc[l+2*j+1]],
527544
balauca_anc[l+k+j],
528545
method = "gidney")

src/qrisp/environments/control_environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,9 @@ def control(*args, **kwargs):
474474
args[0] = [args[0]]
475475

476476
if check_for_tracing_mode():
477-
if all(isinstance(obj, AbstractQubit) for obj in [x.aval for x in args[0]]):
477+
if all(isinstance(obj, bool) for obj in [x for x in args[0]]):
478+
return ClControlEnvironment(*args, **kwargs)
479+
elif all(isinstance(obj, AbstractQubit) for obj in [x.aval for x in args[0]]):
478480
return ControlEnvironment(*args, **kwargs)
479481
elif all(isinstance(obj, ShapedArray) for obj in [x.aval for x in args[0]]):
480482
return ClControlEnvironment(*args, **kwargs)

src/qrisp/jasp/control_flow/jrange_iterator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,4 +295,11 @@ def make_tracer(x):
295295
def tracerizer():
296296
return jnp.array(x, dtype)
297297

298-
return jit(tracerizer)()
298+
return jit(tracerizer)()
299+
300+
def jlen(x):
301+
if isinstance(x, list):
302+
return len(x)
303+
else:
304+
return x.size
305+

src/qrisp/jasp/jasp_expression/centerclass.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,13 @@ def qjit(self, *args, function_name = "jaspr_function"):
516516

517517
from qrisp.jasp.catalyst_interface import jaspr_to_catalyst_qjit
518518
qjit_obj = jaspr_to_catalyst_qjit(flattened_jaspr, function_name = function_name)
519-
return qjit_obj.compiled_function(*args)
519+
res = qjit_obj.compiled_function(*args)
520+
if not isinstance(res, (tuple,list)):
521+
return res
522+
elif len(res) == 1:
523+
return res[0]
524+
else:
525+
return res
520526

521527
@classmethod
522528
@lru_cache(maxsize = int(1E5))

tests/jax_tests/test_dynamic_mcx.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,24 @@ def main(i, j):
3838
assert k[1]
3939
else:
4040
assert not k[1]
41+
42+
# Test static list behavior
43+
@terminal_sampling
44+
def main(j):
45+
qf = QuantumFloat(5)
46+
h(qf)
47+
qbl = QuantumBool()
48+
qb_list = [qf[i] for i in range(5)]
49+
mcx(qf.reg, qbl[0], method = "balauca", ctrl_state = j)
50+
return qf, qbl
51+
52+
for j in range(2**5):
53+
res_dict = main(j)
54+
55+
for k in res_dict.keys():
56+
if k[0] == j:
57+
assert k[1]
58+
else:
59+
assert not k[1]
60+
4161

0 commit comments

Comments
 (0)