Skip to content

Commit

Permalink
Merge pull request #143 from eclipse-qrisp/jit_based_rt_simulation
Browse files Browse the repository at this point in the history
Jit based rt simulation
  • Loading branch information
positr0nium authored Feb 28, 2025
2 parents 61d5f20 + 0b862de commit 867e78a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 8 deletions.
41 changes: 40 additions & 1 deletion src/qrisp/jasp/evaluation_tools/jaspification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
********************************************************************************/
"""

from functools import lru_cache

import jax
from jax.tree_util import tree_flatten, tree_unflatten

from qrisp.jasp.interpreter_tools import extract_invalues, insert_outvalues, eval_jaxpr
from qrisp.jasp.evaluation_tools.buffered_quantum_state import BufferedQuantumState
from qrisp.jasp.primitives import OperationPrimitive, AbstractQuantumCircuit, AbstractQubitArray, AbstractQubit
from qrisp.core import recursive_qv_search
from qrisp.circuit import fast_append


def jaspify(func = None, terminal_sampling = False):
"""
This simulator is the established Qrisp simulator linked to the Jasp infrastructure.
Expand Down Expand Up @@ -261,6 +266,7 @@ def eqn_evaluator(eqn, context_dic):
if eqn.primitive.name == "pjit":

function_name = eqn.params["name"]
jaxpr = eqn.params["jaxpr"]

if terminal_sampling:

Expand All @@ -275,6 +281,31 @@ def eqn_evaluator(eqn, context_dic):
return

invalues = extract_invalues(eqn, context_dic)

# If there are only classical values, we attempt to compile using the jax pipeline
for var in jaxpr.jaxpr.invars + jaxpr.jaxpr.outvars:
if isinstance(var.aval, (AbstractQuantumCircuit, AbstractQubitArray, AbstractQubit)):
break
else:


compiled_function, is_executable = compile_cl_func(jaxpr.jaxpr, function_name)

# Functions with purely classical inputs/outputs can still contain
# kernelized quantum functions. This will raise an NotImplementedError
# when attempting to compile. Since the compile_cl_func is lru_cached
# we can store this information to avoid further attempts at compiling
# such a function.
if is_executable[0]:
try:
outvalues = compiled_function(*(invalues + jaxpr.consts))
if len(jaxpr.jaxpr.outvars) > 1:
insert_outvalues(eqn, context_dic, outvalues)
else:
insert_outvalues(eqn, context_dic, [outvalues])
return False
except NotImplementedError:
is_executable[0] = False

# We simulate the inverse Gidney mcx via the non-hybrid version because
# the hybrid version prevents the simulator from fusing gates, which
Expand All @@ -287,6 +318,8 @@ def eqn_evaluator(eqn, context_dic):
if not isinstance(outvalues, (list, tuple)):
outvalues = [outvalues]
insert_outvalues(eqn, context_dic, outvalues)


elif eqn.primitive.name == "jasp.quantum_kernel":
insert_outvalues(eqn, context_dic, BufferedQuantumState(simulator))
else:
Expand All @@ -298,4 +331,10 @@ def eqn_evaluator(eqn, context_dic):
if len(jaspr.outvars) == 2:
return res[1]
else:
return res[1:]
return res[1:]

@lru_cache(maxsize = int(1E5))
def compile_cl_func(jaxpr, function_name):
return jax.jit(eval_jaxpr(jaxpr)), [True]


Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
********************************************************************************/
"""

from functools import lru_cache
from random import shuffle

import numpy as np

import jax.numpy as jnp
import jax

from qrisp.jasp.interpreter_tools.abstract_interpreter import eval_jaxpr, extract_invalues, insert_outvalues, exec_eqn
from qrisp.jasp.interpreter_tools.interpreters.control_flow_interpretation import evaluate_while_loop
Expand Down Expand Up @@ -236,6 +238,9 @@ def sampling_body_eqn_evaluator(eqn, context_dic):
elif sampling_res_type == "dict":
sampling_res = {}

# Compile the decoder
decoder = decoder_compiler(eqn.params["jaxpr"], eqn_evaluator)

# Iterate through the sampled values
for k, v in meas_res_dic.items():

Expand All @@ -255,12 +260,14 @@ def sampling_body_eqn_evaluator(eqn, context_dic):
j += return_signature[i]

# Evaluate the decoder
outvalues = eval_jaxpr(eqn.params["jaxpr"], eqn_evaluator = eqn_evaluator)(*new_invalues)
#outvalues = eval_jaxpr(eqn.params["jaxpr"], eqn_evaluator = eqn_evaluator)(*new_invalues)

outvalues = decoder(*new_invalues)


# We now build the key for the result dic
# For that we turn the jax types into the corresponding
# Python types.

if not isinstance(outvalues, tuple):
if sampling_res_type == "ev":
sampling_res += outvalues*v
Expand Down Expand Up @@ -293,10 +300,11 @@ def sampling_body_eqn_evaluator(eqn, context_dic):
else:
raise
sampling_res[tuple(x.item() for x in outvalues)] = v

if sampling_res_type == "array":
shuffle(sampling_res)
sampling_res = jnp.array(sampling_res)
sampling_res = np.array(sampling_res)
np.random.shuffle(sampling_res)

elif sampling_res_type == "ev":
sampling_res = sampling_res/shots
if sampling_res.shape[0] == 1:
Expand All @@ -312,12 +320,19 @@ def sampling_body_eqn_evaluator(eqn, context_dic):

# Execute the above defined interpreter
sampling_body_jaxpr = eqn.params["jaxpr"].jaxpr

outvalues = eval_jaxpr(sampling_body_jaxpr, eqn_evaluator = sampling_body_eqn_evaluator)(*invalues)



if not isinstance(outvalues, (list, tuple)):
outvalues = [outvalues]

insert_outvalues(eqn, context_dic, decoded_meas_res)

return sampling_eqn_evaluator
return sampling_eqn_evaluator


@lru_cache(maxsize = int(1E5))
def decoder_compiler(jaxpr, eqn_evaluator):
return jax.jit(eval_jaxpr(jaxpr, eqn_evaluator = eqn_evaluator))
6 changes: 6 additions & 0 deletions tests/jax_tests/test_HHL_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,18 @@ def main():
# quantum variables, while most other evaluation modes require
# classical return values.
return qrisp.measure(x)

try:
import catalyst
except ImportError:
return

jaspr = qrisp.make_jaspr(main)()
qir_str = jaspr.to_qir()
# Print only the first few lines - the whole string is very long.
print(qir_str[:200])


############################################################

A = np.array([[3 / 8, 1 / 8], [1 / 8, 3 / 8]])
Expand Down

0 comments on commit 867e78a

Please sign in to comment.