diff --git a/src/qrisp/core/quantum_variable.py b/src/qrisp/core/quantum_variable.py index f76016bb..60c816dd 100644 --- a/src/qrisp/core/quantum_variable.py +++ b/src/qrisp/core/quantum_variable.py @@ -1511,8 +1511,6 @@ def flatten_qv(qv): return tuple(children), (QuantumVariableIdentityContainer(qv),) def unflatten_qv(aux_data, children): - qs = TracingQuantumSession.get_instance() - qv = aux_data[0].qv qv.reg = DynamicQubitArray(children[0]) for i in range(len(qv.traced_attributes)): diff --git a/src/qrisp/environments/conjugation_environment.py b/src/qrisp/environments/conjugation_environment.py index 51d6fa19..6b7b4edf 100644 --- a/src/qrisp/environments/conjugation_environment.py +++ b/src/qrisp/environments/conjugation_environment.py @@ -149,7 +149,7 @@ def __enter__(self): if check_for_tracing_mode(): with QuantumEnvironment(): - res = qache(self.conjugation_function)(*self.args, **self.kwargs) + res = qache(self.conjugation_function)(*list(self.args), **self.kwargs) return res @@ -189,7 +189,7 @@ def __exit__(self, exception_type, exception_value, traceback): else: from qrisp.environments import invert with invert(): - qache(self.conjugation_function)(*self.args, **self.kwargs) + qache(self.conjugation_function)(*list(self.args), **self.kwargs) QuantumEnvironment.__exit__(self, exception_type, exception_value, traceback) diff --git a/src/qrisp/jasp/qaching.py b/src/qrisp/jasp/qaching.py index 565d918d..8b5182ed 100644 --- a/src/qrisp/jasp/qaching.py +++ b/src/qrisp/jasp/qaching.py @@ -17,6 +17,9 @@ """ import jax + +from qrisp.core import recursive_qv_search + from qrisp.jasp import TracingQuantumSession, check_for_tracing_mode def qache(*func, **kwargs): @@ -184,9 +187,9 @@ def return_function(*args, **kwargs): # tracers of the jit trace. To reverse this, we store the current tracers # by flattening each QuantumVariable in the signature. flattened_qvs = [] - for arg in args: - if isinstance(arg, QuantumVariable): - flattened_qvs.append(flatten_qv(arg)) + + for qv in recursive_qv_search(args): + flattened_qvs.append(flatten_qv(qv)) # Get the AbstractQuantumCircuit for tracing abs_qs = TracingQuantumSession.get_instance()