Skip to content

Commit a4b72f0

Browse files
committed
implemented reset for catalyst backend
1 parent d281b86 commit a4b72f0

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

src/qrisp/jasp/interpreter_tools/interpreters/catalyst_interpreter.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@
2020

2121
from jax import make_jaxpr, jit
2222
from jax.core import ClosedJaxpr
23-
from jax.lax import fori_loop
23+
from jax.lax import fori_loop, cond, while_loop
2424
from jax._src.linear_util import wrap_init
2525
import jax.numpy as jnp
2626

2727
from catalyst.jax_primitives import (AbstractQreg, qinst_p, qmeasure_p,
2828
qextract_p, qinsert_p, while_p, cond_p, func_p)
2929

3030
from qrisp.jasp import (QuantumPrimitive, OperationPrimitive, AbstractQuantumCircuit, AbstractQubitArray,
31-
AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues)
31+
AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues, Measurement_p, get_qubit_p,
32+
get_size_p, delete_qubits_p)
3233

3334

3435
# Name translator from Qrisp gate naming to Catalyst gate naming
@@ -98,6 +99,8 @@ def catalyst_eqn_evaluator(eqn, context_dic):
9899
elif eqn.primitive.name == "jasp.delete_qubits":
99100
# Not available in Catalyst
100101
context_dic[outvars[0]] = context_dic[invars[0]]
102+
elif eqn.primitive.name == "jasp.reset":
103+
process_reset(eqn, context_dic)
101104
elif isinstance(eqn.primitive, OperationPrimitive):
102105
process_op(eqn.primitive, invars, outvars, context_dic)
103106
else:
@@ -417,7 +420,6 @@ def process_cond(eqn, context_dic):
417420
else:
418421
unflattened_outvalues.append(outvalues.pop(0))
419422

420-
421423
insert_outvalues(eqn, context_dic, unflattened_outvalues)
422424

423425
@lru_cache(maxsize = int(1E5))
@@ -464,4 +466,53 @@ def process_pjit(eqn, context_dic):
464466
unflattened_outvalues.append(outvalues.pop(0))
465467

466468
insert_outvalues(eqn, context_dic, unflattened_outvalues)
469+
470+
# Function to reset and delete a qubit array
471+
def reset_qubit_array(abs_qc, qb_array):
472+
from qrisp.circuit import XGate
473+
474+
def body_func(arg_tuple):
475+
476+
abs_qc, qb_array, i = arg_tuple
477+
478+
abs_qb = get_qubit_p.bind(qb_array, i)
479+
abs_qc, meas_bl = Measurement_p.bind(abs_qc, abs_qb)
480+
481+
def true_fun(arg_tuple):
482+
abs_qc, qb = arg_tuple
483+
abs_qc = OperationPrimitive(XGate()).bind(abs_qc, qb)
484+
return (abs_qc, qb)
485+
486+
def false_fun(arg_tuple):
487+
return arg_tuple
488+
489+
abs_qc, qb = cond(meas_bl, true_fun, false_fun, (abs_qc, abs_qb))
490+
491+
i += 1
492+
493+
return (abs_qc, qb_array, i)
494+
495+
def cond_fun(arg_tuple):
496+
return arg_tuple[-1] < get_size_p.bind(arg_tuple[1])
497+
498+
499+
abs_qc, qb_array, i = while_loop(cond_fun,
500+
body_func,
501+
(abs_qc, qb_array, jnp.array(0, dtype = jnp.int32))
502+
)
503+
504+
abs_qc = delete_qubits_p.bind(abs_qc, qb_array)
505+
506+
return abs_qc
507+
508+
reset_jaxpr = make_jaxpr(reset_qubit_array)(AbstractQuantumCircuit(), AbstractQubitArray())
509+
510+
def process_reset(eqn, context_dic):
511+
512+
invalues = extract_invalues(eqn, context_dic)
513+
outvalues = eval_jaxpr(reset_jaxpr.jaxpr, eqn_evaluator = catalyst_eqn_evaluator)(*invalues)
514+
insert_outvalues(eqn, context_dic, outvalues)
515+
516+
517+
467518

src/qrisp/jasp/rus.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def extract_boolean_digit(integer, digit):
212212
@jax.jit
213213
def reset_qubit_array(abs_qc, qb_array):
214214

215-
216215
def body_func(arg_tuple):
217216

218217
abs_qc, qb_array, i = arg_tuple

0 commit comments

Comments
 (0)