Skip to content

Commit

Permalink
Merge pull request #380 from dchorel/Fix_concatenate_3D_with_4D_images
Browse files Browse the repository at this point in the history
[WIP] Concatenate 3d with 4d images into a 4d image
  • Loading branch information
arnaudbore authored Feb 15, 2021
2 parents abcf770 + 53dbd94 commit c8f5d2a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
36 changes: 27 additions & 9 deletions scilpy/image/operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-

"""
Utility operations provided for scil_image_math.py and scil_connectivity_math.py
Utility operations provided for scil_image_math.py
and scil_connectivity_math.py
They basically act as wrappers around numpy to avoid installing MRtrix/FSL
to apply simple operations on nibabel images or numpy arrays.
"""
Expand Down Expand Up @@ -54,7 +55,7 @@ def get_image_ops():
"""Get a dictionary of all functions relating to image operations"""
image_ops = get_array_ops()
image_ops.update(OrderedDict([
('concatenate', concat),
('concatenate', concatenate),
('dilation', dilation),
('erosion', erosion),
('closing', closing),
Expand Down Expand Up @@ -83,6 +84,13 @@ def _validate_imgs(*imgs):
raise ValueError('Not all inputs have the same shape!')


def _validate_imgs_concat(*imgs):
"""Make sure that all inputs are images."""
for img in imgs:
if not isinstance(img, nib.Nifti1Image):
raise ValueError('Inputs are not all images')


def _validate_length(input_list, length, at_least=False):
"""Make sure the the input list has the right number of arguments
(length)."""
Expand Down Expand Up @@ -499,20 +507,30 @@ def invert(input_list, ref_img):
return output_data


def concat(input_list, ref_img):
def concatenate(input_list, ref_img):
"""
concat: IMGs
Concatenate a list of 3D images into a single 4D image.
concatenate: IMGs
Concatenate a list of 3D and 4D images into a single 4D image.
"""
_validate_imgs(*input_list, ref_img)
if len(input_list[0].header.get_data_shape()) != 3:
raise ValueError('Concatenate require 3D arrays.')

_validate_imgs_concat(*input_list, ref_img)
if len(input_list[0].header.get_data_shape()) > 4:
raise ValueError('Concatenate require 3D or 4D arrays.')

input_data = []
for img in input_list:

data = img.get_fdata(dtype=np.float64)
input_data.append(data)

if len(img.header.get_data_shape()) == 4:
data = np.rollaxis(data, 3)
for i in range(0, len(data)):
input_data.append(data[i])
else:
input_data.append(data)

img.uncache()

return np.rollaxis(np.stack(input_data), axis=0, start=4)


Expand Down
14 changes: 13 additions & 1 deletion scripts/scil_image_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def main():
found_ref = True
break

# If there's a 4D image, replace the previous 3D image with
# this one for reference
for input_arg in args.in_images:
if not is_float(input_arg):
ref_img = nib.load(input_arg)
if len(ref_img.shape) == 4:
mask = np.zeros(ref_img.shape)
break

if not found_ref:
raise ValueError('Requires at least one nifti image.')

Expand Down Expand Up @@ -137,7 +146,10 @@ def main():

if isinstance(img, nib.Nifti1Image):
data = img.get_fdata(dtype=np.float64)
mask[data > 0] = 1
if data.ndim == 4:
mask[np.sum(data, axis=3).astype(bool) > 0] = 1
else:
mask[data > 0] = 1
img.uncache()
input_img.append(img)

Expand Down
14 changes: 13 additions & 1 deletion scripts/tests/test_image_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_execution_low_mult(script_runner):
assert ret.success


def test_execution_concat(script_runner):
def test_execution_concatenate(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_img_1 = os.path.join(get_home(), 'atlas', 'ids', '10.nii.gz')
in_img_2 = os.path.join(get_home(), 'atlas', 'ids', '11.nii.gz')
Expand All @@ -57,3 +57,15 @@ def test_execution_concat(script_runner):
in_img_1, in_img_2, in_img_3, in_img_4, in_img_5,
in_img_6, 'concat_ids.nii.gz')
assert ret.success


def test_execution_concatenate_4D(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_img_1 = os.path.join(get_home(), 'atlas', 'ids', '10.nii.gz')
in_img_2 = os.path.join(get_home(), 'atlas', 'ids', '8_10.nii.gz')
in_img_3 = os.path.join(get_home(), 'atlas', 'ids', '12.nii.gz')
in_img_4 = os.path.join(get_home(), 'atlas', 'ids', '8_10.nii.gz')
ret = script_runner.run('scil_image_math.py', 'concatenate',
in_img_1, in_img_2, in_img_3, in_img_4,
'concat_ids_4d.nii.gz')
assert ret.success

0 comments on commit c8f5d2a

Please sign in to comment.