Skip to content

Commit 3da6344

Browse files
committed
fixed an issue in the jax converter that preventer more complex parameter expressions to be evaluated
1 parent a33780c commit 3da6344

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/qrisp/jasp/interpreter_tools/interpreters/catalyst_interpreter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from functools import lru_cache
2020

21+
from sympy import lambdify, symbols
22+
2123
from jax import make_jaxpr, jit
2224
from jax.core import ClosedJaxpr
2325
from jax.lax import fori_loop, cond, while_loop
@@ -31,6 +33,7 @@
3133
AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues, Measurement_p, get_qubit_p,
3234
get_size_p, delete_qubits_p)
3335

36+
greek_letters = symbols('alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega')
3437

3538
# Name translator from Qrisp gate naming to Catalyst gate naming
3639
op_name_translation_dic = {"cx" : "CNOT",
@@ -232,7 +235,11 @@ def exec_qrisp_op(op, catalyst_qbs, param_dict):
232235

233236
catalyst_name = op_name_translation_dic[op_name]
234237

235-
param_list = [param_dict[symb] for symb in op.params]
238+
jax_values = list(param_dict.values())
239+
240+
param_list = [lambdify(greek_letters[:len(op.abstract_params)], expr)(*jax_values) for expr in op.params]
241+
# param_list = [param_dict[symb] for symb in op.params]
242+
236243
res_qbs = qinst_p.bind(*(catalyst_qbs+param_list),
237244
op = catalyst_name,
238245
qubits_len = op.num_qubits,

0 commit comments

Comments
 (0)