diff --git a/src/augmentation.py b/src/augmentation.py index e9521b5..ea608bd 100644 --- a/src/augmentation.py +++ b/src/augmentation.py @@ -3,74 +3,6 @@ from scipy import ndimage -def crop(data, masks, depth=None, slice_shape=None): - """Crop samples for a neural network input. - - Args: - data (`numpy.array`): - numpy arrays [x, y, z, channels]/[x, y, channels]. Scan data. - masks (`numpy.array`): - numpy arrays [x, y, z, channels]/[x, y, channels]. Ground truth data. - depth (int): New z of a sample. - slice_shape (tuple): New xy shape of a sample. - Returns: - data (`numpy.array`): Croped numpy arrays [x, y, z, channels] - """ - - - if slice_shape: - if len(data.shape) == 3: - vertical_shift = int((data.shape[0] - slice_shape[0]) // 2) - horizontal_shift = int((data.shape[1] - slice_shape[1]) // 2) - data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:] - elif len(data.shape) == 4: - vertical_shift = int((data.shape[1] - slice_shape[0]) // 2) - horizontal_shift = int((data.shape[2] - slice_shape[1]) // 2) - data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:] - else: - raise RuntimeError("unexpected dimension") - - if depth: - if depth < data.shape[-1]: - if len(data.shape) == 4: - depth_shift = int((data.shape[-1] - depth) // 2) - data = data[:, :,:,depth_shift:depth+depth_shift] - masks = masks[:, :,:,depth_shift:depth+depth_shift] - - return data, masks - - -def pad(data, masks, prev_shape, shape, n_channels, n_classes): - """Pad samples for a neural network input. - - Args: - data (`numpy.array`): - numpy arrays [x, y, z, channels]/[x, y, channels]. Scan data. - masks (`numpy.array`): - numpy arrays [x, y, z, channels]/[x, y, channels]. Ground truth data. - prev_shape (tuple): Old shape of a sample - shape (tuple): New shape of a sample. - n_channels (int): The number of a case's channels/modalities/classes. - Returns: - new_data (`numpy.array`): Padded numpy data [x, y, z, channels] - new_masks (`numpy.array`): Padded numpy ground truth [x, y, z, channels] - """ - - new_data = np.zeros((n_channels, *shape)) - new_masks = np.zeros((n_classes, *shape)) - start = (np.array(shape) / 2. - np.array(prev_shape) / 2.).astype(int) - end = start + np.array([int(dim) for dim in prev_shape], dtype=int) - if len(shape) == 2: - new_data[start[0]:end[0], start[1]:end[1]] = data[:, :] - new_masks[start[0]:end[0], start[1]:end[1]] = masks[:, :] - elif len(shape) == 3: - new_data[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = data[:, :, :, :] - new_masks[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = masks[:, :, :, :] - else: - raise RuntimeError("unexpected dimension") - return new_data, new_masks - - def augment(data, masks, params): """Augment samples. diff --git a/src/preprocessing.py b/src/preprocessing.py index 30f8476..f7359a6 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -1,5 +1,7 @@ +import os import nibabel as nib import numpy as np +from tqdm import tqdm def fill_labels(img, slice_nums): @@ -19,7 +21,7 @@ def preprocess_label(mask, output_classes=['ed'], merge_classes=False, out_shape Args: mask (numpy.array): - Ground truth numpy arrays [x, y, z, classes]. Whole volumes, channels of a case. + Ground truth numpy arrays [classes, x, y, z]. Whole volumes, channels of a case. output_classes (:obj:`list` of :obj:`str`): classes to sepatare. merge_classes (bool): Merge output_classes into one or not. out_shape (tuple): Shape for scaling ground truth labels. @@ -52,11 +54,85 @@ def preprocess_label(mask, output_classes=['ed'], merge_classes=False, out_shape output += label output = [output] else: - masks = output + output = masks return np.array(output, dtype=np.uint8) +def crop(data, masks, depth=None, slice_shape=None): + """Crop samples for a neural network input. + + Args: + data (`numpy.array`): + numpy arrays [channels, x, y, z]/[channels, x, y]. Scan data. + masks (`numpy.array`): + numpy arrays [channels, x, y, z]/[channels, x, y]. Ground truth data. + depth (int): New z of a sample. + slice_shape (tuple): New xy shape of a sample. + Returns: + data (`numpy.array`): Croped numpy arrays [channels, x, y, z] + """ + + + if slice_shape: + if len(data.shape) == 3: + vertical_shift = int((data.shape[0] - slice_shape[0]) // 2) + horizontal_shift = int((data.shape[1] - slice_shape[1]) // 2) + data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:] + elif len(data.shape) == 4: + vertical_shift = int((data.shape[1] - slice_shape[0]) // 2) + horizontal_shift = int((data.shape[2] - slice_shape[1]) // 2) + data = data[vertical_shift:slice_shape[0]+vertical_shift,horizontal_shift:slice_shape[1]+horizontal_shift,:] + else: + raise RuntimeError(f"Got unexpected dimension {len(data.shape)}") + + if depth: + if depth < data.shape[-1]: + if len(data.shape) == 4: + depth_shift = int((data.shape[-1] - depth) // 2) + data = data[:, :,:,depth_shift:depth+depth_shift] + masks = masks[:, :,:,depth_shift:depth+depth_shift] + + return data, masks + + +def pad(data, masks, prev_shape, shape, n_channels, n_classes): + """Pad samples for a neural network input. + + Args: + data (`numpy.array`): + numpy arrays [channels, x, y, z]/[channels, x, y]. Scan data. + masks (`numpy.array`): + numpy arrays [channels, x, y, z]/[channels, x, y]. Ground truth data. + prev_shape (tuple): Old shape of a sample + shape (tuple): New shape of a sample. + n_channels (int): The number of a case's channels/modalities/classes. + Returns: + new_data (`numpy.array`): Padded numpy data [channels, x, y, z] + new_masks (`numpy.array`): Padded numpy ground truth [channels, x, y, z] + """ + + new_data = np.zeros((n_channels, *shape)) + new_masks = np.zeros((n_classes, *shape)) + start = (np.array(shape) / 2. - np.array(prev_shape) / 2.).astype(int) + end = start + np.array([int(dim) for dim in prev_shape], dtype=int) + if len(shape) == 2: + new_data[:, start[0]:end[0], start[1]:end[1]] = data[:, :, :] + new_masks[:, start[0]:end[0], start[1]:end[1]] = masks[:, :, :] + elif len(shape) == 3: + new_data[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = data[:, :, :, :] + new_masks[:, start[0]:end[0], start[1]:end[1], start[2]:end[2]] = masks[:, :, :, :] + else: + raise RuntimeError(f"Got unexpected dimension {len(shape)}") + return new_data, new_masks + + + +def prepare(data_paths:dict, dataset_name:str, preprocessed_dist:str, mode="3D"): + for i, imgs in enumerate(tqdm(data_paths)): + preprocesse(imgs, dataset_name, preprocessed_dist, mode) + + # Source: https://github.com/sacmehta/3D-ESPNet/blob/master/utils.py def cropVolume(img, data=False): ''' @@ -118,7 +194,39 @@ def cropVolumes(img1, img2, img3, img4): return wi_st, wi_en, hi_st, hi_en, ch_st, ch_en -def preprocesse(imgs, dataset_name, dist_dir_path): +def save_nifti(imgs2save): + for imgs in imgs2save: + nib.save(*imgs) + + +def save_npy(imgs2save): + frst_slice = 0 + last_slice = 0 + seg = np.swapaxes(imgs2save["seg"]["modality"], 0, -1) + for i in range(seg.shape[0]): + curr_slice = seg[i, :, :] + if np.sum(curr_slice) == 0: + if last_slice <= frst_slice: + frst_slice = i + else: + last_slice = i + frst_slice += 1 + + for name, data in imgs2save.items(): + modality = data["modality"] + path = data["path"] + modality = np.swapaxes(modality, 0, -1) + modality = modality[frst_slice:last_slice] + for i in range(modality.shape[0]): + curr_slice = modality[i, :, :] + if not os.path.isdir(path): + os.makedirs(path) + slice_dist_path = path + os.sep + str(i) + with open(f"{slice_dist_path}.npy", "wb") as f: + np.save(f, curr_slice) + + +def preprocesse(imgs, dataset_name, dist_dir_path, mode="3D"): """Preprocesse nii.gz data. Args: @@ -155,11 +263,11 @@ def preprocesse(imgs, dataset_name, dist_dir_path): affine_gth = gth.affine header_gth = gth.header - img_flair = img_flair.get_data() - img_t1 = img_t1.get_data() - img_t1ce = img_t1ce.get_data() - img_t2 = img_t2.get_data() - gth = gth.get_data() + img_flair = np.asanyarray(img_flair.dataobj) + img_t1 = np.asanyarray(img_t1.dataobj) + img_t1ce = np.asanyarray(img_t1ce.dataobj) + img_t2 = np.asanyarray(img_t2.dataobj) + gth = np.asanyarray(gth.dataobj) # Crop the volumes @@ -170,24 +278,37 @@ def preprocesse(imgs, dataset_name, dist_dir_path): img_t1ce = img_t1ce[wi_st:wi_en, hi_st:hi_en, ch_st:ch_en] img_t2 = img_t2[wi_st:wi_en, hi_st:hi_en, ch_st:ch_en] gth = gth[wi_st:wi_en, hi_st:hi_en, ch_st:ch_en] - - - # save the cropped volumes - flair_cropped = nib.Nifti1Image(img_flair, affine_flair, header_flair) - t1_cropped = nib.Nifti1Image(img_t1, affine_t1, header_t1) - t1ce_cropped = nib.Nifti1Image(img_t1ce, affine_t1ce, header_t1ce) - t2_cropped = nib.Nifti1Image(img_t2, affine_t2, header_t2) - gth_cropped = nib.Nifti1Image(gth, affine_gth, header_gth) # create the directories if they do not exist - dist_dir_path = dist_dir_path + os.sep + imgs['flair'].split('/')[1] + dist_dir_path = dist_dir_path + os.sep + imgs['flair'].split('/')[-2] if not os.path.isdir(dist_dir_path): - os.makedirs(dist_dir_path) - - nib.save(flair_cropped, dist_dir_path + os.sep + imgs['flair'].split('/')[-1]) - nib.save(t1_cropped, dist_dir_path + os.sep + imgs['t1'].split('/')[-1]) - nib.save(t1ce_cropped, dist_dir_path + os.sep + imgs['t1ce'].split('/')[-1]) - nib.save(t2_cropped, dist_dir_path + os.sep + imgs['t2'].split('/')[-1]) - nib.save(gth_cropped, dist_dir_path + os.sep + imgs['seg'].split('/')[-1]) - \ No newline at end of file + os.makedirs(dist_dir_path) + + if mode=="3D": + # save the cropped volumes + flair_cropped = nib.Nifti1Image(img_flair, affine_flair, header_flair) + t1_cropped = nib.Nifti1Image(img_t1, affine_t1, header_t1) + t1ce_cropped = nib.Nifti1Image(img_t1ce, affine_t1ce, header_t1ce) + t2_cropped = nib.Nifti1Image(img_t2, affine_t2, header_t2) + gth_cropped = nib.Nifti1Image(gth, affine_gth, header_gth) + + imgs2save = [ + (flair_cropped, dist_dir_path + os.sep + imgs['flair'].split('/')[-1]), + (t1_cropped, dist_dir_path + os.sep + imgs['t1'].split('/')[-1]), + (t1ce_cropped, dist_dir_path + os.sep + imgs['t1ce'].split('/')[-1]), + (t2_cropped, dist_dir_path + os.sep + imgs['t2'].split('/')[-1]), + (gth_cropped, dist_dir_path + os.sep + imgs['seg'].split('/')[-1]) + ] + save_nifti(imgs2save) + elif mode=="2D": + imgs2save = { + "flair": {"modality": img_flair, "path": dist_dir_path + os.sep + imgs['flair'].split('/')[-1].split('.')[-3]}, + "t1": {"modality": img_t1, "path": dist_dir_path + os.sep + imgs['t1'].split('/')[-1].split('.')[-3]}, + "t1ce": {"modality": img_t1ce, "path": dist_dir_path + os.sep + imgs['t1ce'].split('/')[-1].split('.')[-3]}, + "t2": {"modality": img_t2, "path": dist_dir_path + os.sep + imgs['t2'].split('/')[-1].split('.')[-3]}, + "seg": {"modality": gth, "path": dist_dir_path + os.sep + imgs['seg'].split('/')[-1].split('.')[-3]} + } + save_npy(imgs2save) + else: + raise ValueError(f"mode must be one of ['2D', '3D'], got {mode}") diff --git a/src/utils.py b/src/utils.py index 33e7e67..a79add4 100644 --- a/src/utils.py +++ b/src/utils.py @@ -6,7 +6,7 @@ from tqdm import tqdm -def get_fpaths(data_dir): +def get_3Dfpaths(data_dir): '''Parse all the filenames and create a dictionary for each patient with structure: { 't1': @@ -35,6 +35,57 @@ def get_fpaths(data_dir): return data_paths +def get_2Dfpaths(data_dir): + '''Parse all the filenames and create a dictionary for each patient with structure: + { + 't1': list() + 't2': list() + 'flair': list() + 't1ce': list() + 'seg': list() + } + ''' + + pat = re.compile('.*_(\w*)') + data_paths = [] + for case in glob.glob(os.path.join(data_dir, '*')): + # Get a list of files for all modalities individually + t1 = sorted(glob.glob(os.path.join(case, '*t1/*.npy'))) + t2 = sorted(glob.glob(os.path.join(case, '*t2/*.npy'))) + flair = sorted(glob.glob(os.path.join(case, '*flair/*.npy'))) + t1ce = sorted(glob.glob(os.path.join(case, '*t1ce/*.npy'))) + seg = sorted(glob.glob(os.path.join(case, '*seg/*.npy'))) # Ground Truth + + data = {} + for items in list(zip(t1, t2, t1ce, flair, seg)): + for item in items: + data[pat.findall(item)[0]] = data.get(pat.findall(item)[0], []) + [item] + + data_paths.append(data) + + return data_paths + + +def unpack_2D_fpaths(packed_data_paths): + upacked_data_paths = [] + for paths in packed_data_paths: + t1 = paths['t1'] + t2 = paths['t2'] + flair = paths['flair'] + t1ce = paths['t1ce'] + seg = paths['seg'] + + + pat = re.compile('.*_(\w*)') + + upacked_data_paths.extend([{ + pat.findall(item)[0]:item + for item in items + } + for items in list(zip(t1, t2, t1ce, flair, seg))]) + return upacked_data_paths + + def change_orientation(img): img = np.moveaxis(img, 0, -1) img = np.swapaxes(img, 0, -2)