Skip to content

Commit

Permalink
updated the boolean simulator with a prototypical memory management s…
Browse files Browse the repository at this point in the history
…ystem
  • Loading branch information
positr0nium committed Jan 8, 2025
1 parent 44fad77 commit c8052e6
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 112 deletions.
28 changes: 16 additions & 12 deletions src/qrisp/jasp/boolean_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

import jax.numpy as jnp
from jax import jit
from jax.core import eval_jaxpr

from qrisp.jasp import make_jaspr

from qrisp.jasp.interpreter_tools.interpreters.cl_func_interpreter import jaspr_to_cl_func_jaxpr
from qrisp.jasp.interpreter_tools import Jlist, eval_jaxpr

def boolean_simulation(*func, bit_array_padding = 2**20):
def boolean_simulation(*func, bit_array_padding = 2**16):
"""
Decorator to simulate Jasp functions containing only classical logic (like X, CX, CCX etc.).
This decorator transforms the function into a Jax-Expression without any
Expand Down Expand Up @@ -164,24 +164,28 @@ def main(i, j):
if bit_array_padding < 64:
raise Exception("Tried to initialize boolean_simulation with less than 64 bits")

@jit
@jit
def return_function(*args):

jaspr = make_jaspr(func)(*args)
jaspr = make_jaspr(func, garbage_collection="manual")(*args)

cl_func_jaxpr = jaspr_to_cl_func_jaxpr(jaspr.flatten_environments(), bit_array_padding)

aval = cl_func_jaxpr.invars[0].aval
res = eval_jaxpr(cl_func_jaxpr,
[],
jnp.zeros(aval.shape, dtype = aval.dtype),
jnp.array(0, dtype = jnp.int64), *args)

bit_array = jnp.zeros(aval.shape, dtype = aval.dtype)
free_qubit_list = Jlist(jnp.arange(bit_array_padding), max_size = bit_array_padding).flatten()[0]
boolean_quantum_circuit = (bit_array, *free_qubit_list)


res = eval_jaxpr(cl_func_jaxpr)(*boolean_quantum_circuit,
*args)

if len(res) == 3:
return res[2]
elif len(res) == 2:
if len(res) == 4:
return res[3]
elif len(res) == 3:
return None
else:
return res[2:]
return res[3:]

return return_function
1 change: 1 addition & 0 deletions src/qrisp/jasp/interpreter_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
********************************************************************************/
"""

from qrisp.jasp.interpreter_tools.dynamic_list import *
from qrisp.jasp.interpreter_tools.abstract_interpreter import *
from qrisp.jasp.interpreter_tools.interpreters import *
163 changes: 163 additions & 0 deletions src/qrisp/jasp/interpreter_tools/dynamic_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""
\********************************************************************************
* Copyright (c) 2023 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

import jax
import jax.numpy as jnp

@jax.tree_util.register_pytree_node_class
class Jlist:

fill_value = 0

def __init__(self, init_val = None, max_size = int(2**10)):
self.max_size = max_size
self.array, self.counter = self._create_dynamic_array(init_val)

def _create_dynamic_array(self, init_val):
jax_array = jnp.zeros(self.max_size, dtype = jnp.int64)

n = 0

if init_val is not None:

if isinstance(init_val, list):
n = len(init_val)
else:
n = init_val.size

# Create an index array for updating
idx = jnp.arange(min(n, jax_array.size), dtype = jnp.int64)

# Use JAX's index_update to fill the array
jax_array = jax_array.at[idx].set(jnp.array(init_val[:jax_array.size], dtype = jnp.int64), indices_are_sorted = True)

return jax_array, jnp.array(min(n, self.max_size), dtype = jnp.int64)

def append(self, value):
self.array, self.counter = self._append(value)
return self

@jax.jit
def _append(self, value):
new_array = self.array.at[self.counter].set(value)
new_counter = jnp.minimum(self.counter + 1, self.array.shape[0])
return new_array, new_counter


def pop(self):
self.counter, value = self._pop()
return value

@jax.jit
def _pop(self):
new_counter = self.counter - 1
value = self.array[new_counter]
return new_counter, value


def extend(self, values):
self.array, self.counter = self._extend(self.array, self.counter, values)
return self

@jax.jit
def _extend(self, array, counter, values):
def body_fun(i, state):
curr_array, curr_counter = state
new_array = curr_array.at[curr_counter].set(values[i])
new_counter = jnp.minimum(curr_counter + 1, self.max_size)
return new_array, new_counter

return jax.lax.fori_loop(0, values.counter, body_fun, (array, counter))

@jax.jit
def clear(self):
self.array, self.counter = self._clear(self.array, self.counter)
return self

@staticmethod
def _clear(array, counter):
return array, jnp.array(0)

def __getitem__(self, key):
if isinstance(key, slice):

if key.start is None:
start = 0
else:
start = jnp.maximum(key.start, 0)

if key.stop is None:
stop = self.counter
else:
stop = jnp.minimum(key.stop, self.counter)

length = stop - start

def body_fun(i, state):
new_array, old_array = state
new_array = new_array.at[i].set(old_array[i+start])
return new_array, old_array

new_array = jnp.zeros(self.max_size, dtype = jnp.int64)

new_array, _ = jax.lax.fori_loop(0, length, body_fun, (new_array, self.array))

res = Jlist.__new__(Jlist)
res.array = new_array
res.counter = length
res.max_size = self.max_size

return res
else:
return self.array[key]

@jax.jit
def _slice(array, counter, start, end):
start = jnp.maximum(0, start)
end = jnp.minimum(counter, end)
return array[start:end]

def __len__(self):
return int(self.counter)

def flatten(self):
"""
Flatten the DynamicJaxArray into a tuple of arrays and auxiliary data.
This is useful for JAX transformations and serialization.
"""
return (self.array, self.counter), tuple()

@classmethod
def unflatten(cls, aux_data, children):
"""
Recreate a DynamicJaxArray from flattened data.
"""
array, counter = children
obj = cls()
obj.array = array
obj.counter = counter
return obj

# Add this method to make the class compatible with jax.tree_util
def tree_flatten(self):
return self.flatten()

# Add this class method to make the class compatible with jax.tree_util
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls.unflatten(aux_data, children)
Loading

0 comments on commit c8052e6

Please sign in to comment.