Skip to content

Commit

Permalink
fixed a bug in the quantum variable identification procedure of qache
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Dec 22, 2024
1 parent a7f93ac commit 739e1c9
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 0 additions & 2 deletions src/qrisp/core/quantum_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,8 +1511,6 @@ def flatten_qv(qv):
return tuple(children), (QuantumVariableIdentityContainer(qv),)

def unflatten_qv(aux_data, children):
qs = TracingQuantumSession.get_instance()

qv = aux_data[0].qv
qv.reg = DynamicQubitArray(children[0])
for i in range(len(qv.traced_attributes)):
Expand Down
4 changes: 2 additions & 2 deletions src/qrisp/environments/conjugation_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __enter__(self):

if check_for_tracing_mode():
with QuantumEnvironment():
res = qache(self.conjugation_function)(*self.args, **self.kwargs)
res = qache(self.conjugation_function)(*list(self.args), **self.kwargs)
return res


Expand Down Expand Up @@ -189,7 +189,7 @@ def __exit__(self, exception_type, exception_value, traceback):
else:
from qrisp.environments import invert
with invert():
qache(self.conjugation_function)(*self.args, **self.kwargs)
qache(self.conjugation_function)(*list(self.args), **self.kwargs)

QuantumEnvironment.__exit__(self, exception_type, exception_value, traceback)

Expand Down
9 changes: 6 additions & 3 deletions src/qrisp/jasp/qaching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"""

import jax

from qrisp.core import recursive_qv_search

from qrisp.jasp import TracingQuantumSession, check_for_tracing_mode

def qache(*func, **kwargs):
Expand Down Expand Up @@ -184,9 +187,9 @@ def return_function(*args, **kwargs):
# tracers of the jit trace. To reverse this, we store the current tracers
# by flattening each QuantumVariable in the signature.
flattened_qvs = []
for arg in args:
if isinstance(arg, QuantumVariable):
flattened_qvs.append(flatten_qv(arg))

for qv in recursive_qv_search(args):
flattened_qvs.append(flatten_qv(qv))

# Get the AbstractQuantumCircuit for tracing
abs_qs = TracingQuantumSession.get_instance()
Expand Down

0 comments on commit 739e1c9

Please sign in to comment.