Skip to content

Commit

Permalink
implemented the static argnums feature for RUS
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Dec 31, 2024
1 parent 5714c1e commit 3e26aa9
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 55 deletions.
239 changes: 189 additions & 50 deletions src/qrisp/jasp/rus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""
import inspect

from jax.lax import while_loop, cond
import jax
Expand All @@ -25,7 +26,7 @@
from qrisp.jasp.primitives import Measurement_p, OperationPrimitive, get_qubit_p, get_size_p, delete_qubits_p, reset_p


def RUS(trial_function):
def RUS(*trial_function, **jit_kwargs):
r"""
Decorator to deploy repeat-until-success (RUS) components. At the core,
RUS repeats a given quantum subroutine followed by a qubit measurement until
Expand All @@ -48,6 +49,15 @@ def RUS(trial_function):
trial_function : callable
A function returning a boolean value as the first return value. More
return values are possible.
static_argnums : int or list[int], optional
A list of integers specifying which arguments are considered static in
the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_.
The first argument is indicated by 1, the second by 2, etc. The default
is ``[]``.
static_argnames : str or list[str], optional
A list of strings specifying which arguments are considered static in
the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_.
The default is ``[]``.
Returns
-------
Expand Down Expand Up @@ -100,9 +110,155 @@ def call_RUS_example():
jaspr = make_jaspr(call_RUS_example)()
print(jaspr())
# Yields, 31 which is the decimal version of 11111
**Static arguments**
To demonstrate the specification of static arguments, we will realize implement a
simple `linear combination of unitaries <https://arxiv.org/abs/1202.5822>`_.
Our implementation initializes a state of the form
.. math::
\left( \sum_{i = 0}^N c_i U_i \right) \ket{0}.
We achieve this by specifying a set of unitaries $U_i$ in the form of a
tuple of functions, each processing a :ref:`QuantumFloat`.
The coefficients $c_i$ are specified through a function preparing the state
.. math::
\ket{\psi} = \sum_{i = 0}^N c_i \ket{i}
For the state preparation function we specify two options to experiment with.
A two qubit uniform superposition and a function that brings only the first
qubit into superpostion.
::
def state_prep_full(qv):
h(qv[0])
h(qv[1])
def state_prep_half(qv):
h(qv[0])
For the first one we have $c_0 = c_1 = c_2 = c_3 = \sqrt{0.25}$. The second one
gives $c_0 = c_1 = \sqrt{0.5}$ and $c_2 = c_3 = 0$.
The next step is to define the unitaries $U_i$ in the form of a tuple
of functions.
::
from qrisp.jasp import *
from qrisp import *
def case_function_0(x):
x += 3
def case_function_1(x):
x += 4
def case_function_2(x):
x += 5
def case_function_3(x):
x += 6
case_functions = (case_function_0,
case_function_1,
case_function_2,
case_function_3)
These functions each represent the unitary:
.. math::
U_i \ket{0} = \ket{i+3}
Executing a linear combination of unitaries therefore gives
.. math::
\left( \sum_{i = 0}^N c_i U_i \right) \ket{0} = \sum_{i = 0}^N c_i \ket{i+3}
Now we implement the LCU procedure.
::
# Specify the corresponding arguments of the block encoding as "static",
# i.e. compile time constants.
@RUS(static_argnums = [2,3])
def block_encoding(return_size, state_preparation, case_functions):
"""
# This QuantumFloat will be returned
qf = QuantumFloat(return_size)
# Specify the QuantumVariable that indicates, which
# case to execute
n = int(np.ceil(np.log2(len(case_functions))))
case_indicator = QuantumFloat(n)
# Turn into a list of qubits
case_indicator_qubits = [case_indicator[i] for i in range(n)]
# Perform the LCU protocoll
with conjugate(state_preparation)(case_indicator):
for i in range(len(case_functions)):
with control(case_indicator_qubits, ctrl_state = i):
case_functions[i](qf)
# Compute the success condition
success_bool = (measure(case_indicator) == 0)
return success_bool, qf
Finally, evaluate via the :ref:`terminal_sampling <terminal_sampling>`
feature:
::
@terminal_sampling
def main():
return block_encoding(4, state_prep_full, case_functions)
print(main())
# Yields: {3.0: 0.25, 4.0: 0.25, 5.0: 0.25, 6.0: 0.25}
Evaluate the other state preparation function
::
@terminal_sampling
def main():
return block_encoding(4, state_prep_half, case_functions)
print(main())
# Yields: {3.0: 0.5, 4.0: 0.5}
As expected, the full state preparation function yields a state proportional
to
.. math::
\ket{3} + \ket{4} + \ket{5} + \ket{6}.
The second state preparation gives us
.. math::
\ket{3} + \ket{4}.
"""
if len(trial_function) == 0:
return lambda x : RUS(x, **jit_kwargs)
else:
trial_function = trial_function[0]

# The idea for implementing this feature is to execute the function once
# to collect the output QuantumVariable object.
Expand All @@ -111,19 +267,43 @@ def call_RUS_example():
def return_function(*trial_args):

# Execute the function
first_iter_res = qache(trial_function)(*trial_args)

# Flatten the arguments and the res values
arg_vals, arg_tree_def = jax.tree.flatten(trial_args)
res_vals, res_tree_def = jax.tree.flatten(first_iter_res)

first_iter_res = qache(trial_function, **jit_kwargs)(*trial_args)

# Extract the jaspr
from qrisp.jasp import make_jaspr
eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]
ammended_trial_func_jaspr = eqn.params["jaxpr"].jaxpr

from qrisp.jasp import collect_environments

ammended_trial_func_jaspr = make_jaspr(trial_function)(*trial_args)
ammended_trial_func_jaspr = collect_environments(ammended_trial_func_jaspr)
ammended_trial_func_jaspr = ammended_trial_func_jaspr.flatten_environments()

# Filter out the static arguments
if "static_argnums" in jit_kwargs:
static_argnums = jit_kwargs["static_argnums"]
if isinstance(static_argnums, int):
static_argnums = [static_argnums]
else:
static_argnums = []

if "static_argnames" in jit_kwargs:
argname_list = inspect.getfullargspec(trial_function)
for i in range(len(argname_list)):
if argname_list[i] in jit_kwargs["static_argnames"]:
static_argnums.append(i+1)

new_trial_args = []

for i in range(len(trial_args)):
if i+1 not in static_argnums:
new_trial_args.append(trial_args[i])

trial_args = new_trial_args

# Flatten the arguments and the res values
arg_vals, arg_tree_def = jax.tree.flatten(trial_args)
res_vals, res_tree_def = jax.tree.flatten(first_iter_res)

# Next we construct the body of the loop
# In order to work with the while_loop interface from jax
Expand Down Expand Up @@ -204,44 +384,3 @@ def false_fun(combined_args):

return return_function


@jax.jit
def extract_boolean_digit(integer, digit):
return jnp.bool((integer>>digit & 1))
# Function to reset and delete a qubit array
@jax.jit
def reset_qubit_array(abs_qc, qb_array):

def body_func(arg_tuple):

abs_qc, qb_array, i = arg_tuple

abs_qb = get_qubit_p.bind(qb_array, i)
abs_qc, meas_bl = Measurement_p.bind(abs_qc, abs_qb)

def true_fun(arg_tuple):
abs_qc, qb = arg_tuple
abs_qc = OperationPrimitive(XGate()).bind(abs_qc, qb)
return (abs_qc, qb)

def false_fun(arg_tuple):
return arg_tuple

abs_qc, qb = cond(meas_bl, true_fun, false_fun, (abs_qc, abs_qb))

i += 1

return (abs_qc, qb_array, i)

def cond_fun(arg_tuple):
return arg_tuple[-1] < get_size_p.bind(arg_tuple[1])


abs_qc, qb_array, i = while_loop(cond_fun,
body_func,
(abs_qc, qb_array, 0)
)

abs_qc = delete_qubits_p.bind(abs_qc, qb_array)

return abs_qc
2 changes: 1 addition & 1 deletion src/qrisp/jasp/terminal_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def eqn_evaluator(eqn, context_dic):
# Round to prevent floating point errors of the simulation
norm = 0
for k, v in meas_res_dic.items():
meas_res_dic[k] = np.round(v, decimals = 8)
meas_res_dic[k] = np.round(v, decimals = 7)
norm += meas_res_dic[k]

for k, v in meas_res_dic.items():
Expand Down
6 changes: 3 additions & 3 deletions src/qrisp/jasp/tracing_logic/qaching.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ def main():
"""

if len(kwargs):
if len(kwargs) and len(func) == 0:
return lambda x : qache_helper(x, kwargs)
elif len(kwargs) and len(func):
return qache_helper(func[0], kwargs)
else:
return qache_helper(func[0], {})

Expand Down Expand Up @@ -293,8 +295,6 @@ def return_function(*args, **kwargs):
from qrisp.jasp import Jaspr
eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]
eqn.params["jaxpr"] = jax.core.ClosedJaxpr(Jaspr.from_cache(eqn.params["jaxpr"].jaxpr), eqn.params["jaxpr"].consts)
if eqn.params["name"] == "gidney_mcx_inv":
print(id(eqn.params["jaxpr"].jaxpr))

# Update the AbstractQuantumCircuit of the TracingQuantumSession
abs_qs.abs_qc = abs_qc_new
Expand Down
2 changes: 1 addition & 1 deletion src/qrisp/jasp/tracing_logic/tracing_quantum_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def append(self, operation, qubits = [], clbits = [], param_tracers = []):

from qrisp.core import QuantumVariable

if isinstance(qubits[0], QuantumVariable):
if isinstance(qubits[0], (QuantumVariable, DynamicQubitArray)):

from qrisp.jasp import jrange

Expand Down
62 changes: 62 additions & 0 deletions tests/jax_tests/test_rus.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,68 @@ def main():

assert main() in [3,4,5,6]

# Test static arguments

def case_function_0(x):
x += 3

def case_function_1(x):
x += 4

def case_function_2(x):
x += 5

def case_function_3(x):
x += 6

case_functions = (case_function_0,
case_function_1,
case_function_2,
case_function_3)

def state_prep_full(qv):
h(qv[0])
h(qv[1])

def state_prep_half(qv):
h(qv[0])

# Specify the corresponding arguments of the block encoding as "static",
# i.e. compile time constants.

@RUS(static_argnums = [2,3])
def block_encoding(return_size, state_preparation, case_functions):

# This QuantumFloat will be returned
qf = QuantumFloat(return_size)

# Specify the QuantumVariable that indicates, which
# case to execute
n = int(np.ceil(np.log2(len(case_functions))))
case_indicator = QuantumFloat(n)

# Turn into a list of qubits
case_indicator_qubits = [case_indicator[i] for i in range(n)]

# Perform the LCU protocoll
with conjugate(state_preparation)(case_indicator):
for i in range(len(case_functions)):
with control(case_indicator_qubits, ctrl_state = i):
case_functions[i](qf)

# Compute the success condition
success_bool = (measure(case_indicator) == 0)

return success_bool, qf

@terminal_sampling
def main():
return block_encoding(4, state_prep_full, case_functions)

assert main() == {3.0: 0.25, 4.0: 0.25, 5.0: 0.25, 6.0: 0.25}






0 comments on commit 3e26aa9

Please sign in to comment.