Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@

# Performance Upgrades To FABLE
- Vectorized/Numba JITified Walsh Hadamard Transforms -> 100 - 400x speedups
- Vectorized Gray Code Permutations -> upto 20-30x speedups
- Lookup tables for computing control qubits -> 50x speedups

- Overall atleast 30% speedup on base version.



# Fast Approximate BLock Encodings (FABLE)

FABLE can synthesize quantum circuits for approximate block-encodings of matrices. A block-encoding is the embedding of a matrix in the leading block of of a larger unitary matrix.
Expand Down
6 changes: 4 additions & 2 deletions py/fable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Fable"""
from .fable import fable
from .fable import fable,faster_fable

__all__ = [
'fable'
'fable',
'faster_fable'
]



__version__ = '1.0.1'
__author__ = '''Daan Camps'''
__maintainer__ = 'Daan Camps'
Expand Down
134 changes: 134 additions & 0 deletions py/fable/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def gray_permutation(a):
b[i] = a[gray_code(i)]
return b

def gray_permutation_vectorized(a):
"""Fast Gray code permutation using NumPy vectorization."""
indices = np.arange(a.shape[0])
return a[indices ^ (indices >> 1)]


def sfwht(a):
'''Scaled Fast Walsh-Hadamard transform of input vector a.
Expand All @@ -51,7 +56,56 @@ def sfwht(a):
a[j + 2**h] = (x - y) / 2
return a

from numba import njit

@njit
def sfwht_numba(a):
"""
Numba Accelerated JIT SFWHT, can beat the vectorized version by almost 4x, and naive version by 400x

Args:
a (np.ndarray): Input vector of size 2^n.

Returns:
np.ndarray: Scaled Walsh-Hadamard transform of `a`.
"""
n = int(np.log2(a.shape[0]))
N = a.shape[0]
for h in range(n):
mh = 1 << (h + 1)
m = mh // 2
for i in range(0, N, mh):
for j in range(i, i + m):
x = a[j]
y = a[j + m]
a[j] = (x + y) / 2
a[j + m] = (x - y) / 2
return a

def sfwht_optimized_vectorized(a):
"""
Fully vectorized SFWHT using 2D reshaping to eliminate Python loops. Almost 50-100x faster than naive version

Args:
a (np.ndarray): Input vector of size 2^n.

Returns:
np.ndarray: Scaled Walsh-Hadamard transform of `a`.
"""
n = int(np.log2(a.shape[0]))
N = a.shape[0]

for h in range(n):
mh = 1 << (h + 1)
m = mh // 2
a = a.reshape(-1, mh)
a[:, :m] = a[:, :m] + a[:, m:]
a[:, m:] = a[:, :m] - 2 * a[:, m:]
a = a.flatten()

a /= N
return a

def compute_control(i, n):
'''Compute the control qubit index based on the index i and size n.'''
if i == 4**n:
Expand Down Expand Up @@ -101,3 +155,83 @@ def compressed_uniform_rotation(a, ry=True):
circ.cx(j, 0)

return circ


def count_trailing_ones_optimized(x):
if x == 0:
return 0
return ((x ^ (x + 1)) >> 1).bit_length()

def compute_control_optimized(i, n, max_i):
if i == max_i:
return 1
trailing_ones = count_trailing_ones_optimized(i - 1)
return 2*n - trailing_ones

def precompute_control_lut(n):
max_trailing_ones = 2*n
return [2*n - t for t in range(max_trailing_ones + 1)]

from functools import lru_cache

# @lru_cache(maxsize=None)
def compressed_uniform_rotation_with_lut(a, ry=True):
"""
Optimized compressed uniform rotation using a control LUT.

Args:
a (np.ndarray): Thresholded vector of dimension 2^(2n)
ry (bool): Apply RY if True, RZ otherwise

Returns:
QuantumCircuit: Qiskit circuit representing the rotation.
"""
from qiskit import QuantumCircuit

n = int(np.log2(a.shape[0]) // 2)
max_i = a.shape[0]
circ = QuantumCircuit(2 * n + 1)

# Precompute control LUT (size O(n))
max_trailing_ones = 2 * n
control_lut = precompute_control_lut(n)

i = 0
while i < max_i:
parity_check = 0

# Apply rotation
if a[i] != 0:
if ry:
circ.ry(a[i], 0)
else:
circ.rz(a[i], 0)

# Skip zero blocks
while True:
# Special case: i+1 == max_i
if i + 1 == max_i:
ctrl = 1
else:
# Compute trailing_ones(i) and use LUT
trailing_ones = count_trailing_ones_optimized(i)
ctrl = control_lut[trailing_ones]

# Toggle control bit
parity_check ^= (1 << (ctrl - 1))
i += 1

# Break if non-zero or end of array
if i >= max_i or a[i] != 0:
break

# Add CNOT gates (only for set bits)
bit = 1
pos = 0
while bit <= parity_check:
if parity_check & bit:
circ.cx(pos + 1, 0)
bit <<= 1
pos += 1

return circ
99 changes: 97 additions & 2 deletions py/fable/fable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,101 @@
# -*- coding: utf-8 -*-
import numpy as np
from qiskit import QuantumCircuit
from ._util import compressed_uniform_rotation, sfwht, gray_permutation
from ._util import compressed_uniform_rotation, sfwht, gray_permutation,sfwht_numba,gray_permutation_vectorized
from ._util import compressed_uniform_rotation_with_lut


def faster_fable(a, epsilon=None):
'''FABLE - Fast Approximate BLock Encodings.
compressed_uniform_rotation_with_lut
Args:
a: array
matrix to be block encoded.
epsilon: float >= 0
(optional) compression threshold.
Returns:
circuit: qiskit circuit
circuit that block encodes A
alpha: float
subnormalization factor
'''
epsm = np.finfo(a.dtype).eps
alpha = np.linalg.norm(np.ravel(a), np.inf)
# alpha = 1.0
if alpha > 1:
alpha = alpha + np.sqrt(epsm)
a = a/alpha
else:
alpha = 1.0

n, m = a.shape
if n != m:
k = max(n, m)
a = np.pad(a, ((0, k - n), (0, k - m)))
n = k
logn = int(np.ceil(np.log2(n)))
if n < 2**logn:
a = np.pad(a, ((0, 2**logn - n), (0, 2**logn - n)))
n = 2**logn

a = np.ravel(a)

if all(np.abs(np.imag(a)) < epsm): # real data
a = gray_permutation_vectorized(
sfwht_numba(
2.0 * np.arccos(np.real(a))
)
)
# threshold the vector
if epsilon:
a[abs(a) <= epsilon] = 0
# compute circuit
OA = compressed_uniform_rotation_with_lut(a)
else: # complex data
# magnitude
a_m = gray_permutation_vectorized(
sfwht_numba(
2.0 * np.arccos(np.abs(a))
)
)
if epsilon:
a_m[abs(a_m) <= epsilon] = 0

# phase
a_p = gray_permutation_vectorized(
sfwht_numba(
-2.0 * np.angle(a)
)
)
if epsilon:
a_p[abs(a_p) <= epsilon] = 0

# compute circuit
OA = compressed_uniform_rotation_with_lut(a_m).compose(
compressed_uniform_rotation_with_lut(a_p, ry=False)
)

circ = QuantumCircuit(2*logn + 1)

# diffusion on row indices
for i in range(logn):
circ.h(i+1)

# matrix oracle
circ = circ.compose(OA)

# swap register
for i in range(logn):
circ.swap(i+1, i+logn+1)

# diffusion on row indices
for i in range(logn):
circ.h(i+1)

# reverse bits because of little-endiannes
circ = circ.reverse_bits()

return circ, alpha


def fable(a, epsilon=None):
Expand All @@ -20,7 +114,8 @@ def fable(a, epsilon=None):
subnormalization factor
'''
epsm = np.finfo(a.dtype).eps
alpha = np.linalg.norm(np.ravel(a), np.inf)
# alpha = np.linalg.norm(np.ravel(a), np.inf)
alpha = 1.0
if alpha > 1:
alpha = alpha + np.sqrt(epsm)
a = a/alpha
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ python_requires = >=3.6
install_requires =
qiskit>=0.19.1
numpy>=1.20.3
numba

[options.packages.find]
where = py
where = py