Skip to content

Commit

Permalink
replace numba parallel vectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 5, 2024
1 parent d126a6e commit 9a2827b
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 115 deletions.
254 changes: 172 additions & 82 deletions quimb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from numbers import Integral

import numba
import numpy as np
import scipy.sparse as sp

Expand Down Expand Up @@ -39,33 +40,13 @@ def prod(iterable):

_NUM_THREAD_WORKERS = psutil.cpu_count(logical=False)

if "NUMBA_NUM_THREADS" in os.environ:
if int(os.environ["NUMBA_NUM_THREADS"]) != _NUM_THREAD_WORKERS:
import warnings

warnings.warn(
"'NUMBA_NUM_THREADS' has been set elsewhere and doesn't match the "
"value 'quimb' has tried to set - "
f"{os.environ['NUMBA_NUM_THREADS']} vs {_NUM_THREAD_WORKERS}."
)
else:
os.environ["NUMBA_NUM_THREADS"] = str(_NUM_THREAD_WORKERS)

# need to set NUMBA_NUM_THREADS first
import numba # noqa

_NUMBA_CACHE = {
"TRUE": True,
"ON": True,
"FALSE": False,
"OFF": False,
}[os.environ.get("QUIMB_NUMBA_CACHE", "True").upper()]
_NUMBA_PAR = {
"TRUE": True,
"ON": True,
"FALSE": False,
"OFF": False,
}[os.environ.get("QUIMB_NUMBA_PARALLEL", "True").upper()]

njit = functools.partial(numba.njit, cache=_NUMBA_CACHE)
"""Numba no-python jit, but obeying cache setting.
Expand All @@ -75,15 +56,6 @@ def prod(iterable):
"""Numba vectorize, but obeying cache setting.
"""

pvectorize = functools.partial(
numba.vectorize,
cache=_NUMBA_CACHE,
target="parallel" if _NUMBA_PAR else "cpu",
)
"""Numba vectorize, but obeying cache setting, with optional parallel
target, depending on environment variable 'QUIMB_NUMBA_PARALLEL'.
"""


class CacheThreadPool(object):
""" """
Expand Down Expand Up @@ -505,7 +477,7 @@ def threading_choose_num_blocks(size_total, target_block_size, num_threads):
# target blocks actually close to size target_block_size, for
# cyclically distributing work with potentially varying costs
target_block_size = -target_block_size
num_blocks = math.ceil(size_total / target_block_size)
num_blocks = np.ceil(size_total / target_block_size)
if num_blocks > num_threads:
# round to nearest multiple of num_threads
num_blocks = num_threads * round(num_blocks / num_threads)
Expand Down Expand Up @@ -548,29 +520,91 @@ def maybe_multithread(
pool.submit(
fn,
*args,
trank=trank,
thread_rank=thread_rank,
num_threads=num_threads,
target_block_size=target_block_size,
**kwargs,
)
for trank in range(num_threads)
for thread_rank in range(num_threads)
)


def _nb_complex_base(real, imag): # pragma: no cover
return real + 1j * imag

@njit(nogil=True)
def _complex_array_numba(
x, y, out, thread_rank=0, num_threads=1, target_block_size=2**15
): # pragma: no cover
N = x.size

_cmplx_sigs = ["complex64(float32, float32)", "complex128(float64, float64)"]
_nb_complex_seq = vectorize(_cmplx_sigs)(_nb_complex_base)
_nb_complex_par = pvectorize(_cmplx_sigs)(_nb_complex_base)
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
out[i] = complex(x[i], y[i])


def complex_array(real, imag):
def complex_array(x, y, num_threads=None, target_block_size=2**15):
"""Accelerated creation of complex array."""
if real.size > 50000:
return _nb_complex_par(real, imag)
return _nb_complex_seq(real, imag)
if x.dtype == "float32":
dtype = "complex64"
else:
dtype = "complex128"

N = x.size
out = np.empty(N, dtype=dtype)

maybe_multithread(
_complex_array_numba,
x,
y,
out,
size_total=N,
target_block_size=target_block_size,
num_threads=num_threads,
)
return out


@njit(nogil=True)
def _phase_to_complex_numba(
x, out, thread_rank=0, num_threads=1, target_block_size=2**10
): # pragma: no cover
N = x.size

num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
xi = x[i]
out[i] = complex(np.cos(xi), np.sin(xi))


def phase_to_complex(x, num_threads=None, target_block_size=2**10):
"""Convert an array of phases to actual complex numbers."""
if x.dtype == "float32":
dtype = "complex64"
else:
dtype = "complex128"

N = x.size
out = np.empty(N, dtype=dtype)
maybe_multithread(
_phase_to_complex_numba,
x.ravel(),
out,
size_total=N,
target_block_size=target_block_size,
num_threads=num_threads,
)
out.shape = x.shape
return out


@ensure_qarray
Expand Down Expand Up @@ -606,52 +640,108 @@ def mul(x, y):
return mul_dense(x, y)


def _nb_subtract_update_base(X, c, Z): # pragma: no cover
return X - c * Z
@njit(nogil=True)
def _subtract_update_2d_numba(
X, c, Y, thread_rank=0, num_threads=1, target_block_size=2**14
): # pragma: no cover
N, M = X.shape
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
for j in range(M):
X[i, j] -= c * Y[i, j]


_sbtrct_sigs = [
"float32(float32, float32, float32)",
"float32(float32, float64, float32)",
"float64(float64, float64, float64)",
"complex64(complex64, float32, complex64)",
"complex64(complex64, float64, complex64)",
"complex128(complex128, float64, complex128)",
]
_nb_subtract_update_seq = vectorize(_sbtrct_sigs)(_nb_subtract_update_base)
_nb_subtract_update_par = pvectorize(_sbtrct_sigs)(_nb_subtract_update_base)
@njit(nogil=True)
def _subtract_update_1d_numba(
X, c, Y, thread_rank=0, num_threads=1, target_block_size=2**14
): # pragma: no cover
(N,) = X.shape
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
X[i] -= c * Y[i]


def subtract_update_(X, c, Y):
def subtract_update_(X, c, Y, num_threads=None, target_block_size=2**14):
"""Accelerated inplace computation of ``X -= c * Y``. This is mainly
for Lanczos iteration.
"""
if X.size > 2048:
_nb_subtract_update_par(X, c, Y, out=X)
if X.ndim == 2:
fn = _subtract_update_2d_numba
else:
_nb_subtract_update_seq(X, c, Y, out=X)
fn = _subtract_update_1d_numba

maybe_multithread(
fn,
X,
c,
Y,
size_total=X.shape[0],
target_block_size=target_block_size,
num_threads=num_threads,
)

def _nb_divide_update_base(X, c): # pragma: no cover
return X / c

@njit(nogil=True)
def _divide_update_2d_numba(
X, c, out, thread_rank=0, num_threads=1, target_block_size=2**14
): # pragma: no cover
N, M = X.shape
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
for j in range(M):
out[i, j] = X[i, j] / c

_divd_sigs = [
"float32(float32, float32)",
"float64(float64, float64)",
"complex64(complex64, float32)",
"complex128(complex128, float64)",
]
_nb_divide_update_seq = vectorize(_divd_sigs)(_nb_divide_update_base)
_nb_divide_update_par = pvectorize(_divd_sigs)(_nb_divide_update_base)

@njit(nogil=True)
def _divide_update_1d_numba(
X, c, out, thread_rank=0, num_threads=1, target_block_size=2**14
): # pragma: no cover
(N,) = X.shape
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
out[i] = X[i] / c


def divide_update_(X, c, out):
def divide_update_(X, c, out, num_threads=None, target_block_size=2**14):
"""Accelerated computation of ``X / c`` into ``out``."""
if X.size > 2048:
_nb_divide_update_par(X, c, out=out)
if X.ndim == 2:
fn = _divide_update_2d_numba
else:
_nb_divide_update_seq(X, c, out=out)
fn = _divide_update_1d_numba

maybe_multithread(
fn,
X,
c,
out,
size_total=X.shape[0],
target_block_size=target_block_size,
num_threads=num_threads,
)


@njit(nogil=True) # pragma: no cover
Expand All @@ -661,7 +751,7 @@ def _dot_csr_matvec_numba(
indices,
vec,
out,
trank=0,
thread_rank=0,
num_threads=1,
target_block_size=-1024,
):
Expand All @@ -674,7 +764,7 @@ def _dot_csr_matvec_numba(
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, num_blocks, num_threads):
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
Expand Down Expand Up @@ -786,13 +876,13 @@ def rdot(a, b): # pragma: no cover

@njit(nogil=True)
def _l_diag_dot_dense_par(
l, A, out, trank=0, num_threads=1, target_block_size=128
l, A, out, thread_rank=0, num_threads=1, target_block_size=128
): # pragma: no cover
N, M = A.shape
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, num_blocks, num_threads):
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
Expand Down Expand Up @@ -853,13 +943,13 @@ def ldmul(diag, mat):

@njit(nogil=True)
def _r_diag_dot_dense_par(
A, l, out, trank=0, num_threads=1, target_block_size=128
A, l, out, thread_rank=0, num_threads=1, target_block_size=128
): # pragma: no cover
N, M = A.shape
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, num_blocks, num_threads):
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
Expand Down Expand Up @@ -918,12 +1008,12 @@ def rdmul(mat, diag):

@njit(nogil=True)
def _outer_par(
x, y, out, m, n, trank=0, num_threads=1, target_block_size=128
x, y, out, m, n, thread_rank=0, num_threads=1, target_block_size=128
): # pragma: no cover
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
m, target_block_size, num_threads
)
for b in range(trank, num_blocks, num_threads):
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
Expand Down Expand Up @@ -974,15 +1064,15 @@ def _kron_dense_numba(
n,
p,
q,
trank=0,
thread_rank=0,
num_threads=1,
target_block_size=128,
): # pragma: no cover
N = m * p
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, num_blocks, num_threads):
for b in range(thread_rank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
Expand Down
Loading

0 comments on commit 9a2827b

Please sign in to comment.