20
20
21
21
from jax import make_jaxpr , jit
22
22
from jax .core import ClosedJaxpr
23
- from jax .lax import fori_loop
23
+ from jax .lax import fori_loop , cond , while_loop
24
24
from jax ._src .linear_util import wrap_init
25
25
import jax .numpy as jnp
26
26
27
27
from catalyst .jax_primitives import (AbstractQreg , qinst_p , qmeasure_p ,
28
28
qextract_p , qinsert_p , while_p , cond_p , func_p )
29
29
30
30
from qrisp .jasp import (QuantumPrimitive , OperationPrimitive , AbstractQuantumCircuit , AbstractQubitArray ,
31
- AbstractQubit , eval_jaxpr , Jaspr , extract_invalues , insert_outvalues )
31
+ AbstractQubit , eval_jaxpr , Jaspr , extract_invalues , insert_outvalues , Measurement_p , get_qubit_p ,
32
+ get_size_p , delete_qubits_p )
32
33
33
34
34
35
# Name translator from Qrisp gate naming to Catalyst gate naming
@@ -98,6 +99,8 @@ def catalyst_eqn_evaluator(eqn, context_dic):
98
99
elif eqn .primitive .name == "jasp.delete_qubits" :
99
100
# Not available in Catalyst
100
101
context_dic [outvars [0 ]] = context_dic [invars [0 ]]
102
+ elif eqn .primitive .name == "jasp.reset" :
103
+ process_reset (eqn , context_dic )
101
104
elif isinstance (eqn .primitive , OperationPrimitive ):
102
105
process_op (eqn .primitive , invars , outvars , context_dic )
103
106
else :
@@ -417,7 +420,6 @@ def process_cond(eqn, context_dic):
417
420
else :
418
421
unflattened_outvalues .append (outvalues .pop (0 ))
419
422
420
-
421
423
insert_outvalues (eqn , context_dic , unflattened_outvalues )
422
424
423
425
@lru_cache (maxsize = int (1E5 ))
@@ -464,4 +466,53 @@ def process_pjit(eqn, context_dic):
464
466
unflattened_outvalues .append (outvalues .pop (0 ))
465
467
466
468
insert_outvalues (eqn , context_dic , unflattened_outvalues )
469
+
470
+ # Function to reset and delete a qubit array
471
+ def reset_qubit_array (abs_qc , qb_array ):
472
+ from qrisp .circuit import XGate
473
+
474
+ def body_func (arg_tuple ):
475
+
476
+ abs_qc , qb_array , i = arg_tuple
477
+
478
+ abs_qb = get_qubit_p .bind (qb_array , i )
479
+ abs_qc , meas_bl = Measurement_p .bind (abs_qc , abs_qb )
480
+
481
+ def true_fun (arg_tuple ):
482
+ abs_qc , qb = arg_tuple
483
+ abs_qc = OperationPrimitive (XGate ()).bind (abs_qc , qb )
484
+ return (abs_qc , qb )
485
+
486
+ def false_fun (arg_tuple ):
487
+ return arg_tuple
488
+
489
+ abs_qc , qb = cond (meas_bl , true_fun , false_fun , (abs_qc , abs_qb ))
490
+
491
+ i += 1
492
+
493
+ return (abs_qc , qb_array , i )
494
+
495
+ def cond_fun (arg_tuple ):
496
+ return arg_tuple [- 1 ] < get_size_p .bind (arg_tuple [1 ])
497
+
498
+
499
+ abs_qc , qb_array , i = while_loop (cond_fun ,
500
+ body_func ,
501
+ (abs_qc , qb_array , jnp .array (0 , dtype = jnp .int32 ))
502
+ )
503
+
504
+ abs_qc = delete_qubits_p .bind (abs_qc , qb_array )
505
+
506
+ return abs_qc
507
+
508
+ reset_jaxpr = make_jaxpr (reset_qubit_array )(AbstractQuantumCircuit (), AbstractQubitArray ())
509
+
510
+ def process_reset (eqn , context_dic ):
511
+
512
+ invalues = extract_invalues (eqn , context_dic )
513
+ outvalues = eval_jaxpr (reset_jaxpr .jaxpr , eqn_evaluator = catalyst_eqn_evaluator )(* invalues )
514
+ insert_outvalues (eqn , context_dic , outvalues )
515
+
516
+
517
+
467
518
0 commit comments