From c8052e6deb766dfaa5fe139f261acdce2ffa9d15 Mon Sep 17 00:00:00 2001 From: positr0nium Date: Wed, 8 Jan 2025 15:25:21 +0100 Subject: [PATCH] updated the boolean simulator with a prototypical memory management system --- src/qrisp/jasp/boolean_simulation.py | 28 ++- src/qrisp/jasp/interpreter_tools/__init__.py | 1 + .../jasp/interpreter_tools/dynamic_list.py | 163 +++++++++++++ .../interpreters/cl_func_interpreter.py | 224 ++++++++++-------- 4 files changed, 304 insertions(+), 112 deletions(-) create mode 100644 src/qrisp/jasp/interpreter_tools/dynamic_list.py diff --git a/src/qrisp/jasp/boolean_simulation.py b/src/qrisp/jasp/boolean_simulation.py index 66ab82b7..a60b5609 100644 --- a/src/qrisp/jasp/boolean_simulation.py +++ b/src/qrisp/jasp/boolean_simulation.py @@ -18,13 +18,13 @@ import jax.numpy as jnp from jax import jit -from jax.core import eval_jaxpr from qrisp.jasp import make_jaspr from qrisp.jasp.interpreter_tools.interpreters.cl_func_interpreter import jaspr_to_cl_func_jaxpr +from qrisp.jasp.interpreter_tools import Jlist, eval_jaxpr -def boolean_simulation(*func, bit_array_padding = 2**20): +def boolean_simulation(*func, bit_array_padding = 2**16): """ Decorator to simulate Jasp functions containing only classical logic (like X, CX, CCX etc.). This decorator transforms the function into a Jax-Expression without any @@ -164,24 +164,28 @@ def main(i, j): if bit_array_padding < 64: raise Exception("Tried to initialize boolean_simulation with less than 64 bits") - @jit + @jit def return_function(*args): - jaspr = make_jaspr(func)(*args) + jaspr = make_jaspr(func, garbage_collection="manual")(*args) + cl_func_jaxpr = jaspr_to_cl_func_jaxpr(jaspr.flatten_environments(), bit_array_padding) aval = cl_func_jaxpr.invars[0].aval - res = eval_jaxpr(cl_func_jaxpr, - [], - jnp.zeros(aval.shape, dtype = aval.dtype), - jnp.array(0, dtype = jnp.int64), *args) + bit_array = jnp.zeros(aval.shape, dtype = aval.dtype) + free_qubit_list = Jlist(jnp.arange(bit_array_padding), max_size = bit_array_padding).flatten()[0] + boolean_quantum_circuit = (bit_array, *free_qubit_list) + + + res = eval_jaxpr(cl_func_jaxpr)(*boolean_quantum_circuit, + *args) - if len(res) == 3: - return res[2] - elif len(res) == 2: + if len(res) == 4: + return res[3] + elif len(res) == 3: return None else: - return res[2:] + return res[3:] return return_function diff --git a/src/qrisp/jasp/interpreter_tools/__init__.py b/src/qrisp/jasp/interpreter_tools/__init__.py index ae3d7193..bda9bd6e 100644 --- a/src/qrisp/jasp/interpreter_tools/__init__.py +++ b/src/qrisp/jasp/interpreter_tools/__init__.py @@ -16,5 +16,6 @@ ********************************************************************************/ """ +from qrisp.jasp.interpreter_tools.dynamic_list import * from qrisp.jasp.interpreter_tools.abstract_interpreter import * from qrisp.jasp.interpreter_tools.interpreters import * diff --git a/src/qrisp/jasp/interpreter_tools/dynamic_list.py b/src/qrisp/jasp/interpreter_tools/dynamic_list.py new file mode 100644 index 00000000..97adcd0f --- /dev/null +++ b/src/qrisp/jasp/interpreter_tools/dynamic_list.py @@ -0,0 +1,163 @@ +""" +\******************************************************************************** +* Copyright (c) 2023 the Qrisp authors +* +* This program and the accompanying materials are made available under the +* terms of the Eclipse Public License 2.0 which is available at +* http://www.eclipse.org/legal/epl-2.0. +* +* This Source Code may also be made available under the following Secondary +* Licenses when the conditions for such availability set forth in the Eclipse +* Public License, v. 2.0 are satisfied: GNU General Public License, version 2 +* with the GNU Classpath Exception which is +* available at https://www.gnu.org/software/classpath/license.html. +* +* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 +********************************************************************************/ +""" + +import jax +import jax.numpy as jnp + +@jax.tree_util.register_pytree_node_class +class Jlist: + + fill_value = 0 + + def __init__(self, init_val = None, max_size = int(2**10)): + self.max_size = max_size + self.array, self.counter = self._create_dynamic_array(init_val) + + def _create_dynamic_array(self, init_val): + jax_array = jnp.zeros(self.max_size, dtype = jnp.int64) + + n = 0 + + if init_val is not None: + + if isinstance(init_val, list): + n = len(init_val) + else: + n = init_val.size + + # Create an index array for updating + idx = jnp.arange(min(n, jax_array.size), dtype = jnp.int64) + + # Use JAX's index_update to fill the array + jax_array = jax_array.at[idx].set(jnp.array(init_val[:jax_array.size], dtype = jnp.int64), indices_are_sorted = True) + + return jax_array, jnp.array(min(n, self.max_size), dtype = jnp.int64) + + def append(self, value): + self.array, self.counter = self._append(value) + return self + + @jax.jit + def _append(self, value): + new_array = self.array.at[self.counter].set(value) + new_counter = jnp.minimum(self.counter + 1, self.array.shape[0]) + return new_array, new_counter + + + def pop(self): + self.counter, value = self._pop() + return value + + @jax.jit + def _pop(self): + new_counter = self.counter - 1 + value = self.array[new_counter] + return new_counter, value + + + def extend(self, values): + self.array, self.counter = self._extend(self.array, self.counter, values) + return self + + @jax.jit + def _extend(self, array, counter, values): + def body_fun(i, state): + curr_array, curr_counter = state + new_array = curr_array.at[curr_counter].set(values[i]) + new_counter = jnp.minimum(curr_counter + 1, self.max_size) + return new_array, new_counter + + return jax.lax.fori_loop(0, values.counter, body_fun, (array, counter)) + + @jax.jit + def clear(self): + self.array, self.counter = self._clear(self.array, self.counter) + return self + + @staticmethod + def _clear(array, counter): + return array, jnp.array(0) + + def __getitem__(self, key): + if isinstance(key, slice): + + if key.start is None: + start = 0 + else: + start = jnp.maximum(key.start, 0) + + if key.stop is None: + stop = self.counter + else: + stop = jnp.minimum(key.stop, self.counter) + + length = stop - start + + def body_fun(i, state): + new_array, old_array = state + new_array = new_array.at[i].set(old_array[i+start]) + return new_array, old_array + + new_array = jnp.zeros(self.max_size, dtype = jnp.int64) + + new_array, _ = jax.lax.fori_loop(0, length, body_fun, (new_array, self.array)) + + res = Jlist.__new__(Jlist) + res.array = new_array + res.counter = length + res.max_size = self.max_size + + return res + else: + return self.array[key] + + @jax.jit + def _slice(array, counter, start, end): + start = jnp.maximum(0, start) + end = jnp.minimum(counter, end) + return array[start:end] + + def __len__(self): + return int(self.counter) + + def flatten(self): + """ + Flatten the DynamicJaxArray into a tuple of arrays and auxiliary data. + This is useful for JAX transformations and serialization. + """ + return (self.array, self.counter), tuple() + + @classmethod + def unflatten(cls, aux_data, children): + """ + Recreate a DynamicJaxArray from flattened data. + """ + array, counter = children + obj = cls() + obj.array = array + obj.counter = counter + return obj + + # Add this method to make the class compatible with jax.tree_util + def tree_flatten(self): + return self.flatten() + + # Add this class method to make the class compatible with jax.tree_util + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls.unflatten(aux_data, children) diff --git a/src/qrisp/jasp/interpreter_tools/interpreters/cl_func_interpreter.py b/src/qrisp/jasp/interpreter_tools/interpreters/cl_func_interpreter.py index a4b2c13a..1a8759e4 100644 --- a/src/qrisp/jasp/interpreter_tools/interpreters/cl_func_interpreter.py +++ b/src/qrisp/jasp/interpreter_tools/interpreters/cl_func_interpreter.py @@ -27,7 +27,7 @@ from qrisp.circuit import ControlledOperation from qrisp.jasp import (QuantumPrimitive, OperationPrimitive, AbstractQuantumCircuit, AbstractQubitArray, -AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues) +AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues, Jlist) def cl_func_eqn_evaluator(eqn, context_dic): @@ -72,15 +72,34 @@ def process_create_qubits(invars, outvars, context_dic): # The first invar of the create_qubits primitive is an AbstractQuantumCircuit # which is represented by an AbstractQreg and an integer - qreg, stack_size = context_dic[invars[0]] + qreg, free_qubits = context_dic[invars[0]] + + size = context_dic[invars[1]] # We create the new QubitArray representation by putting the appropriate tuple # in the context_dic - context_dic[outvars[1]] = (stack_size, context_dic[invars[1]]) + + reg_qubits = Jlist() + + def loop_body(i, val_tuple): + free_qubits, reg_qubits = val_tuple + reg_qubits.append(free_qubits.pop()) + return free_qubits, reg_qubits + + @jit + def make_tracer(x): + return x + + size = make_tracer(size) + + + free_qubits, reg_qubits = fori_loop(0, size, loop_body, (free_qubits, reg_qubits)) + + context_dic[outvars[1]] = reg_qubits # Furthermore we create the updated AbstractQuantumCircuit representation. # The new stack size is the old stask size + the size of the QubitArray - context_dic[outvars[0]] = (qreg, stack_size + context_dic[invars[1]]) + context_dic[outvars[0]] = (qreg, free_qubits) def process_get_qubit(invars, outvars, context_dic): @@ -89,28 +108,22 @@ def process_get_qubit(invars, outvars, context_dic): # For that we add the Qubit index (in the QubitArray) to the QubitArray # starting index. - qubit_array_starting_index = context_dic[invars[0]][0] - qubit_index = context_dic[invars[1]] - context_dic[outvars[0]] = qubit_array_starting_index + qubit_index + reg_qubits = context_dic[invars[0]] + index = context_dic[invars[1]] + context_dic[outvars[0]] = reg_qubits[index] def process_slice(invars, outvars, context_dic): - base_qubit_array_starting_index = context_dic[invars[0]][0] - base_qubit_array_ending_index = base_qubit_array_starting_index + context_dic[invars[0]][1] - - new_starting_index = context_dic[invars[1]] + base_qubit_array_starting_index - - new_max_index = jnp.min(jnp.array([base_qubit_array_ending_index, - base_qubit_array_starting_index + context_dic[invars[2]]])) - - new_size = new_max_index - new_starting_index + reg_qubits = context_dic[invars[0]] + start = context_dic[invars[1]] + stop = context_dic[invars[2]] - context_dic[outvars[0]] = (new_starting_index, new_size) + context_dic[outvars[0]] = reg_qubits[start:stop] def process_get_size(invars, outvars, context_dic): # The size is simply the second entry of the QubitArray representation - context_dic[outvars[0]] = context_dic[invars[0]][1] + context_dic[outvars[0]] = context_dic[invars[0]].counter def process_op(op_prim, invars, outvars, context_dic): @@ -136,7 +149,6 @@ def process_op(op_prim, invars, outvars, context_dic): context_dic[outvars[0]] = context_dic[invars[0]] return else: - print(type(op)) raise Exception(f"Classical function simulator can't process gate {op.name}") bit_array = cl_multi_cx(bit_array, ctrl_state, qb_pos) @@ -171,12 +183,10 @@ def process_measurement(invars, outvars, context_dic): if isinstance(invars[1].aval, AbstractQubitArray): # Retrieve the start and the endpoint indices of the QubitArray - qubit_array_data = context_dic[invars[1]] - start = qubit_array_data[0] - stop = start + qubit_array_data[1] + qubit_reg = context_dic[invars[1]] # The multi measurement logic is outsourced into a dedicated function - bit_array, meas_res = exec_multi_measurement(bit_array, start, stop) + bit_array, meas_res = exec_multi_measurement(bit_array, qubit_reg) # The singular Qubit case else: @@ -190,7 +200,7 @@ def process_measurement(invars, outvars, context_dic): context_dic[outvars[1]] = meas_res -def exec_multi_measurement(bit_array, start, stop): +def exec_multi_measurement(bit_array, qubit_reg): # This function performs the measurement of multiple qubits at once, returning # an integer. The qubits to be measured sit in one consecutive interval, # starting at index "start" and ending at "stop". @@ -201,13 +211,13 @@ def exec_multi_measurement(bit_array, start, stop): # loop def loop_body(i, arg_tuple): - acc, bit_array = arg_tuple - res_bl = get_bit_array(bit_array, i) - acc = acc + (jnp.asarray(1, dtype = "int64")<<(i-start))*res_bl - i += jnp.asarray(1, dtype = "int64") - return (acc, bit_array) + acc, bit_array, qubit_reg = arg_tuple + qb_index = qubit_reg[i] + res_bl = get_bit_array(bit_array, qb_index) + acc = acc + (jnp.asarray(1, dtype = "int64")<>6 int_index = (index ^ (array_index << 6)) return (bit_array[array_index] >> int_index) & 1 + +def unflatten_signature(values, variables): + values = list(values) + unflattened_values = [] + for var in variables: + if isinstance(var.aval, AbstractQuantumCircuit): + bit_array = values.pop(0) + jlist_tuple = (values.pop(0), values.pop(0)) + unflattened_values.append((bit_array, Jlist.unflatten([], jlist_tuple))) + elif isinstance(var.aval, AbstractQubitArray): + jlist_tuple = (values.pop(0), values.pop(0)) + unflattened_values.append(Jlist.unflatten([], jlist_tuple)) + else: + unflattened_values.append(values.pop(0)) + + return unflattened_values + +def flatten_signature(values, variables): + values = list(values) + flattened_values = [] + for i in range(len(variables)): + var = variables[i] + value = values.pop(0) + if isinstance(var.aval, AbstractQuantumCircuit): + flattened_values.extend((value[0], *value[1].flatten()[0])) + elif isinstance(var.aval, AbstractQubitArray): + flattened_values.extend(value.flatten()[0]) + else: + flattened_values.append(value) + + return flattened_values + +def ensure_conversion(jaxpr, invalues): + + bit_array_padding = 0 + convert = False + for i in range(len(jaxpr.invars)): + invar = jaxpr.invars[i] + if isinstance(invar.aval, (AbstractQuantumCircuit, AbstractQubitArray, AbstractQubit)): + convert = True + if isinstance(invar.aval, AbstractQuantumCircuit): + bit_array_padding = invalues[i][0].shape[0]*64 + + if convert: + return jaspr_to_cl_func_jaxpr(jaxpr, bit_array_padding) + return ClosedJaxpr(jaxpr, []) \ No newline at end of file