From 807cc50aa35efb3a8000d3947f0f115a3512228c Mon Sep 17 00:00:00 2001 From: positr0nium Date: Thu, 27 Jun 2024 16:51:42 +0200 Subject: [PATCH] renamed AbstractQuantumSession to TracingQuantumSession --- src/qrisp/core/library.py | 4 ++-- src/qrisp/core/quantum_variable.py | 8 ++++---- src/qrisp/environments/quantum_environments.py | 6 +++--- src/qrisp/jax/qaching.py | 6 +++--- ...m_session.py => tracing_quantum_session.py} | 18 +++++++++--------- src/qrisp/misc/utility.py | 4 ++-- 6 files changed, 23 insertions(+), 23 deletions(-) rename src/qrisp/jax/{abstract_quantum_session.py => tracing_quantum_session.py} (89%) diff --git a/src/qrisp/core/library.py b/src/qrisp/core/library.py index cb1f1481..9f0ae77c 100644 --- a/src/qrisp/core/library.py +++ b/src/qrisp/core/library.py @@ -1044,10 +1044,10 @@ def measure(qubits, clbits=None): """ from qrisp import find_qs - from qrisp.jax import AbstractQuantumSession + from qrisp.jax import TracingQuantumSession qs = find_qs(qubits) - if not isinstance(qs, AbstractQuantumSession): + if not isinstance(qs, TracingQuantumSession): if clbits is None: clbits = [] if hasattr(qubits, "__len__"): diff --git a/src/qrisp/core/quantum_variable.py b/src/qrisp/core/quantum_variable.py index 59d945ec..8e7b30f9 100644 --- a/src/qrisp/core/quantum_variable.py +++ b/src/qrisp/core/quantum_variable.py @@ -234,10 +234,10 @@ def __init__(self, size, qs=None, name=None): # Store quantum session from qrisp.core import QuantumSession, merge_sessions - from qrisp.jax import check_for_tracing_mode, get_abstract_qs + from qrisp.jax import check_for_tracing_mode, get_tracing_qs if check_for_tracing_mode(): - self.qs = get_abstract_qs() + self.qs = get_tracing_qs() else: if qs is not None: self.qs = qs @@ -1453,7 +1453,7 @@ def plot_histogram(outcome_labels, counts, filename=None): from jax import tree_util -from qrisp.jax.abstract_quantum_session import get_abstract_qs +from qrisp.jax.abstract_quantum_session import get_tracing_qs from builtins import id @@ -1470,7 +1470,7 @@ def unflatten_qv(aux_data, children): res.reg = children[0] res.size = children[1] res.name = aux_data[1] - res.qs = get_abstract_qs() + res.qs = get_tracing_qs() return res diff --git a/src/qrisp/environments/quantum_environments.py b/src/qrisp/environments/quantum_environments.py index 73698fb7..4166d29f 100644 --- a/src/qrisp/environments/quantum_environments.py +++ b/src/qrisp/environments/quantum_environments.py @@ -56,7 +56,7 @@ from qrisp.circuit import QubitAlloc, QubitDealloc, fast_append from qrisp.core.quantum_session import QuantumSession -from qrisp.jax import QuantumPrimitive, AbstractQuantumCircuit, get_abstract_qs +from qrisp.jax import QuantumPrimitive, AbstractQuantumCircuit, get_tracing_qs class QuantumEnvironment(QuantumPrimitive): """ @@ -379,7 +379,7 @@ def stop_dumping(self): # Method to enter the environment def __enter__(self): - abs_qs = get_abstract_qs() + abs_qs = get_tracing_qs() if abs_qs is not None: abs_qs.abs_qc = self.bind(abs_qs.abs_qc, stage = "enter") return @@ -428,7 +428,7 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): - abs_qs = get_abstract_qs() + abs_qs = get_tracing_qs() if abs_qs is not None: abs_qs.abs_qc = self.bind(abs_qs.abs_qc, stage = "exit") return diff --git a/src/qrisp/jax/qaching.py b/src/qrisp/jax/qaching.py index 5d02ad42..6603a566 100644 --- a/src/qrisp/jax/qaching.py +++ b/src/qrisp/jax/qaching.py @@ -17,13 +17,13 @@ """ from jax import jit -from qrisp.jax import get_abstract_qs +from qrisp.jax import get_tracing_qs def qache(func): def ammended_function(abs_qc, *args): - qs = get_abstract_qs() + qs = get_tracing_qs() qs.abs_qc = abs_qc res = func(*args) @@ -36,7 +36,7 @@ def ammended_function(abs_qc, *args): def return_function(*args): - abs_qs = get_abstract_qs() + abs_qs = get_tracing_qs() abs_qc_new, res = ammended_function(abs_qs.abs_qc, *args) diff --git a/src/qrisp/jax/abstract_quantum_session.py b/src/qrisp/jax/tracing_quantum_session.py similarity index 89% rename from src/qrisp/jax/abstract_quantum_session.py rename to src/qrisp/jax/tracing_quantum_session.py index 98a364a3..339d6c26 100644 --- a/src/qrisp/jax/abstract_quantum_session.py +++ b/src/qrisp/jax/tracing_quantum_session.py @@ -22,10 +22,10 @@ from qrisp.jax import qdef_p, create_qubits, delete_qubits_p -abstract_qs = [lambda : None] +tr_qs_container = [lambda : None] -class AbstractQuantumSession: +class TracingQuantumSession: def __init__(self): @@ -85,18 +85,18 @@ def check_for_tracing_mode(): def check_live(tracer): return bool(tracer._trace.main.jaxpr_stack) -def get_abstract_qs(): - res = abstract_qs[0]() +def get_tracing_qs(): + res = tr_qs_container[0]() if check_for_tracing_mode(): if res is None: - res = AbstractQuantumSession() - abstract_qs[0] = weakref.ref(res) + res = TracingQuantumSession() + tr_qs_container[0] = weakref.ref(res) return res if not check_live(res.abs_qc): - res = AbstractQuantumSession() - abstract_qs[0] = weakref.ref(res) + res = TracingQuantumSession() + tr_qs_container[0] = weakref.ref(res) return res else: if res is not None: - abstract_qs[0] = lambda : None + tr_qs_container[0] = lambda : None return res \ No newline at end of file diff --git a/src/qrisp/misc/utility.py b/src/qrisp/misc/utility.py index 8b917d8a..214229e0 100644 --- a/src/qrisp/misc/utility.py +++ b/src/qrisp/misc/utility.py @@ -611,8 +611,8 @@ def wrapped_function( def find_qs(args): - from qrisp.jax import get_abstract_qs - abs_qs = get_abstract_qs() + from qrisp.jax import get_tracing_qs + abs_qs = get_tracing_qs() if abs_qs is not None: return abs_qs