Skip to content

Commit

Permalink
Jaspr.flatten_environments is now automatically called by make_jaspr
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Jan 8, 2025
1 parent c8052e6 commit 62ec91c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
"flask",
"waitress",
"pyyaml",
"requests"]
"requests",
"psutil",
"jax"]


with open("README.md", "r", encoding="utf-8") as fh:
Expand Down
15 changes: 10 additions & 5 deletions src/qrisp/jasp/jasp_expression/centerclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def example_function(i):
"""
from qrisp import QuantumCircuit
jaspr = self.flatten_environments()
jaspr = self

def eqn_evaluator(eqn, context_dic):
if eqn.primitive.name == "pjit" and isinstance(eqn.params["jaxpr"].jaxpr, Jaspr):
Expand Down Expand Up @@ -433,7 +433,7 @@ def __call__(self, *args):
args = [BufferedQuantumState()] + list(args)

from qrisp.jasp import extract_invalues, insert_outvalues, eval_jaxpr
flattened_jaspr = self.flatten_environments()
flattened_jaspr = self

def eqn_evaluator(eqn, context_dic):
if eqn.primitive.name == "pjit":
Expand Down Expand Up @@ -511,7 +511,7 @@ def qjit(self, *args, function_name = "jaspr_function"):
The values returned by the compiled, executed function.
"""
flattened_jaspr = self.flatten_environments()
flattened_jaspr = self

from qrisp.jasp.catalyst_interface import jaspr_to_catalyst_qjit
qjit_obj = jaspr_to_catalyst_qjit(flattened_jaspr, function_name = function_name)
Expand Down Expand Up @@ -987,7 +987,7 @@ def example_function(i):



def make_jaspr(fun, garbage_collection = "auto", **jax_kwargs):
def make_jaspr(fun, garbage_collection = "auto", flatten_envs = True, **jax_kwargs):
from qrisp.jasp import AbstractQuantumCircuit, TracingQuantumSession, check_for_tracing_mode
from qrisp.core.quantum_variable import QuantumVariable, flatten_qv, unflatten_qv
from qrisp.core import recursive_qv_search
Expand Down Expand Up @@ -1036,7 +1036,12 @@ def ammended_function(abs_qc, *args, **kwargs):
# Collect the environments
# This means that the quantum environments no longer appear as
# enter/exit primitives but as primitive that "call" a certain Jaspr.
return Jaspr.from_cache(collect_environments(jaxpr))
res = Jaspr.from_cache(collect_environments(jaxpr))

if flatten_envs:
res = res.flatten_environments()

return res

# Since we are calling the "ammended function", where the first parameter
# is the AbstractQuantumCircuit, we need to move the static_argnums indicator.
Expand Down
2 changes: 1 addition & 1 deletion src/qrisp/jasp/terminal_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def terminal_sampling_helper_2(*meas_tuples):

# Make the jaspr and flatten the environments
jaspr = make_jaspr(ammended_function)(*args, **kwargs)
flattened_jaspr = jaspr.flatten_environments()
flattened_jaspr = jaspr

# This dictionary will contain the integer measurement resuts
meas_res_dic = {}
Expand Down
5 changes: 4 additions & 1 deletion tests/jax_tests/test_basic_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def test_function():

compare_jaxpr(make_jaspr(test_function)(),
['jasp.create_qubits',
'jasp.q_env',
'jasp.get_qubit',
'jasp.h',
'jasp.get_qubit',
'jasp.cx',
'jasp.get_qubit',
'jasp.measure',
'jasp.reset',
Expand Down
6 changes: 3 additions & 3 deletions tests/jax_tests/test_control_flow_capturing.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_f(a, b):

return measure(qv)

jaspr = make_jaspr(test_f)(1,1)
jaspr = make_jaspr(test_f, flatten_envs = False)(1,1)

try:
jaspr(4,5)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_f(a, b):

return measure(qv)

jaspr = make_jaspr(test_f)(1,1)
jaspr = make_jaspr(test_f, flatten_envs = False)(1,1)

try:
jaspr(4,5)
Expand Down Expand Up @@ -260,7 +260,7 @@ def test_f(i):

return measure(c)

jaspr = make_jaspr(test_f)(1)
jaspr = make_jaspr(test_f, flatten_envs = False)(1)

try:
jaspr(4)
Expand Down

0 comments on commit 62ec91c

Please sign in to comment.