21
21
import jax .numpy as jnp
22
22
23
23
from qrisp .circuit import XGate
24
- from qrisp .jasp import TracingQuantumSession , AbstractQubitArray , DynamicQubitArray
25
-
24
+ from qrisp .jasp import TracingQuantumSession , AbstractQubitArray , DynamicQubitArray , qache
25
+ from qrisp . jasp . primitives import Measurement_p , OperationPrimitive , get_qubit_p , get_size_p , delete_qubits_p , reset_p
26
26
27
27
28
28
def RUS (trial_function ):
@@ -110,31 +110,18 @@ def call_RUS_example():
110
110
111
111
def return_function (* trial_args ):
112
112
113
-
114
- def ammended_function (* args ):
115
- abs_qs = TracingQuantumSession .get_instance ()
116
- initial_qv_list = list (abs_qs .qv_list )
117
-
118
- res = trial_function (* args )
119
-
120
- created_qvs = list (set (abs_qs .qv_list ) - set (initial_qv_list ))
121
- created_qvs .sort (key = lambda x : x .creation_time )
122
-
123
- return res , tuple (created_qvs )
124
-
125
113
# Execute the function
126
- first_iter_res , created_qvs = ammended_function (* trial_args )
114
+ first_iter_res = qache ( trial_function ) (* trial_args )
127
115
128
116
# Flatten the arguments and the res values
129
117
arg_vals , arg_tree_def = jax .tree .flatten (trial_args )
130
118
res_vals , res_tree_def = jax .tree .flatten (first_iter_res )
131
- created_qvs_vals , created_qvs_tree_def = jax .tree .flatten (created_qvs )
132
119
133
120
134
121
# Extract the jaspr
135
122
from qrisp .jasp import make_jaspr
136
123
137
- ammended_trial_func_jaspr = make_jaspr (ammended_function )(* trial_args )
124
+ ammended_trial_func_jaspr = make_jaspr (trial_function )(* trial_args )
138
125
ammended_trial_func_jaspr = ammended_trial_func_jaspr .flatten_environments ()
139
126
140
127
@@ -150,22 +137,21 @@ def ammended_function(*args):
150
137
# And the final section are trial function arguments
151
138
152
139
abs_qs = TracingQuantumSession .get_instance ()
153
- combined_args = tuple ([abs_qs .abs_qc ] + list (arg_vals ) + list (res_vals ) + list ( created_qvs_vals ) )
140
+ combined_args = tuple ([abs_qs .abs_qc ] + list (arg_vals ) + list (res_vals ))
154
141
155
142
n_res_vals = len (res_vals )
156
143
n_arg_vals = len (arg_vals )
157
- n_created_qv_vals = len (created_qvs_vals )
158
144
159
145
def body_fun (args ):
160
146
# We now need to deallocate the AbstractQubitArrays from the previous
161
147
# iteration since they are no longer needed.
162
- created_qvs_vals = args [- n_created_qv_vals :]
148
+ res_qv_vals = args [- n_res_vals :]
163
149
164
150
abs_qc = args [0 ]
165
- for res_val in created_qvs_vals :
151
+ for res_val in res_qv_vals :
166
152
if isinstance (res_val .aval , AbstractQubitArray ):
167
- # abs_qc = delete_qubits_p .bind(abs_qc, res_val)
168
- abs_qc = reset_qubit_array (abs_qc , res_val )
153
+ abs_qc = reset_p .bind (abs_qc , res_val )
154
+ abs_qc = delete_qubits_p . bind (abs_qc , res_val )
169
155
170
156
# Next we evaluate the trial function by evaluating the corresponding jaspr
171
157
# Prepare the arguments tuple
@@ -226,7 +212,6 @@ def extract_boolean_digit(integer, digit):
226
212
@jax .jit
227
213
def reset_qubit_array (abs_qc , qb_array ):
228
214
229
- from qrisp .jasp .primitives import Measurement_p , OperationPrimitive , get_qubit_p , get_size_p , delete_qubits_p
230
215
231
216
def body_func (arg_tuple ):
232
217
0 commit comments