From 739e1c906aeeab405ea717c2fc8d7e85f8cde765 Mon Sep 17 00:00:00 2001 From: positr0nium Date: Sun, 22 Dec 2024 20:18:50 +0100 Subject: [PATCH] fixed a bug in the quantum variable identification procedure of qache --- src/qrisp/core/quantum_variable.py | 2 -- src/qrisp/environments/conjugation_environment.py | 4 ++-- src/qrisp/jasp/qaching.py | 9 ++++++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/qrisp/core/quantum_variable.py b/src/qrisp/core/quantum_variable.py index f76016bb..60c816dd 100644 --- a/src/qrisp/core/quantum_variable.py +++ b/src/qrisp/core/quantum_variable.py @@ -1511,8 +1511,6 @@ def flatten_qv(qv): return tuple(children), (QuantumVariableIdentityContainer(qv),) def unflatten_qv(aux_data, children): - qs = TracingQuantumSession.get_instance() - qv = aux_data[0].qv qv.reg = DynamicQubitArray(children[0]) for i in range(len(qv.traced_attributes)): diff --git a/src/qrisp/environments/conjugation_environment.py b/src/qrisp/environments/conjugation_environment.py index 51d6fa19..6b7b4edf 100644 --- a/src/qrisp/environments/conjugation_environment.py +++ b/src/qrisp/environments/conjugation_environment.py @@ -149,7 +149,7 @@ def __enter__(self): if check_for_tracing_mode(): with QuantumEnvironment(): - res = qache(self.conjugation_function)(*self.args, **self.kwargs) + res = qache(self.conjugation_function)(*list(self.args), **self.kwargs) return res @@ -189,7 +189,7 @@ def __exit__(self, exception_type, exception_value, traceback): else: from qrisp.environments import invert with invert(): - qache(self.conjugation_function)(*self.args, **self.kwargs) + qache(self.conjugation_function)(*list(self.args), **self.kwargs) QuantumEnvironment.__exit__(self, exception_type, exception_value, traceback) diff --git a/src/qrisp/jasp/qaching.py b/src/qrisp/jasp/qaching.py index 565d918d..8b5182ed 100644 --- a/src/qrisp/jasp/qaching.py +++ b/src/qrisp/jasp/qaching.py @@ -17,6 +17,9 @@ """ import jax + +from qrisp.core import recursive_qv_search + from qrisp.jasp import TracingQuantumSession, check_for_tracing_mode def qache(*func, **kwargs): @@ -184,9 +187,9 @@ def return_function(*args, **kwargs): # tracers of the jit trace. To reverse this, we store the current tracers # by flattening each QuantumVariable in the signature. flattened_qvs = [] - for arg in args: - if isinstance(arg, QuantumVariable): - flattened_qvs.append(flatten_qv(arg)) + + for qv in recursive_qv_search(args): + flattened_qvs.append(flatten_qv(qv)) # Get the AbstractQuantumCircuit for tracing abs_qs = TracingQuantumSession.get_instance()