Skip to content

Commit

Permalink
Merge pull request #53 from samuelstjean/split_b0s
Browse files Browse the repository at this point in the history
Split b0s
  • Loading branch information
samuelstjean authored Oct 20, 2017
2 parents 2cbf104 + b7c2bee commit d91a786
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 78 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- PIESNO will now warn if less than 1% of noisy voxels were identified, which might indicate that something have gone wrong during the noise estimation.
- On python >= 3.4, --mp_method [a_valid_start_method](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) can now be used to control behavior in the multiprocessing loop.
- A new option --split_b0s can be specified to split the b0s equally amongst the training data.
- Fixed crash in option --noise_est local_std when --cores 1 was also supplied.
- setup.py and requirements.txt will now fetch spams v2.6, with patches for numpy 1.12 support.
- The GSL library and associated headers are now bundled for all platforms.
Expand Down
17 changes: 15 additions & 2 deletions nlsam/angular_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,28 @@ def angular_neighbors(vec, n):
# Sort the values and only keep the n closest neighbors.
# The first angle is always 0, since _angle always
# computes the angle between the vector and itself.
# Therefore we pick the rest of n+1 vectors excluding the first one.
return np.argsort(_angle(vec))[:, 1:n + 1]
# Therefore we pick the rest of n+1 vectors and exclude the index
# itself if it was picked, which can happen if we have N repetition of dwis
# but want n < N angular neighbors
arr = np.argsort(_angle(vec))[:, :n+1]

# We only want n elements - either we remove an index and return the remainder
# or we don't and only return the n first indexes.
output = np.zeros((arr.shape[0], n), dtype=np.int32)
for i in range(arr.shape[0]):
cond = i != arr[i]
output[i] = arr[i, cond][:n]

return output


def _angle(vec):
"""
Inner function that finds the angle between all vectors of the input.
The diagonal is the angle between each vector and itself, thus 0 everytime.
It should not be called as is, since it serves mainly as a shortcut for other functions.
arccos(0) = pi/2, so b0s are always far from everyone in this formulation.
"""

vec = np.array(vec)
Expand Down
132 changes: 60 additions & 72 deletions nlsam/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging

from time import time
from itertools import cycle

from nlsam.utils import im2col_nd, col2im_nd
from nlsam.angular_tools import angular_neighbors
Expand All @@ -22,7 +23,7 @@


def nlsam_denoise(data, sigma, bvals, bvecs, block_size,
mask=None, is_symmetric=False, n_cores=None,
mask=None, is_symmetric=False, n_cores=None, split_b0s=False,
subsample=True, n_iter=10, b0_threshold=10, verbose=False, mp_method=None):
"""Main nlsam denoising function which sets up everything nicely for the local
block denoising.
Expand Down Expand Up @@ -51,6 +52,9 @@ def nlsam_denoise(data, sigma, bvals, bvecs, block_size,
n_cores : int, default None
Number of processes to use for the denoising. Default is to use
all available cores.
split_b0s : bool, default False
If True and the dataset contains multiple b0s, a different b0 will be used for
each run of the denoising. If False, the b0s are averaged and the average b0 is used instead.
subsample : bool, default True
If True, find the smallest subset of indices required to process each
dwi at least once.
Expand Down Expand Up @@ -85,102 +89,86 @@ def nlsam_denoise(data, sigma, bvals, bvecs, block_size,
raise ValueError('Block shape {} and data shape {} are not of the same '
'length'.format(data.shape, block_size.shape))

b0_loc = tuple(np.where(bvals <= b0_threshold)[0])
b0_loc = np.where(bvals <= b0_threshold)[0]
dwis = np.where(bvals > b0_threshold)[0]
num_b0s = len(b0_loc)
variance = sigma**2
orig_shape = data.shape

# We also convert bvecs associated with b0s to exactly (0,0,0), which
# is not always the case when we hack around with the scanner.
bvecs = np.where(bvals[:, None] <= b0_threshold, 0, bvecs)

logger.info("Found {} b0s at position {}".format(str(num_b0s), str(b0_loc)))

# Average multiple b0s, and just use the average for the rest of the script
# patching them in at the end
if num_b0s > 1:
mean_b0 = np.mean(data[..., b0_loc], axis=-1)
dwis = tuple(np.where(bvals > b0_threshold)[0])
data = data[..., dwis]
bvals = np.take(bvals, dwis, axis=0)
bvecs = np.take(bvecs, dwis, axis=0)

rest_of_b0s = b0_loc[1:]
b0_loc = b0_loc[0]

data = np.insert(data, b0_loc, mean_b0, axis=-1)
bvals = np.insert(bvals, b0_loc, [0.], axis=0)
bvecs = np.insert(bvecs, b0_loc, [0., 0., 0.], axis=0)
b0_loc = tuple([b0_loc])
# Average all b0s if we don't split them in the training set
if num_b0s > 1 and not split_b0s:
num_b0s = 1
else:
rest_of_b0s = None
data[..., b0_loc] = np.mean(data[..., b0_loc], axis=-1, keepdims=True)

# Split the b0s in a cyclic fashion along the training data
# If we only had one, cycle just return b0_loc indefinitely,
# else we go through all indexes.
np.random.shuffle(b0_loc)
split_b0s_idx = cycle(b0_loc)

# Double bvecs to find neighbors with assumed symmetry if needed
if is_symmetric:
logger.info('Data is assumed to be already symmetric.')
sym_bvecs = np.delete(bvecs, b0_loc, axis=0)
sym_bvecs = bvecs
else:
sym_bvecs = np.vstack((np.delete(bvecs, b0_loc, axis=0), np.delete(-bvecs, b0_loc, axis=0)))
sym_bvecs = np.vstack((bvecs, -bvecs))

neighbors = (angular_neighbors(sym_bvecs, block_size[-1] - num_b0s) % (data.shape[-1] - num_b0s))[:data.shape[-1] - num_b0s]
neighbors = angular_neighbors(sym_bvecs, block_size[-1] - 1) % data.shape[-1]
neighbors = neighbors[:data.shape[-1]] # everything was doubled for symmetry

# Full overlap for dictionary learning
overlap = np.array(block_size, dtype=np.int16) - 1
b0 = np.squeeze(data[..., b0_loc])
data = np.delete(data, b0_loc, axis=-1)

indexes = [(i,) + tuple(neighbors[i]) for i in range(len(neighbors))]
full_indexes = [(dwi,) + tuple(neighbors[dwi]) for dwi in range(data.shape[-1]) if dwi in dwis]

if subsample:
indexes = greedy_set_finder(indexes)
indexes = greedy_set_finder(full_indexes)
else:
indexes = full_indexes

b0_block_size = tuple(block_size[:-1]) + ((block_size[-1] + num_b0s,))
# If we have more b0s than indexes, then we have to add a few more blocks since
# we won't do a full cycle. If we have more b0s than indexes after that, then it breaks.
if num_b0s > len(indexes):
the_rest = [rest for rest in full_indexes if rest not in indexes]
indexes += the_rest[:(num_b0s - len(indexes))]

denoised_shape = data.shape[:-1] + (data.shape[-1] + num_b0s,)
data_denoised = np.zeros(denoised_shape, np.float32)
if num_b0s > len(indexes):
error = ('Seems like you still have more b0s {} than available blocks {},'
' either average them or deactivate subsampling.'.format(num_b0s, len(indexes)))
raise ValueError(error)

b0_block_size = tuple(block_size[:-1]) + ((block_size[-1] + 1,))
data_denoised = np.zeros(data.shape, np.float32)
divider = np.zeros(data.shape[-1])

# Put all idx + b0 in this array in each iteration
to_denoise = np.empty(data.shape[:-1] + (block_size[-1] + 1,), dtype=np.float64)

for i, idx in enumerate(indexes):
dwi_idx = tuple(np.where(idx <= b0_loc, idx, np.array(idx) + num_b0s))
logger.info('Now denoising volumes {} / block {} out of {}.'.format(idx, i + 1, len(indexes)))

to_denoise[..., 0] = np.copy(b0)
for i, idx in enumerate(indexes, start=1):
b0_loc = tuple((next(split_b0s_idx),))
to_denoise[..., 0] = data[..., b0_loc].squeeze()
to_denoise[..., 1:] = data[..., idx]

data_denoised[..., b0_loc + dwi_idx] += local_denoise(to_denoise,
b0_block_size,
overlap,
variance,
n_iter=n_iter,
mask=mask,
dtype=np.float64,
n_cores=n_cores,
verbose=verbose,
mp_method=mp_method)

divider = np.bincount(np.array(indexes, dtype=np.int16).ravel())
divider = np.insert(divider, b0_loc, len(indexes))

data_denoised = data_denoised[:orig_shape[0],
:orig_shape[1],
:orig_shape[2],
:orig_shape[3]] / divider

# Put back the original number of b0s
if rest_of_b0s is not None:

b0_denoised = np.squeeze(data_denoised[..., b0_loc])
data_denoised_insert = np.empty(orig_shape, dtype=np.float32)
n = 0

for i in range(orig_shape[-1]):
if i in rest_of_b0s:
data_denoised_insert[..., i] = b0_denoised
n += 1
else:
data_denoised_insert[..., i] = data_denoised[..., i - n]

data_denoised = data_denoised_insert

divider[list(b0_loc + idx)] += 1

logger.info('Now denoising volumes {} / block {} out of {}.'.format(b0_loc + idx, i, len(indexes)))

data_denoised[..., b0_loc + idx] += local_denoise(to_denoise,
b0_block_size,
overlap,
variance,
n_iter=n_iter,
mask=mask,
dtype=np.float64,
n_cores=n_cores,
verbose=verbose,
mp_method=mp_method)

data_denoised /= divider
return data_denoised


Expand Down
2 changes: 1 addition & 1 deletion nlsam/tests/test_angular_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_angular_neighbors():
[-1, -2, -3]]
neighbors = angular_neighbors(vectors, 2)
true_neighbors = np.array([[1, 2],
[1, 2],
[0, 2],
[0, 1],
[0, 1]])

Expand Down
12 changes: 9 additions & 3 deletions scripts/nlsam_denoising
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def buildArgsParser():
metavar='int', default=10, type=int,
help='Lowest b-value to be considered as a b0. Default 10.')

p.add_argument('--split_b0s', action='store_true', dest='split_b0s',
help='If set and multiple b0s are present, they are split amongst the '
'training data.')

p.add_argument('--block_size', action='store', dest='spatial_block_size',
metavar='tuple', type=literal_eval, default=(3, 3, 3),
help='Size of the 3D spatial patch to be denoised. Default : 3, 3, 3')
Expand Down Expand Up @@ -182,6 +186,7 @@ def main():
is_symmetric = args.is_symmetric
n_iter = args.iterations
b0_threshold = args.b0_threshold
split_b0s = args.split_b0s
mp_method = args.mp_method
block_size = np.array(args.spatial_block_size + (args.angular_block_size,))

Expand Down Expand Up @@ -211,9 +216,9 @@ def main():
# Load up data and do some sanity checks
##########################################

overwritable_files = [args.output,
args.save_sigma,
args.save_piesno_mask,
overwritable_files = [args.output,
args.save_sigma,
args.save_piesno_mask,
args.save_stab]

for f in overwritable_files:
Expand Down Expand Up @@ -371,6 +376,7 @@ def main():
mask=mask,
is_symmetric=is_symmetric,
n_cores=n_cores,
split_b0s=split_b0s,
subsample=subsample,
n_iter=n_iter,
b0_threshold=b0_threshold,
Expand Down

0 comments on commit d91a786

Please sign in to comment.