From 78e6b8db63ae790a53d0bf46ba127347e088ec91 Mon Sep 17 00:00:00 2001 From: ViiSkor Date: Mon, 18 May 2020 02:10:35 +0300 Subject: [PATCH] Add a bunch of fixes --- src/augmentation.py | 45 ++++++++++-------- src/blocks.py | 53 ++++++++++++++++----- src/data_gen.py | 6 ++- src/preprocessing.py | 46 +++++++++--------- src/unet.py | 1 + src/utils.py | 110 +++++++++++++++++++++++-------------------- src/visualization.py | 34 ++++++++++++- 7 files changed, 187 insertions(+), 108 deletions(-) diff --git a/src/augmentation.py b/src/augmentation.py index ea608bd..fd08144 100644 --- a/src/augmentation.py +++ b/src/augmentation.py @@ -8,13 +8,13 @@ def augment(data, masks, params): Args: data (:obj:`numpy.array` of :obj:`np.float32`): - (x pathways) of numpy arrays [x, y, z, channels]. Scan data. + (x pathways) of numpy arrays [channels, x, y, z]. Scan data. masks (:obj:`numpy.array` of :obj:`np.int8`): - numpy arrays [x, y, z, channels]. Ground truth data. + numpy arrays [channels, x, y, z]. Ground truth data. params (dict): None or Dictionary, with parameters of each augmentation type. Returns: - data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [x, y, z, channels] - masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [x,y,z, classes] + data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [channels, x, y, z] + masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [classes, x, y, z] """ if params['hist_dist']: @@ -32,13 +32,13 @@ def random_flip(data, masks, n_dimensions): Args: data (:obj:`numpy.array` of :obj:`np.float32`): - (x pathways) of np arrays [x, y, z, channels]. Scan data. + (x pathways) of np arrays [channels, x, y, z]. Scan data. masks (:obj:`numpy.array` of :obj:`np.int8`): - numpy arrays [x, y, z, channels]. Ground truth data. + numpy arrays [channels, x, y, z]. Ground truth data. n_dimensions (int): the number of dimensions Returns: - data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [x, y, z, channels] - masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [x,y,z, classes] + data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [channels, x, y, z] + masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [classes, x, y, z] """ axis = [dim for dim in range(1, n_dimensions) if np.random.choice([True, False])] @@ -54,31 +54,38 @@ def random_histogram_distortion(data: np.array, shift={'mu': 0.0, 'std': 0}, sca Args: data (:obj:`numpy.array` of :obj:`np.float32`): - (x pathways) of np arrays [x, y, z, channels]. Scan data. + (x pathways) of np arrays [channels, x, y, z]. Scan data. shift (:obj:`dict` of :obj:`dict`): {'mu': 0.0, 'std':0.} params (:obj:`dict` of :obj:`dict`): {'mu': 1.0, 'std': '0.'} Returns: data (:obj:`numpy.array` of :obj:`np.float32`): - (x pathways) of numpy arrays [x, y, z, channels] + (x pathways) of numpy arrays [channels, x, y, z] References: Adapted from https://github.com/deepmedic/deepmedic/blob/f937eaa79debf001db2df697ddb14d94e7757b9f/deepmedic/dataManagement/augmentSample.py#L23 """ - n_channs = data[0].shape[-1] + n_channs = data[0].shape[0] + if len(data[0].shape) == 3: + axis2distort = [n_channs, 1, 1] + elif len(data[0].shape) == 4: + axis2distort = [n_channs, 1, 1, 1] + else: + raise RuntimeError(f"Got unexpected dimension {len(data[0].shape)}") + if shift is None: shift_per_chan = 0. elif shift['std'] != 0: # np.random.normal does not work for an std==0. - shift_per_chan = np.random.normal(shift['mu'], shift['std'], [1, 1, 1, n_channs]) + shift_per_chan = np.random.normal(shift['mu'], shift['std'], axis2distort) else: - shift_per_chan = np.ones([1, 1, 1, n_channs], dtype="float32") * shift['mu'] + shift_per_chan = np.ones(axis2distort, dtype="float32") * shift['mu'] if scale is None: scale_per_chan = 1. elif scale['std'] != 0: - scale_per_chan = np.random.normal(scale['mu'], scale['std'], [1, 1, 1, n_channs]) + scale_per_chan = np.random.normal(scale['mu'], scale['std'], axis2distort) else: - scale_per_chan = np.ones([1, 1, 1, n_channs], dtype="float32") * scale['mu'] + scale_per_chan = np.ones(axis2distort, dtype="float32") * scale['mu'] # Intensity augmentation for path_idx in range(len(data)): @@ -92,13 +99,13 @@ def random_rotate(data, masks, degrees=[-15, -10, -5, 0, 5, 10, 15]): Args: data (:obj:`numpy.array` of :obj:`np.float32`): - (x pathways) of np arrays [x, y, z, channels]. Scan data. + (x pathways) of np arrays [channels, x, y, z]. Scan data. masks (:obj:`numpy.array` of :obj:`np.int8`): - numpy arrays [x, y, z, channels]. Ground truth data. + numpy arrays [channels, x, y, z]. Ground truth data. degrees (:obj:`numpy.array` of :obj:`int`): list of possible angle of rotation in degrees. Returns: - data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [x, y, z, channels] - masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [x,y,z, classes] + data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [channels, x, y, z] + masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [classes, x, y, z] """ degrees = np.random.choice(a=degrees, size=1) diff --git a/src/blocks.py b/src/blocks.py index 7c0b1e2..a6fdd61 100644 --- a/src/blocks.py +++ b/src/blocks.py @@ -1,23 +1,33 @@ -from tensorflow.keras.layers import Conv2D, Conv3D, BatchNormalization, Activation, SpatialDropout2D, SpatialDropout3D, Dropout +import math +from tensorflow.keras.layers import add, Conv2D, Conv3D, BatchNormalization, Activation, SpatialDropout2D, SpatialDropout3D, Dropout -def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False): - if conv_type == "2D": +def get_layers(conv_type, dropout_type, mode="3D"): + if conv_type == "2D": conv = Conv2D spatial_dropout = SpatialDropout2D - elif conv_type == "3D": + elif conv_type == "3D": conv = Conv3D spatial_dropout = SpatialDropout3D - else: + else: raise ValueError(f"conv_type must be one of ['2D', '3D'], but got {conv_type}") - if dropout_type == "standard": - dropout = Dropout - elif dropout_type == "spatial": - dropout = spatial_dropout + if dropout_type == "standard": + dropout = Dropout + elif dropout_type == "spatial": + dropout = spatial_dropout + else: + if dropout_type: + raise ValueError(f"dropout_type must be one of ['standard', 'spatial', None], but got {dropout_type}") else: - if dropout_type: - raise ValueError(f"dropout_type must be one of ['standard', 'spatial', None], but got {dropout_type}") + dropout = None + + return {'conv': conv, 'dropout': dropout} + +def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False): + layers = get_layers(conv_type, dropout_type, mode=conv_type) + conv = layers['conv'] + dropout = layers['dropout'] # first layer x = conv(filters=n_filters, **conv_kwds)(inputs) @@ -34,3 +44,24 @@ def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type x = Activation(activation)(x) return x + + +def dilate_conv_block(x, n_filters, max_dilation_rate, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False): + layers = get_layers(conv_type, dropout_type, mode="3D") + conv = layers['conv'] + dropout = layers['dropout'] + + dilates = [] + for i in range(math.ceil(math.log(max_dilation_rate, 2))): + x = conv(filters=n_filters, dilation_rate=2**i, **conv_kwds)(x) + if batchnorm: + x = BatchNormalization()(x) + x = Activation(activation)(x) + if dropout_type and dropout_prob > 0.0: + x = dropout(dropout_prob)(x) + dilates.append(x) + x = conv(filters=n_filters, dilation_rate=max_dilation_rate, **conv_kwds)(x) + dilates.append(x) + + return add(dilates) + diff --git a/src/data_gen.py b/src/data_gen.py index 69bbab5..f65303b 100644 --- a/src/data_gen.py +++ b/src/data_gen.py @@ -171,8 +171,9 @@ def __data_generation(self, list_fpaths_temp): # Generate data for i, imgs in enumerate(list_fpaths_temp): - modalities = np.array([np.load(imgs[m]) for m in self.scan_types]) - masks = preprocess_label(np.load(imgs['seg']), + curr_slice = imgs['seg'][1] + modalities = np.array([np.load(imgs[m][0])[curr_slice] for m in self.scan_types]) + masks = preprocess_label(np.load(imgs['seg'][0])[curr_slice], output_classes=self.output_classes, merge_classes=self.merge_classes) @@ -189,4 +190,5 @@ def __data_generation(self, list_fpaths_temp): X[i] = (X[i] - np.mean(X[i])) / np.std(X[i]) return X, y + \ No newline at end of file diff --git a/src/preprocessing.py b/src/preprocessing.py index f7359a6..6da45d2 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -196,34 +196,36 @@ def cropVolumes(img1, img2, img3, img4): def save_nifti(imgs2save): for imgs in imgs2save: - nib.save(*imgs) + nib.save(*imgs) + + +# def save_npy(imgs2save): +# frst_slice = 0 +# last_slice = 0 +# seg = np.swapaxes(imgs2save["seg"]["modality"], 0, -1) +# for i in range(seg.shape[0]): +# curr_slice = seg[i, :, :] +# if np.sum(curr_slice) == 0: +# if last_slice <= frst_slice: +# frst_slice = i +# else: +# last_slice = i +# frst_slice += 1 + +# for name, data in imgs2save.items(): +# modality = data["modality"] +# modality = np.swapaxes(modality, 0, -1) +# modality = modality[frst_slice:last_slice] +# with open(f"{data['path']}.npy", "wb") as f: +# np.save(f, modality) def save_npy(imgs2save): - frst_slice = 0 - last_slice = 0 - seg = np.swapaxes(imgs2save["seg"]["modality"], 0, -1) - for i in range(seg.shape[0]): - curr_slice = seg[i, :, :] - if np.sum(curr_slice) == 0: - if last_slice <= frst_slice: - frst_slice = i - else: - last_slice = i - frst_slice += 1 - for name, data in imgs2save.items(): modality = data["modality"] - path = data["path"] modality = np.swapaxes(modality, 0, -1) - modality = modality[frst_slice:last_slice] - for i in range(modality.shape[0]): - curr_slice = modality[i, :, :] - if not os.path.isdir(path): - os.makedirs(path) - slice_dist_path = path + os.sep + str(i) - with open(f"{slice_dist_path}.npy", "wb") as f: - np.save(f, curr_slice) + with open(f"{data['path']}.npy", "wb") as f: + np.save(f, modality) def preprocesse(imgs, dataset_name, dist_dir_path, mode="3D"): diff --git a/src/unet.py b/src/unet.py index 9e5d625..aaaa511 100644 --- a/src/unet.py +++ b/src/unet.py @@ -31,6 +31,7 @@ def __init__(self, self.dropout_prob_shift = dropout_prob_shift self.batch_size = batch_size self.model_depth = model_depth + self.bottleneck_depth = bottleneck_depth self.dilate = dilate self.max_dilation_rate = max_dilation_rate self.name = name diff --git a/src/utils.py b/src/utils.py index a79add4..a0f4963 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,12 +1,13 @@ import glob import re import os +import random import numpy as np import nibabel as nib from tqdm import tqdm -def get_3Dfpaths(data_dir): +def get_fpaths(data_dir, mode="3D"): '''Parse all the filenames and create a dictionary for each patient with structure: { 't1': @@ -18,13 +19,18 @@ def get_3Dfpaths(data_dir): ''' # Get a list of files for all modalities individually - t1 = glob.glob(os.path.join(data_dir, '*/*t1.nii.gz')) - t2 = glob.glob(os.path.join(data_dir, '*/*t2.nii.gz')) - flair = glob.glob(os.path.join(data_dir, '*/*flair.nii.gz')) - t1ce = glob.glob(os.path.join(data_dir, '*/*t1ce.nii.gz')) - seg = glob.glob(os.path.join(data_dir, '*/*seg.nii.gz')) # Ground Truth + if mode == "3D": + ext = 'nii.gz' + pat = re.compile('.*_(\w*)\.nii\.gz') + elif mode == "2D": + ext = 'npy' + pat = re.compile('.*_(\w*)\.npy') - pat = re.compile('.*_(\w*)\.nii\.gz') + t1 = glob.glob(os.path.join(data_dir, f'*/*t1.{ext}')) + t2 = glob.glob(os.path.join(data_dir, f'*/*t2.{ext}')) + flair = glob.glob(os.path.join(data_dir, f'*/*flair.{ext}')) + t1ce = glob.glob(os.path.join(data_dir, f'*/*t1ce.{ext}')) + seg = glob.glob(os.path.join(data_dir, f'*/*seg.{ext}')) # Ground Truth data_paths = [{ pat.findall(item)[0]:item @@ -35,54 +41,36 @@ def get_3Dfpaths(data_dir): return data_paths -def get_2Dfpaths(data_dir): - '''Parse all the filenames and create a dictionary for each patient with structure: - { - 't1': list() - 't2': list() - 'flair': list() - 't1ce': list() - 'seg': list() - } - ''' - - pat = re.compile('.*_(\w*)') - data_paths = [] - for case in glob.glob(os.path.join(data_dir, '*')): - # Get a list of files for all modalities individually - t1 = sorted(glob.glob(os.path.join(case, '*t1/*.npy'))) - t2 = sorted(glob.glob(os.path.join(case, '*t2/*.npy'))) - flair = sorted(glob.glob(os.path.join(case, '*flair/*.npy'))) - t1ce = sorted(glob.glob(os.path.join(case, '*t1ce/*.npy'))) - seg = sorted(glob.glob(os.path.join(case, '*seg/*.npy'))) # Ground Truth - - data = {} - for items in list(zip(t1, t2, t1ce, flair, seg)): - for item in items: - data[pat.findall(item)[0]] = data.get(pat.findall(item)[0], []) + [item] - - data_paths.append(data) - - return data_paths - - -def unpack_2D_fpaths(packed_data_paths): +def unpack_2D_fpaths(packed_data_paths, only_with_mask=True): upacked_data_paths = [] + mod_names = packed_data_paths[0].keys() for paths in packed_data_paths: - t1 = paths['t1'] - t2 = paths['t2'] - flair = paths['flair'] - t1ce = paths['t1ce'] - seg = paths['seg'] - + frst_slice = 0 + last_slice = 0 + img = np.load(paths['seg']) + if only_with_mask: + for i in range(img.shape[0]): + curr_slice = img[i, :, :] + if np.sum(curr_slice) == 0: + if last_slice <= frst_slice: + frst_slice = i + else: + last_slice = i + frst_slice += 1 + else: + last_slice = img.shape[0] + depth = 0 + modalities = {} + for name, path in paths.items(): + data = np.load(path) + data = data[frst_slice:last_slice] + depth = data.shape[0] + for i in range(depth): + modalities[name] = modalities.get(name, []) + [(path, i)] + + for i in range(depth): + upacked_data_paths.append({name: modalities[name][i] for name in mod_names}) - pat = re.compile('.*_(\w*)') - - upacked_data_paths.extend([{ - pat.findall(item)[0]:item - for item in items - } - for items in list(zip(t1, t2, t1ce, flair, seg))]) return upacked_data_paths @@ -150,3 +138,21 @@ def get_preprocessed_data(data_paths:dict, scan_types=['t1', 'seg']): data.append(scans) return data + + +def get_dataset_split(data_paths, train_ratio=0.7, seed=42, shuffle=True): + random.seed(seed) + + n_samples = len(data_paths) + n_train = int(n_samples*train_ratio) + n_test = int(n_samples*(train_ratio-1)/2) + + if shuffle: + random.shuffle(data_paths) + + train_data_paths = data_paths[:n_train] + test_data_paths = data_paths[n_train:] + val_data_paths = test_data_paths[:n_test] + test_data_paths = test_data_paths[n_test:] + + return train_data_paths, test_data_paths, val_data_paths diff --git a/src/visualization.py b/src/visualization.py index 067529c..fd3ace8 100644 --- a/src/visualization.py +++ b/src/visualization.py @@ -3,6 +3,9 @@ from matplotlib import animation +plt.style.use('seaborn-pastel') + + def animate_scan(scan, mask): fig = plt.figure(figsize=(16, 8)) ax1 = fig.add_subplot(1,2,1) @@ -34,7 +37,6 @@ def show_class_frequency(classes_freq, classes2show, pixel2class): plt.show() -# Source https://gist.github.com/soply/f3eec2e79c165e39c9d540e916142ae1 def show_images(images, cols = 1, scale=4, titles = None): """Display a list of images in a single figure with matplotlib. @@ -47,7 +49,11 @@ def show_images(images, cols = 1, scale=4, titles = None): titles: List of titles corresponding to each image. Must have the same length as titles. + + References: + Adapted from https://gist.github.com/soply/f3eec2e79c165e39c9d540e916142ae1 """ + assert((titles is None)or (len(images) == len(titles))) n_images = len(images) if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)] @@ -61,4 +67,28 @@ def show_images(images, cols = 1, scale=4, titles = None): plt.imshow(image) #a.set_title(title) fig.set_size_inches(np.array(fig.get_size_inches()) * n_images / scale) - plt.show() \ No newline at end of file + plt.show() + + +def show_depth_hist(data): + depths = [] + for modalities in data: + depths.append(modalities['t1'].shape[0]) + + plt.figure(figsize=(16,8)) + plt.title("Modalities' depth distribution", fontdict={'fontsize': 20}) + plt.hist(depths, bins=len(depths)//9) + plt.show() + + +def show_modalities(imgs, slice_num, scan_types): + fig, ax = plt.subplots(nrows=1, ncols=len(scan_types), figsize=(18, 9)) + + for ax, scan in zip(ax.flat, scan_types): + img = imgs[scan] + ax.imshow(img[slice_num, :, :, 0], cmap='gray') + ax.set_title(scan) + ax.axis('off') + + plt.tight_layout(True) + plt.show()