Skip to content

Commit

Permalink
implemented the catalyst function calling logic
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Dec 30, 2024
1 parent d4536ff commit d281b86
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import jax.numpy as jnp
from jax import jit

from qrisp.jasp import jrange
from qrisp.jasp import jrange, qache
from qrisp.core import x, cx, QuantumVariable, mcx
from qrisp.environments import control, custom_control

@jit
@qache
def extract_boolean_digit(integer, digit):
return jnp.bool((integer>>digit & 1))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,38 @@
********************************************************************************/
"""

from jax import make_jaxpr
from functools import lru_cache

from jax import make_jaxpr, jit
from jax.core import ClosedJaxpr
from jax.lax import fori_loop
from jax._src.linear_util import wrap_init
import jax.numpy as jnp

from catalyst.jax_primitives import AbstractQreg, qinst_p, qmeasure_p, qextract_p, qinsert_p, while_p, cond_p
from catalyst.jax_primitives import (AbstractQreg, qinst_p, qmeasure_p,
qextract_p, qinsert_p, while_p, cond_p, func_p)

from qrisp.jasp import (QuantumPrimitive, OperationPrimitive, AbstractQuantumCircuit, AbstractQubitArray,
AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues)


# Name translator from Qrisp gate naming to Catalyst gate naming
op_name_translation_dic = {"cx" : "CNOT",
"cy" : "CY",
"cz" : "CZ",
"crx" : "CRX",
"crz" : "CRZ",
"swap" : "SWAP",
"x" : "PauliX",
"y" : "PauliY",
"z" : "PauliZ",
"h" : "Hadamard",
"rx" : "RX",
"ry" : "RY",
"rz" : "RZ",
"s" : "S",
"t" : "T",
"p" : "RZ"}
"cy" : "CY",
"cz" : "CZ",
"crx" : "CRX",
"crz" : "CRZ",
"swap" : "SWAP",
"x" : "PauliX",
"y" : "PauliY",
"z" : "PauliZ",
"h" : "Hadamard",
"rx" : "RX",
"ry" : "RY",
"rz" : "RZ",
"s" : "S",
"t" : "T",
"p" : "RZ"}


def catalyst_eqn_evaluator(eqn, context_dic):
Expand Down Expand Up @@ -414,13 +419,49 @@ def process_cond(eqn, context_dic):


insert_outvalues(eqn, context_dic, unflattened_outvalues)

@lru_cache(maxsize = int(1E5))
def get_traced_fun(jaxpr):

from jax.core import eval_jaxpr

if isinstance(jaxpr, Jaspr):
catalyst_jaxpr = jaxpr.to_catalyst_jaxpr()
else:
catalyst_jaxpr = ClosedJaxpr(jaxpr, [])

@jit
def jitted_fun(*args):
return eval_jaxpr(catalyst_jaxpr.jaxpr, [], *args)

return jitted_fun


def process_pjit(eqn, context_dic):

invalues = extract_invalues(eqn, context_dic)
outvalues = eval_jaxpr(eqn.params["jaxpr"], eqn_evaluator = catalyst_eqn_evaluator)(*invalues)
if len(eqn.params["jaxpr"].jaxpr.outvars) == 1:
outvalues = [outvalues]
insert_outvalues(eqn, context_dic, outvalues)

flattened_invalues = []
for value in invalues:
if isinstance(value, tuple):
flattened_invalues.extend(value)
else:
flattened_invalues.append(value)

jaxpr = eqn.params["jaxpr"]
traced_fun = get_traced_fun(jaxpr.jaxpr)

outvalues = func_p.bind(wrap_init(traced_fun), *flattened_invalues, fn=traced_fun)

outvalues = list(outvalues)
unflattened_outvalues = []
for outvar in eqn.outvars:
if isinstance(outvar.aval, AbstractQuantumCircuit):
unflattened_outvalues.append((outvalues.pop(0), outvalues.pop(0)))
elif isinstance(outvar.aval, AbstractQubitArray):
unflattened_outvalues.append((outvalues.pop(0), outvalues.pop(0)))
else:
unflattened_outvalues.append(outvalues.pop(0))

insert_outvalues(eqn, context_dic, unflattened_outvalues)

13 changes: 13 additions & 0 deletions src/qrisp/jasp/tracing_logic/qaching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import jax
import jax.numpy as jnp

from qrisp.core import recursive_qv_search

Expand Down Expand Up @@ -230,6 +231,18 @@ def ammended_function(abs_qc, *args, **kwargs):
abs_qs.register_qv(qv, None)
flattened_qvs.extend(list(flatten_qv(qv)[0]))

# Make sure literals are 32 bit
args = list(args)
for i in range(len(args)):
if isinstance(args[i], bool):
args[i] = jnp.array(args[i], dtype = jnp.bool)
elif isinstance(args[i], int):
args[i] = jnp.array(args[i], dtype = jnp.int32)
elif isinstance(args[i], float):
args[i] = jnp.array(args[i], dtype = jnp.float32)
elif isinstance(args[i], complex):
args[i] = jnp.array(args[i], dtype = jnp.complex)

# Execute the function
res = func(*args, **kwargs)
new_abs_qc = abs_qs.abs_qc
Expand Down

0 comments on commit d281b86

Please sign in to comment.