Skip to content

Commit

Permalink
renamed AbstractQuantumSession to TracingQuantumSession
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Jun 27, 2024
1 parent bff8465 commit 807cc50
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/qrisp/core/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,10 +1044,10 @@ def measure(qubits, clbits=None):
"""
from qrisp import find_qs
from qrisp.jax import AbstractQuantumSession
from qrisp.jax import TracingQuantumSession
qs = find_qs(qubits)

if not isinstance(qs, AbstractQuantumSession):
if not isinstance(qs, TracingQuantumSession):
if clbits is None:
clbits = []
if hasattr(qubits, "__len__"):
Expand Down
8 changes: 4 additions & 4 deletions src/qrisp/core/quantum_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def __init__(self, size, qs=None, name=None):

# Store quantum session
from qrisp.core import QuantumSession, merge_sessions
from qrisp.jax import check_for_tracing_mode, get_abstract_qs
from qrisp.jax import check_for_tracing_mode, get_tracing_qs

if check_for_tracing_mode():
self.qs = get_abstract_qs()
self.qs = get_tracing_qs()
else:
if qs is not None:
self.qs = qs
Expand Down Expand Up @@ -1453,7 +1453,7 @@ def plot_histogram(outcome_labels, counts, filename=None):


from jax import tree_util
from qrisp.jax.abstract_quantum_session import get_abstract_qs
from qrisp.jax.abstract_quantum_session import get_tracing_qs
from builtins import id


Expand All @@ -1470,7 +1470,7 @@ def unflatten_qv(aux_data, children):
res.reg = children[0]
res.size = children[1]
res.name = aux_data[1]
res.qs = get_abstract_qs()
res.qs = get_tracing_qs()

return res

Expand Down
6 changes: 3 additions & 3 deletions src/qrisp/environments/quantum_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from qrisp.circuit import QubitAlloc, QubitDealloc, fast_append
from qrisp.core.quantum_session import QuantumSession

from qrisp.jax import QuantumPrimitive, AbstractQuantumCircuit, get_abstract_qs
from qrisp.jax import QuantumPrimitive, AbstractQuantumCircuit, get_tracing_qs

class QuantumEnvironment(QuantumPrimitive):
"""
Expand Down Expand Up @@ -379,7 +379,7 @@ def stop_dumping(self):
# Method to enter the environment
def __enter__(self):

abs_qs = get_abstract_qs()
abs_qs = get_tracing_qs()
if abs_qs is not None:
abs_qs.abs_qc = self.bind(abs_qs.abs_qc, stage = "enter")
return
Expand Down Expand Up @@ -428,7 +428,7 @@ def __enter__(self):
def __exit__(self, exception_type, exception_value, traceback):


abs_qs = get_abstract_qs()
abs_qs = get_tracing_qs()
if abs_qs is not None:
abs_qs.abs_qc = self.bind(abs_qs.abs_qc, stage = "exit")
return
Expand Down
6 changes: 3 additions & 3 deletions src/qrisp/jax/qaching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"""

from jax import jit
from qrisp.jax import get_abstract_qs
from qrisp.jax import get_tracing_qs

def qache(func):

def ammended_function(abs_qc, *args):

qs = get_abstract_qs()
qs = get_tracing_qs()
qs.abs_qc = abs_qc

res = func(*args)
Expand All @@ -36,7 +36,7 @@ def ammended_function(abs_qc, *args):

def return_function(*args):

abs_qs = get_abstract_qs()
abs_qs = get_tracing_qs()

abs_qc_new, res = ammended_function(abs_qs.abs_qc, *args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

from qrisp.jax import qdef_p, create_qubits, delete_qubits_p

abstract_qs = [lambda : None]
tr_qs_container = [lambda : None]


class AbstractQuantumSession:
class TracingQuantumSession:

def __init__(self):

Expand Down Expand Up @@ -85,18 +85,18 @@ def check_for_tracing_mode():
def check_live(tracer):
return bool(tracer._trace.main.jaxpr_stack)

def get_abstract_qs():
res = abstract_qs[0]()
def get_tracing_qs():
res = tr_qs_container[0]()
if check_for_tracing_mode():
if res is None:
res = AbstractQuantumSession()
abstract_qs[0] = weakref.ref(res)
res = TracingQuantumSession()
tr_qs_container[0] = weakref.ref(res)
return res
if not check_live(res.abs_qc):
res = AbstractQuantumSession()
abstract_qs[0] = weakref.ref(res)
res = TracingQuantumSession()
tr_qs_container[0] = weakref.ref(res)
return res
else:
if res is not None:
abstract_qs[0] = lambda : None
tr_qs_container[0] = lambda : None
return res
4 changes: 2 additions & 2 deletions src/qrisp/misc/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,8 @@ def wrapped_function(


def find_qs(args):
from qrisp.jax import get_abstract_qs
abs_qs = get_abstract_qs()
from qrisp.jax import get_tracing_qs
abs_qs = get_tracing_qs()
if abs_qs is not None:
return abs_qs

Expand Down

0 comments on commit 807cc50

Please sign in to comment.