Skip to content

Commit

Permalink
cleaned the catalyst converter a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Jan 7, 2025
1 parent 3da6344 commit 44fad77
Showing 1 changed file with 13 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -337,15 +337,8 @@ def process_while(eqn, context_dic):
else:
return True

body_jaxpr = eqn.params["body_jaxpr"]
cond_jaxpr = eqn.params["cond_jaxpr"]

from qrisp.jasp.catalyst_interface import jaspr_to_catalyst_jaxpr

if isinstance(body_jaxpr.jaxpr.invars[0].aval, AbstractQuantumCircuit):
body_jaxpr = jaspr_to_catalyst_jaxpr(body_jaxpr.jaxpr)
if isinstance(cond_jaxpr.jaxpr.invars[0].aval, AbstractQuantumCircuit):
cond_jaxpr = jaspr_to_catalyst_jaxpr(cond_jaxpr.jaxpr)
body_jaxpr = ensure_conversion(eqn.params["body_jaxpr"].jaxpr)
cond_jaxpr = ensure_conversion(eqn.params["cond_jaxpr"].jaxpr)

invalues = extract_invalues(eqn, context_dic)

Expand Down Expand Up @@ -389,16 +382,9 @@ def process_while(eqn, context_dic):

def process_cond(eqn, context_dic):

false_jaxpr = eqn.params["branches"][0]
true_jaxpr = eqn.params["branches"][1]
false_jaxpr = ensure_conversion(eqn.params["branches"][0].jaxpr)
true_jaxpr = ensure_conversion(eqn.params["branches"][1].jaxpr)

from qrisp.jasp.catalyst_interface import jaspr_to_catalyst_jaxpr

if isinstance(false_jaxpr.jaxpr.invars[0].aval, AbstractQuantumCircuit):
false_jaxpr = jaspr_to_catalyst_jaxpr(false_jaxpr.jaxpr)
if isinstance(true_jaxpr.jaxpr.invars[0].aval, AbstractQuantumCircuit):
true_jaxpr = jaspr_to_catalyst_jaxpr(true_jaxpr.jaxpr)

invalues = extract_invalues(eqn, context_dic)

# Contrary to the jax cond primitive, the catalyst cond primitive
Expand Down Expand Up @@ -437,13 +423,9 @@ def process_cond(eqn, context_dic):
@lru_cache(maxsize = int(1E5))
def get_traced_fun(jaxpr):

from jax.core import eval_jaxpr

if isinstance(jaxpr, Jaspr):
catalyst_jaxpr = jaxpr.to_catalyst_jaxpr()
else:
catalyst_jaxpr = ClosedJaxpr(jaxpr, [])
catalyst_jaxpr = ensure_conversion(jaxpr)

from jax.core import eval_jaxpr
@jit
def jitted_fun(*args):
return eval_jaxpr(catalyst_jaxpr.jaxpr, [], *args)
Expand Down Expand Up @@ -525,6 +507,13 @@ def process_reset(eqn, context_dic):
outvalues = eval_jaxpr(reset_jaxpr.jaxpr, eqn_evaluator = catalyst_eqn_evaluator)(*invalues)
insert_outvalues(eqn, context_dic, outvalues)


def ensure_conversion(jaxpr):
from qrisp.jasp.catalyst_interface import jaspr_to_catalyst_jaxpr
for invar in jaxpr.invars:
if isinstance(invar.aval, (AbstractQuantumCircuit, AbstractQubitArray, AbstractQubit)):
return jaspr_to_catalyst_jaxpr(jaxpr)
return ClosedJaxpr(jaxpr, [])



0 comments on commit 44fad77

Please sign in to comment.