Skip to content

Commit

Permalink
Add preprocessing for 2D data
Browse files Browse the repository at this point in the history
  • Loading branch information
ViiSkor committed May 9, 2020
1 parent d33b3d2 commit 1e1ae39
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 94 deletions.
68 changes: 0 additions & 68 deletions src/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,6 @@
from scipy import ndimage


def crop(data, masks, depth=None, slice_shape=None):
"""Crop samples for a neural network input.
Args:
data (`numpy.array`):
numpy arrays [x, y, z, channels]/[x, y, channels]. Scan data.
masks (`numpy.array`):
numpy arrays [x, y, z, channels]/[x, y, channels]. Ground truth data.
depth (int): New z of a sample.
slice_shape (tuple): New xy shape of a sample.
Returns:
data (`numpy.array`): Croped numpy arrays [x, y, z, channels]
"""


if slice_shape:
if len(data.shape) == 3:
vertical_shift = int((data.shape[0] - slice_shape[0]) // 2)
horizontal_shift = int((data.shape[1] - slice_shape[1]) // 2)
data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:]
elif len(data.shape) == 4:
vertical_shift = int((data.shape[1] - slice_shape[0]) // 2)
horizontal_shift = int((data.shape[2] - slice_shape[1]) // 2)
data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:]
else:
raise RuntimeError("unexpected dimension")

if depth:
if depth < data.shape[-1]:
if len(data.shape) == 4:
depth_shift = int((data.shape[-1] - depth) // 2)
data = data[:, :,:,depth_shift:depth+depth_shift]
masks = masks[:, :,:,depth_shift:depth+depth_shift]

return data, masks


def pad(data, masks, prev_shape, shape, n_channels, n_classes):
"""Pad samples for a neural network input.
Args:
data (`numpy.array`):
numpy arrays [x, y, z, channels]/[x, y, channels]. Scan data.
masks (`numpy.array`):
numpy arrays [x, y, z, channels]/[x, y, channels]. Ground truth data.
prev_shape (tuple): Old shape of a sample
shape (tuple): New shape of a sample.
n_channels (int): The number of a case's channels/modalities/classes.
Returns:
new_data (`numpy.array`): Padded numpy data [x, y, z, channels]
new_masks (`numpy.array`): Padded numpy ground truth [x, y, z, channels]
"""

new_data = np.zeros((n_channels, *shape))
new_masks = np.zeros((n_classes, *shape))
start = (np.array(shape) / 2. - np.array(prev_shape) / 2.).astype(int)
end = start + np.array([int(dim) for dim in prev_shape], dtype=int)
if len(shape) == 2:
new_data[start[0]:end[0], start[1]:end[1]] = data[:, :]
new_masks[start[0]:end[0], start[1]:end[1]] = masks[:, :]
elif len(shape) == 3:
new_data[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = data[:, :, :, :]
new_masks[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = masks[:, :, :, :]
else:
raise RuntimeError("unexpected dimension")
return new_data, new_masks


def augment(data, masks, params):
"""Augment samples.
Expand Down
171 changes: 146 additions & 25 deletions src/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import nibabel as nib
import numpy as np
from tqdm import tqdm


def fill_labels(img, slice_nums):
Expand All @@ -19,7 +21,7 @@ def preprocess_label(mask, output_classes=['ed'], merge_classes=False, out_shape
Args:
mask (numpy.array):
Ground truth numpy arrays [x, y, z, classes]. Whole volumes, channels of a case.
Ground truth numpy arrays [classes, x, y, z]. Whole volumes, channels of a case.
output_classes (:obj:`list` of :obj:`str`): classes to sepatare.
merge_classes (bool): Merge output_classes into one or not.
out_shape (tuple): Shape for scaling ground truth labels.
Expand Down Expand Up @@ -52,11 +54,85 @@ def preprocess_label(mask, output_classes=['ed'], merge_classes=False, out_shape
output += label
output = [output]
else:
masks = output
output = masks

return np.array(output, dtype=np.uint8)


def crop(data, masks, depth=None, slice_shape=None):
"""Crop samples for a neural network input.
Args:
data (`numpy.array`):
numpy arrays [channels, x, y, z]/[channels, x, y]. Scan data.
masks (`numpy.array`):
numpy arrays [channels, x, y, z]/[channels, x, y]. Ground truth data.
depth (int): New z of a sample.
slice_shape (tuple): New xy shape of a sample.
Returns:
data (`numpy.array`): Croped numpy arrays [channels, x, y, z]
"""


if slice_shape:
if len(data.shape) == 3:
vertical_shift = int((data.shape[0] - slice_shape[0]) // 2)
horizontal_shift = int((data.shape[1] - slice_shape[1]) // 2)
data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:]
elif len(data.shape) == 4:
vertical_shift = int((data.shape[1] - slice_shape[0]) // 2)
horizontal_shift = int((data.shape[2] - slice_shape[1]) // 2)
data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:]
else:
raise RuntimeError(f"Got unexpected dimension {len(data.shape)}")

if depth:
if depth < data.shape[-1]:
if len(data.shape) == 4:
depth_shift = int((data.shape[-1] - depth) // 2)
data = data[:, :,:,depth_shift:depth+depth_shift]
masks = masks[:, :,:,depth_shift:depth+depth_shift]

return data, masks


def pad(data, masks, prev_shape, shape, n_channels, n_classes):
"""Pad samples for a neural network input.
Args:
data (`numpy.array`):
numpy arrays [channels, x, y, z]/[channels, x, y]. Scan data.
masks (`numpy.array`):
numpy arrays [channels, x, y, z]/[channels, x, y]. Ground truth data.
prev_shape (tuple): Old shape of a sample
shape (tuple): New shape of a sample.
n_channels (int): The number of a case's channels/modalities/classes.
Returns:
new_data (`numpy.array`): Padded numpy data [channels, x, y, z]
new_masks (`numpy.array`): Padded numpy ground truth [channels, x, y, z]
"""

new_data = np.zeros((n_channels, *shape))
new_masks = np.zeros((n_classes, *shape))
start = (np.array(shape) / 2. - np.array(prev_shape) / 2.).astype(int)
end = start + np.array([int(dim) for dim in prev_shape], dtype=int)
if len(shape) == 2:
new_data[:, start[0]:end[0], start[1]:end[1]] = data[:, :, :]
new_masks[:, start[0]:end[0], start[1]:end[1]] = masks[:, :, :]
elif len(shape) == 3:
new_data[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = data[:, :, :, :]
new_masks[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = masks[:, :, :, :]
else:
raise RuntimeError(f"Got unexpected dimension {len(shape)}")
return new_data, new_masks



def prepare(data_paths:dict, dataset_name:str, preprocessed_dist:str, mode="3D"):
for i, imgs in enumerate(tqdm(data_paths)):
preprocesse(imgs, dataset_name, preprocessed_dist, mode)


# Source: https://github.com/sacmehta/3D-ESPNet/blob/master/utils.py
def cropVolume(img, data=False):
'''
Expand Down Expand Up @@ -118,7 +194,39 @@ def cropVolumes(img1, img2, img3, img4):
return wi_st, wi_en, hi_st, hi_en, ch_st, ch_en


def preprocesse(imgs, dataset_name, dist_dir_path):
def save_nifti(imgs2save):
for imgs in imgs2save:
nib.save(*imgs)


def save_npy(imgs2save):
frst_slice = 0
last_slice = 0
seg = np.swapaxes(imgs2save["seg"]["modality"], 0, -1)
for i in range(seg.shape[0]):
curr_slice = seg[i, :, :]
if np.sum(curr_slice) == 0:
if last_slice <= frst_slice:
frst_slice = i
else:
last_slice = i
frst_slice += 1

for name, data in imgs2save.items():
modality = data["modality"]
path = data["path"]
modality = np.swapaxes(modality, 0, -1)
modality = modality[frst_slice:last_slice]
for i in range(modality.shape[0]):
curr_slice = modality[i, :, :]
if not os.path.isdir(path):
os.makedirs(path)
slice_dist_path = path + os.sep + str(i)
with open(f"{slice_dist_path}.npy", "wb") as f:
np.save(f, curr_slice)


def preprocesse(imgs, dataset_name, dist_dir_path, mode="3D"):
"""Preprocesse nii.gz data.
Args:
Expand Down Expand Up @@ -155,11 +263,11 @@ def preprocesse(imgs, dataset_name, dist_dir_path):
affine_gth = gth.affine
header_gth = gth.header

img_flair = img_flair.get_data()
img_t1 = img_t1.get_data()
img_t1ce = img_t1ce.get_data()
img_t2 = img_t2.get_data()
gth = gth.get_data()
img_flair = np.asanyarray(img_flair.dataobj)
img_t1 = np.asanyarray(img_t1.dataobj)
img_t1ce = np.asanyarray(img_t1ce.dataobj)
img_t2 = np.asanyarray(img_t2.dataobj)
gth = np.asanyarray(gth.dataobj)


# Crop the volumes
Expand All @@ -170,24 +278,37 @@ def preprocesse(imgs, dataset_name, dist_dir_path):
img_t1ce = img_t1ce[wi_st:wi_en, hi_st:hi_en, ch_st:ch_en]
img_t2 = img_t2[wi_st:wi_en, hi_st:hi_en, ch_st:ch_en]
gth = gth[wi_st:wi_en, hi_st:hi_en, ch_st:ch_en]


# save the cropped volumes
flair_cropped = nib.Nifti1Image(img_flair, affine_flair, header_flair)
t1_cropped = nib.Nifti1Image(img_t1, affine_t1, header_t1)
t1ce_cropped = nib.Nifti1Image(img_t1ce, affine_t1ce, header_t1ce)
t2_cropped = nib.Nifti1Image(img_t2, affine_t2, header_t2)
gth_cropped = nib.Nifti1Image(gth, affine_gth, header_gth)


# create the directories if they do not exist
dist_dir_path = dist_dir_path + os.sep + imgs['flair'].split('/')[1]
dist_dir_path = dist_dir_path + os.sep + imgs['flair'].split('/')[-2]
if not os.path.isdir(dist_dir_path):
os.makedirs(dist_dir_path)

nib.save(flair_cropped, dist_dir_path + os.sep + imgs['flair'].split('/')[-1])
nib.save(t1_cropped, dist_dir_path + os.sep + imgs['t1'].split('/')[-1])
nib.save(t1ce_cropped, dist_dir_path + os.sep + imgs['t1ce'].split('/')[-1])
nib.save(t2_cropped, dist_dir_path + os.sep + imgs['t2'].split('/')[-1])
nib.save(gth_cropped, dist_dir_path + os.sep + imgs['seg'].split('/')[-1])

os.makedirs(dist_dir_path)

if mode=="3D":
# save the cropped volumes
flair_cropped = nib.Nifti1Image(img_flair, affine_flair, header_flair)
t1_cropped = nib.Nifti1Image(img_t1, affine_t1, header_t1)
t1ce_cropped = nib.Nifti1Image(img_t1ce, affine_t1ce, header_t1ce)
t2_cropped = nib.Nifti1Image(img_t2, affine_t2, header_t2)
gth_cropped = nib.Nifti1Image(gth, affine_gth, header_gth)

imgs2save = [
(flair_cropped, dist_dir_path + os.sep + imgs['flair'].split('/')[-1]),
(t1_cropped, dist_dir_path + os.sep + imgs['t1'].split('/')[-1]),
(t1ce_cropped, dist_dir_path + os.sep + imgs['t1ce'].split('/')[-1]),
(t2_cropped, dist_dir_path + os.sep + imgs['t2'].split('/')[-1]),
(gth_cropped, dist_dir_path + os.sep + imgs['seg'].split('/')[-1])
]
save_nifti(imgs2save)
elif mode=="2D":
imgs2save = {
"flair": {"modality": img_flair, "path": dist_dir_path + os.sep + imgs['flair'].split('/')[-1].split('.')[-3]},
"t1": {"modality": img_t1, "path": dist_dir_path + os.sep + imgs['t1'].split('/')[-1].split('.')[-3]},
"t1ce": {"modality": img_t1ce, "path": dist_dir_path + os.sep + imgs['t1ce'].split('/')[-1].split('.')[-3]},
"t2": {"modality": img_t2, "path": dist_dir_path + os.sep + imgs['t2'].split('/')[-1].split('.')[-3]},
"seg": {"modality": gth, "path": dist_dir_path + os.sep + imgs['seg'].split('/')[-1].split('.')[-3]}
}
save_npy(imgs2save)
else:
raise ValueError(f"mode must be one of ['2D', '3D'], got {mode}")
53 changes: 52 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm


def get_fpaths(data_dir):
def get_3Dfpaths(data_dir):
'''Parse all the filenames and create a dictionary for each patient with structure:
{
't1': <path to t1 MRI file>
Expand Down Expand Up @@ -35,6 +35,57 @@ def get_fpaths(data_dir):
return data_paths


def get_2Dfpaths(data_dir):
'''Parse all the filenames and create a dictionary for each patient with structure:
{
't1': list(<paths to t1 MRI file>)
't2': list(<paths to t2 MRI>)
'flair': list(<paths to FLAIR MRI file>)
't1ce': list(<paths to t1ce MRI file>)
'seg': list(<paths to Ground Truth file>)
}
'''

pat = re.compile('.*_(\w*)')
data_paths = []
for case in glob.glob(os.path.join(data_dir, '*')):
# Get a list of files for all modalities individually
t1 = sorted(glob.glob(os.path.join(case, '*t1/*.npy')))
t2 = sorted(glob.glob(os.path.join(case, '*t2/*.npy')))
flair = sorted(glob.glob(os.path.join(case, '*flair/*.npy')))
t1ce = sorted(glob.glob(os.path.join(case, '*t1ce/*.npy')))
seg = sorted(glob.glob(os.path.join(case, '*seg/*.npy'))) # Ground Truth

data = {}
for items in list(zip(t1, t2, t1ce, flair, seg)):
for item in items:
data[pat.findall(item)[0]] = data.get(pat.findall(item)[0], []) + [item]

data_paths.append(data)

return data_paths


def unpack_2D_fpaths(packed_data_paths):
upacked_data_paths = []
for paths in packed_data_paths:
t1 = paths['t1']
t2 = paths['t2']
flair = paths['flair']
t1ce = paths['t1ce']
seg = paths['seg']


pat = re.compile('.*_(\w*)')

upacked_data_paths.extend([{
pat.findall(item)[0]:item
for item in items
}
for items in list(zip(t1, t2, t1ce, flair, seg))])
return upacked_data_paths


def change_orientation(img):
img = np.moveaxis(img, 0, -1)
img = np.swapaxes(img, 0, -2)
Expand Down

0 comments on commit 1e1ae39

Please sign in to comment.