Skip to content

Commit 151879c

Browse files
committed
fixed an issue that prevent the proper conversion of literals to i32 within custom_control
1 parent a4b72f0 commit 151879c

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

src/qrisp/environments/custom_control_environment.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import inspect
2020

2121
import jax
22+
import jax.numpy as jnp
2223

2324
from qrisp.environments.quantum_environments import QuantumEnvironment
2425
from qrisp.environments.gate_wrap_environment import GateWrapEnvironment
@@ -206,6 +207,19 @@ def adaptive_control_function(*args, **kwargs):
206207
res = func(*args, ctrl = control_qb, **kwargs)
207208

208209
else:
210+
211+
args = list(args)
212+
if func.__name__ == "extract_boolean_digit":
213+
print(args)
214+
for i in range(len(args)):
215+
if isinstance(args[i], bool):
216+
args[i] = jnp.array(args[i], dtype = jnp.bool)
217+
elif isinstance(args[i], int):
218+
args[i] = jnp.array(args[i], dtype = jnp.int32)
219+
elif isinstance(args[i], float):
220+
args[i] = jnp.array(args[i], dtype = jnp.float32)
221+
elif isinstance(args[i], complex):
222+
args[i] = jnp.array(args[i], dtype = jnp.complex)
209223

210224
# Call the (qached) function
211225
res = func(*args, **kwargs)

src/qrisp/jasp/tracing_logic/qaching.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,18 +231,6 @@ def ammended_function(abs_qc, *args, **kwargs):
231231
abs_qs.register_qv(qv, None)
232232
flattened_qvs.extend(list(flatten_qv(qv)[0]))
233233

234-
# Make sure literals are 32 bit
235-
args = list(args)
236-
for i in range(len(args)):
237-
if isinstance(args[i], bool):
238-
args[i] = jnp.array(args[i], dtype = jnp.bool)
239-
elif isinstance(args[i], int):
240-
args[i] = jnp.array(args[i], dtype = jnp.int32)
241-
elif isinstance(args[i], float):
242-
args[i] = jnp.array(args[i], dtype = jnp.float32)
243-
elif isinstance(args[i], complex):
244-
args[i] = jnp.array(args[i], dtype = jnp.complex)
245-
246234
# Execute the function
247235
res = func(*args, **kwargs)
248236
new_abs_qc = abs_qs.abs_qc
@@ -283,6 +271,18 @@ def return_function(*args, **kwargs):
283271
# Get the AbstractQuantumCircuit for tracing
284272
abs_qs = TracingQuantumSession.get_instance()
285273
abs_qs.start_tracing(abs_qs.abs_qc)
274+
275+
# Make sure literals are 32 bit
276+
args = list(args)
277+
for i in range(len(args)):
278+
if isinstance(args[i], bool):
279+
args[i] = jnp.array(args[i], dtype = jnp.bool)
280+
elif isinstance(args[i], int):
281+
args[i] = jnp.array(args[i], dtype = jnp.int32)
282+
elif isinstance(args[i], float):
283+
args[i] = jnp.array(args[i], dtype = jnp.float32)
284+
elif isinstance(args[i], complex):
285+
args[i] = jnp.array(args[i], dtype = jnp.complex)
286286

287287
# Excecute the function
288288
abs_qc_new, res = ammended_function(abs_qs.abs_qc, *args, **kwargs)

0 commit comments

Comments
 (0)