Skip to content

Commit

Permalink
revamp multiprocessing (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstjean authored Jun 29, 2023
1 parent ce4ea08 commit 8b5d9c4
Show file tree
Hide file tree
Showing 13 changed files with 380 additions and 292 deletions.
53 changes: 33 additions & 20 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,39 @@ name: Upload Python Package
on:
release:
types: [created]
workflow_run:
workflows: [build_wheels, build_sdist]
types:
- completed

jobs:
upload_pypi:
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
# upload to PyPI on every tag starting with 'v'
# if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v')
# alternatively, to publish when a GitHub Release is created, use the following rule:
if: github.event_name == 'release' && github.event.action == 'published'
steps:
- uses: actions/download-artifact@v3
with:
name: artifact
path: dist
publish_artifacts:
runs-on: ubuntu-latest
# upload to PyPI on every tag starting with 'v'
# if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v')
# alternatively, to publish when a GitHub Release is created, use the following rule:
if: ${{ github.event.workflow_run.conclusion == 'success' }} && github.event_name == 'release' && github.event.action == 'published'
steps:
- name: Download builds
uses: actions/download-artifact@v3
with:
name: artifact
path: dist

- uses: pypa/gh-action-pypi-publish@release/v1
with:
verbose: true
print-hash: true
user: ${{ secrets.PYPI_USERNAME }}
password: ${{ secrets.PYPI_PASSWORD }}
# password: ${{ secrets.testpypi_password }}
# repository_url: https://test.pypi.org/legacy/
- name: upload to pypi
uses: pypa/gh-action-pypi-publish@release/v1
with:
verbose: true
print-hash: true
user: ${{ secrets.PYPI_USERNAME }}
password: ${{ secrets.PYPI_PASSWORD }}
# password: ${{ secrets.testpypi_password }}
# repository_url: https://test.pypi.org/legacy/

- name: publish to github release
uses: softprops/action-gh-release@v0.1.15
# permissions:
# contents: write
if: startsWith(github.ref, 'refs/tags/')
with:
files: dist/*
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog

## [0.7.1] - unreleased

- Some speed improvements internally
- Some more functions in parallel
- A new progress bar with tqdm
- New non-frozen builds for the standalone versions
- Mac M1/M2 arm64 binary wheels now available

## [0.7] - 2023-05-20

- **Breaking changes in the command line parser**
Expand Down
91 changes: 38 additions & 53 deletions nlsam/bias_correction.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
import numpy as np
import logging

from nlsam.stabilizer import fixed_point_finder, chi_to_gauss, root_finder, xi
from nlsam.stabilizer import root_finder_loop, multiprocess_stabilization
from joblib import Parallel, delayed
from tqdm.autonotebook import tqdm

logger = logging.getLogger('nlsam')

# Vectorised versions of the above, so we can use implicit broadcasting and stuff
vec_fixed_point_finder = np.vectorize(fixed_point_finder, [np.float64])
vec_chi_to_gauss = np.vectorize(chi_to_gauss, [np.float64])
vec_xi = np.vectorize(xi, [np.float64])
vec_root_finder = np.vectorize(root_finder, [np.float64])


def stabilization(data, m_hat, sigma, N, mask=None, clip_eta=True, return_eta=False, n_cores=-1, verbose=False):

data = np.asarray(data)
m_hat = np.asarray(m_hat)
sigma = np.atleast_3d(sigma)
N = np.atleast_3d(N)
sigma = np.atleast_3d(sigma).astype(np.float32)
N = np.atleast_3d(N).astype(np.float32)

if mask is None:
mask = np.ones(data.shape[:-1], dtype=bool)
Expand All @@ -41,55 +36,32 @@ def stabilization(data, m_hat, sigma, N, mask=None, clip_eta=True, return_eta=Fa
if (data.shape != m_hat.shape):
raise ValueError(f'data shape {data.shape} is not compatible with m_hat shape {m_hat.shape}')

arglist = ((data[..., idx, :],
m_hat[..., idx, :],
mask[..., idx],
sigma[..., idx, :],
N[..., idx, :],
clip_eta)
for idx in range(data.shape[-2]))
slicer = [np.index_exp[..., k] for k in range(data.shape[-1])]

# Did we ask for verbose at the module level?
if not verbose:
verbose = logger.getEffectiveLevel() <= 20 # Info or debug level
if verbose:
slicer = tqdm(slicer)

output = Parallel(n_jobs=n_cores,
verbose=verbose)(delayed(multiprocess_stabilization)(*args) for args in arglist)
with Parallel(n_jobs=n_cores, prefer='threads') as parallel:
output = parallel(delayed(multiprocess_stabilization)(data[current_slice],
m_hat[current_slice],
mask,
sigma[current_slice],
N[current_slice],
clip_eta) for current_slice in slicer)

data_stabilized = np.zeros_like(data, dtype=np.float32)
eta = np.zeros_like(data, dtype=np.float32)

for idx, content in enumerate(output):
data_stabilized[..., idx, :] = content[0]
eta[..., idx, :] = content[1]
data_stabilized[..., idx] = content[0]
eta[..., idx] = content[1]

if return_eta:
return data_stabilized, eta
return data_stabilized


def multiprocess_stabilization(data, m_hat, mask, sigma, N, clip_eta):
"""Helper function for multiprocessing the stabilization part."""

if mask.ndim == (sigma.ndim - 1):
mask = mask[..., None]

mask = np.logical_and(sigma > 0, mask)
out = np.zeros_like(data, dtype=np.float32)
eta = np.zeros_like(data, dtype=np.float32)

eta[mask] = vec_fixed_point_finder(m_hat[mask], sigma[mask], N[mask], clip_eta=clip_eta)
out[mask] = vec_chi_to_gauss(data[mask], eta[mask], sigma[mask], N[mask], use_nan=False)

return out, eta


def corrected_sigma(eta, sigma, N, mask=None):
logger.warning('The function nlsam.bias_correction.corrected_sigma was replaced by nlsam.bias_correction.root_finder_sigma')
return root_finder_sigma(eta, sigma, N, mask=mask)


def root_finder_sigma(data, sigma, N, mask=None):
def root_finder_sigma(data, sigma, N, mask=None, verbose=False, n_cores=-1):
"""Compute the local corrected standard deviation for the adaptive nonlocal
means according to the correction factor xi.
Expand All @@ -103,6 +75,10 @@ def root_finder_sigma(data, sigma, N, mask=None):
Number of coils of the acquisition (N=1 for Rician noise)
mask : ndarray, optional
Compute only the corrected sigma value inside the mask.
verbose : bool, optional
displays a progress bar if True
n_cores : int, optional
number of cores to use for parallel processing
Return
--------
Expand All @@ -114,7 +90,7 @@ def root_finder_sigma(data, sigma, N, mask=None):
N = np.array(N)

if mask is None:
mask = np.ones_like(sigma, dtype=bool)
mask = np.ones(data.shape[:-1], dtype=bool)
else:
mask = np.array(mask, dtype=bool)

Expand All @@ -127,13 +103,22 @@ def root_finder_sigma(data, sigma, N, mask=None):

corrected_sigma = np.zeros_like(data, dtype=np.float32)

# To not murder people ram, we process it slice by slice and reuse the arrays in a for loop
gaussian_SNR = np.zeros(np.count_nonzero(mask), dtype=np.float32)
theta = np.zeros_like(gaussian_SNR)
# The mask is only 3D, so this will make a 1D array to loop through
data = data[mask]
sigma = sigma[mask]
N = N[mask]

slicer = [np.index_exp[..., k] for k in range(data.shape[-1])]

for idx in range(data.shape[-1]):
theta[:] = data[..., idx][mask] / sigma[..., idx][mask]
gaussian_SNR[:] = vec_root_finder(theta, N[..., idx][mask])
corrected_sigma[..., idx][mask] = sigma[..., idx][mask] / np.sqrt(vec_xi(gaussian_SNR, 1, N[..., idx][mask]))
if verbose:
slicer = tqdm(slicer)

with Parallel(n_jobs=n_cores, prefer='threads') as parallel:
output = parallel(delayed(root_finder_loop)(data[current_slice],
sigma[current_slice],
N[current_slice]) for current_slice in slicer)

for idx, content in enumerate(output):
corrected_sigma[mask, idx] = content

return corrected_sigma
58 changes: 35 additions & 23 deletions nlsam/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from autodmri.blocks import extract_patches

from joblib import Parallel, delayed
from tqdm.autonotebook import tqdm

import spams

logger = logging.getLogger('nlsam')


def nlsam_denoise(data, sigma, bvals, bvecs, block_size,
mask=None, is_symmetric=False, n_cores=-1, split_b0s=False, split_shell=False,
subsample=True, n_iter=10, b0_threshold=10, bval_threshold=25, dtype=np.float64, verbose=False):
Expand Down Expand Up @@ -181,7 +181,6 @@ def nlsam_denoise(data, sigma, bvals, bvecs, block_size,
data_denoised /= divider
return data_denoised


def local_denoise(data, block_size, overlap, variance, n_iter=10, mask=None,
dtype=np.float64, n_cores=-1, verbose=False):
if verbose:
Expand Down Expand Up @@ -227,37 +226,51 @@ def local_denoise(data, block_size, overlap, variance, n_iter=10, mask=None,
param_alpha['numThreads'] = 1
param_D['numThreads'] = 1

arglist = ((data[:, :, k:k + block_size[2]],
mask[:, :, k:k + block_size[2]],
variance[:, :, k:k + block_size[2]],
block_size,
overlap,
param_alpha,
param_D,
dtype,
n_iter)
for k in range(data.shape[2] - block_size[2] + 1))
slicer = [np.index_exp[:, :, k:k + block_size[2]] for k in range((data.shape[2] - block_size[2] + 1))]

if verbose:
progress_slicer = tqdm(slicer) # This is because tqdm consumes the (generator) slicer, but we also need it later :/
else:
progress_slicer = slicer

time_multi = time()
data_denoised = Parallel(n_jobs=n_cores,
verbose=verbose)(delayed(processer)(*args) for args in arglist)

data_denoised = Parallel(n_jobs=n_cores)(delayed(processer)(data,
mask,
variance,
block_size,
overlap,
param_alpha,
param_D,
current_slice,
dtype,
n_iter)
for current_slice in progress_slicer)

logger.info(f'Multiprocessing done in {(time() - time_multi) / 60:.2f} mins.')

# Put together the multiprocessed results
data_subset = np.zeros_like(data, dtype=np.float32)
divider = np.zeros_like(data, dtype=np.int16)

for k, content in enumerate(data_denoised):
data_subset[:, :, k:k + block_size[2]] += content
divider[:, :, k:k + block_size[2]] += 1
for current_slice, content in zip(slicer, data_denoised):
data_subset[current_slice] += content
divider[current_slice] += 1

data_subset /= divider
return data_subset


def processer(data, mask, variance, block_size, overlap, param_alpha, param_D,
def processer(data, mask, variance, block_size, overlap, param_alpha, param_D, current_slice,
dtype=np.float64, n_iter=10, gamma=3, tau=1, tolerance=1e-5):

# Fetch the current slice for parallel processing since now the arrays are dumped and read from disk
# instead of passed around as smaller slices by the function to 'increase performance'

data = data[current_slice]
mask = mask[current_slice]
variance = variance[current_slice]

orig_shape = data.shape
mask_array = im2col_nd(mask, block_size[:-1], overlap[:-1])
train_idx = np.sum(mask_array, axis=0) > (mask_array.shape[0] / 2)
Expand All @@ -280,7 +293,7 @@ def processer(data, mask, variance, block_size, overlap, param_alpha, param_D,
temp = np.zeros([alpha.shape[0], 1], dtype=dtype)

DtD = np.asfortranarray(D.T @ D)
DtX = D.T @ X
DtX = np.asfortranarray(D.T @ X)
DtXW = np.empty_like(DtX, order='F')
DtDW = np.empty_like(DtD, order='F')

Expand All @@ -291,7 +304,7 @@ def processer(data, mask, variance, block_size, overlap, param_alpha, param_D,

xi = np.random.randn(X.shape[0], X.shape[1]) * var_mat
var_mat *= (X.shape[0] + gamma * np.sqrt(2 * X.shape[0]))
eps = np.max(np.abs(np.dot(D.T, xi)), axis=0)
eps = np.max(np.abs(D.T @ xi), axis=0)

for _ in range(n_iter):
DtXW[:, not_converged] = DtX[:, not_converged] / W[:, not_converged]
Expand All @@ -314,11 +327,10 @@ def processer(data, mask, variance, block_size, overlap, param_alpha, param_D,
alpha_old[:] = arr
W[:] = 1 / (np.abs(alpha_old**tau) + eps)

weights = np.ones(X_full_shape[1], dtype=dtype, order='F')
weights = np.ones(X_full_shape[1], dtype=dtype)
weights[train_idx] = 1 / (np.sum(alpha != 0, axis=0) + 1)

X = np.zeros(X_full_shape, dtype=dtype, order='F')
X[:, train_idx] = D @ arr
out = col2im_nd(X, block_size, orig_shape, overlap, weights)

del X, W, alpha, alpha_old, DtX, DtXW, DtDW
return out
Loading

0 comments on commit 8b5d9c4

Please sign in to comment.