Skip to content

Commit 1a95155

Browse files
committed
adapted RUS to use the garbage collection system
1 parent 325f9f2 commit 1a95155

File tree

1 file changed

+9
-24
lines changed

1 file changed

+9
-24
lines changed

src/qrisp/jasp/rus.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import jax.numpy as jnp
2222

2323
from qrisp.circuit import XGate
24-
from qrisp.jasp import TracingQuantumSession, AbstractQubitArray, DynamicQubitArray
25-
24+
from qrisp.jasp import TracingQuantumSession, AbstractQubitArray, DynamicQubitArray, qache
25+
from qrisp.jasp.primitives import Measurement_p, OperationPrimitive, get_qubit_p, get_size_p, delete_qubits_p, reset_p
2626

2727

2828
def RUS(trial_function):
@@ -110,31 +110,18 @@ def call_RUS_example():
110110

111111
def return_function(*trial_args):
112112

113-
114-
def ammended_function(*args):
115-
abs_qs = TracingQuantumSession.get_instance()
116-
initial_qv_list = list(abs_qs.qv_list)
117-
118-
res = trial_function(*args)
119-
120-
created_qvs = list(set(abs_qs.qv_list) - set(initial_qv_list))
121-
created_qvs.sort(key = lambda x : x.creation_time)
122-
123-
return res, tuple(created_qvs)
124-
125113
# Execute the function
126-
first_iter_res, created_qvs = ammended_function(*trial_args)
114+
first_iter_res = qache(trial_function)(*trial_args)
127115

128116
# Flatten the arguments and the res values
129117
arg_vals, arg_tree_def = jax.tree.flatten(trial_args)
130118
res_vals, res_tree_def = jax.tree.flatten(first_iter_res)
131-
created_qvs_vals, created_qvs_tree_def = jax.tree.flatten(created_qvs)
132119

133120

134121
# Extract the jaspr
135122
from qrisp.jasp import make_jaspr
136123

137-
ammended_trial_func_jaspr = make_jaspr(ammended_function)(*trial_args)
124+
ammended_trial_func_jaspr = make_jaspr(trial_function)(*trial_args)
138125
ammended_trial_func_jaspr = ammended_trial_func_jaspr.flatten_environments()
139126

140127

@@ -150,22 +137,21 @@ def ammended_function(*args):
150137
# And the final section are trial function arguments
151138

152139
abs_qs = TracingQuantumSession.get_instance()
153-
combined_args = tuple([abs_qs.abs_qc] + list(arg_vals) + list(res_vals) + list(created_qvs_vals))
140+
combined_args = tuple([abs_qs.abs_qc] + list(arg_vals) + list(res_vals))
154141

155142
n_res_vals = len(res_vals)
156143
n_arg_vals = len(arg_vals)
157-
n_created_qv_vals = len(created_qvs_vals)
158144

159145
def body_fun(args):
160146
# We now need to deallocate the AbstractQubitArrays from the previous
161147
# iteration since they are no longer needed.
162-
created_qvs_vals = args[-n_created_qv_vals:]
148+
res_qv_vals = args[-n_res_vals:]
163149

164150
abs_qc = args[0]
165-
for res_val in created_qvs_vals:
151+
for res_val in res_qv_vals:
166152
if isinstance(res_val.aval, AbstractQubitArray):
167-
# abs_qc = delete_qubits_p.bind(abs_qc, res_val)
168-
abs_qc = reset_qubit_array(abs_qc, res_val)
153+
abs_qc = reset_p.bind(abs_qc, res_val)
154+
abs_qc = delete_qubits_p.bind(abs_qc, res_val)
169155

170156
# Next we evaluate the trial function by evaluating the corresponding jaspr
171157
# Prepare the arguments tuple
@@ -226,7 +212,6 @@ def extract_boolean_digit(integer, digit):
226212
@jax.jit
227213
def reset_qubit_array(abs_qc, qb_array):
228214

229-
from qrisp.jasp.primitives import Measurement_p, OperationPrimitive, get_qubit_p, get_size_p, delete_qubits_p
230215

231216
def body_func(arg_tuple):
232217

0 commit comments

Comments
 (0)