Skip to content

Commit

Permalink
fixed a bug that caused the static argument specification to start co…
Browse files Browse the repository at this point in the history
…unting at 1 instead of 0
  • Loading branch information
positr0nium committed Jan 3, 2025
1 parent a682c1e commit b940557
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/qrisp/core/gate_application_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def benchmark_mcx(n, methods):
if n == 0:
return controls, target
elif n == 1:
method = "gray"
append_operation(
std_ops.MCXGate(len(qubits_0), ctrl_state, method=method),
qubits_0 + qubits_1,
Expand Down
11 changes: 11 additions & 0 deletions src/qrisp/jasp/jasp_expression/centerclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,17 @@ def ammended_function(abs_qc, *args, **kwargs):
# enter/exit primitives but as primitive that "call" a certain Jaspr.
return Jaspr.from_cache(collect_environments(jaxpr))

# Since we are calling the "ammended function", where the first parameter
# is the AbstractQuantumCircuit, we need to move the static_argnums indicator.
if "static_argnums" in jax_kwargs:
jax_kwargs = dict(jax_kwargs)
if isinstance(jax_kwargs["static_argnums"], list):
jax_kwargs["static_argnums"] = list(jax_kwargs["static_argnums"])
for i in range(len(jax_kwargs["static_argnums"])):
jax_kwargs["static_argnums"][i] += 1
else:
jax_kwargs["static_argnums"] += 1

return jaspr_creator


Expand Down
4 changes: 2 additions & 2 deletions src/qrisp/jasp/rus.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ def return_function(*trial_args):
argname_list = inspect.getfullargspec(trial_function)
for i in range(len(argname_list)):
if argname_list[i] in jit_kwargs["static_argnames"]:
static_argnums.append(i+1)
static_argnums.append(i)

new_trial_args = []

for i in range(len(trial_args)):
if i+1 not in static_argnums:
if i not in static_argnums:
new_trial_args.append(trial_args[i])

trial_args = new_trial_args
Expand Down
27 changes: 26 additions & 1 deletion src/qrisp/jasp/tracing_logic/qaching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from qrisp.core import recursive_qv_search

from qrisp.jasp.primitives import AbstractQuantumCircuit
from qrisp.jasp.tracing_logic import TracingQuantumSession, check_for_tracing_mode

def qache(*func, **kwargs):
Expand Down Expand Up @@ -254,6 +255,17 @@ def ammended_function(abs_qc, *args, **kwargs):

# Return the result and the result AbstractQuantumCircuit.
return new_abs_qc, res

# Since we are calling the "ammended function", where the first parameter
# is the AbstractQuantumCircuit, we need to move the static_argnums indicator.
if "static_argnums" in jax_kwargs:
jax_kwargs = dict(jax_kwargs)
if isinstance(jax_kwargs["static_argnums"], list):
jax_kwargs["static_argnums"] = list(jax_kwargs["static_argnums"])
for i in range(len(jax_kwargs["static_argnums"])):
jax_kwargs["static_argnums"][i] += 1
else:
jax_kwargs["static_argnums"] += 1

# Modify the name of the ammended function to reflect the input
ammended_function.__name__ = func.__name__
Expand Down Expand Up @@ -294,7 +306,20 @@ def return_function(*args, **kwargs):
# Convert the jaxpr from the traced equation in to a Jaspr
from qrisp.jasp import Jaspr
eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]
eqn.params["jaxpr"] = jax.core.ClosedJaxpr(Jaspr.from_cache(eqn.params["jaxpr"].jaxpr), eqn.params["jaxpr"].consts)
jaxpr = eqn.params["jaxpr"].jaxpr

if not isinstance(eqn.invars[0].aval, AbstractQuantumCircuit):
for i in range(len(eqn.invars)):
if isinstance(eqn.invars[i].aval, AbstractQuantumCircuit):
eqn.invars[0], eqn.invars[i] = eqn.invars[i], eqn.invars[0]
break
if not isinstance(jaxpr.invars[0].aval, AbstractQuantumCircuit):
for i in range(len(jaxpr.invars)):
if isinstance(jaxpr.invars[i].aval, AbstractQuantumCircuit):
jaxpr.invars[0], jaxpr.invars[i] = jaxpr.invars[i], jaxpr.invars[0]
break

eqn.params["jaxpr"] = jax.core.ClosedJaxpr(Jaspr.from_cache(jaxpr), eqn.params["jaxpr"].consts)

# Update the AbstractQuantumCircuit of the TracingQuantumSession
abs_qs.abs_qc = abs_qc_new
Expand Down
2 changes: 1 addition & 1 deletion src/qrisp/operators/qubit/qubit_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def adjoint(self):
#
# Simulation
#
@custom_control(static_argnums = 1)
@custom_control(static_argnums = 0)
def simulate(self, coeff, qv, ctrl = None):

from qrisp import h, cx, rz, mcp, conjugate, control, QuantumBool, mcx, x, p, s, QuantumEnvironment, gphase, QuantumVariable, find_qs
Expand Down
2 changes: 1 addition & 1 deletion tests/jax_tests/test_rus.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def state_prep_half(qv):
# Specify the corresponding arguments of the block encoding as "static",
# i.e. compile time constants.

@RUS(static_argnums = [2,3])
@RUS(static_argnums = [1,2])
def block_encoding(return_size, state_preparation, case_functions):

# This QuantumFloat will be returned
Expand Down

0 comments on commit b940557

Please sign in to comment.