Skip to content

Commit

Permalink
Add a bunch of fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ViiSkor committed May 17, 2020
1 parent a2c7701 commit 78e6b8d
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 108 deletions.
45 changes: 26 additions & 19 deletions src/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ def augment(data, masks, params):
Args:
data (:obj:`numpy.array` of :obj:`np.float32`):
(x pathways) of numpy arrays [x, y, z, channels]. Scan data.
(x pathways) of numpy arrays [channels, x, y, z]. Scan data.
masks (:obj:`numpy.array` of :obj:`np.int8`):
numpy arrays [x, y, z, channels]. Ground truth data.
numpy arrays [channels, x, y, z]. Ground truth data.
params (dict): None or Dictionary, with parameters of each augmentation type.
Returns:
data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [x, y, z, channels]
masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [x,y,z, classes]
data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [channels, x, y, z]
masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [classes, x, y, z]
"""

if params['hist_dist']:
Expand All @@ -32,13 +32,13 @@ def random_flip(data, masks, n_dimensions):
Args:
data (:obj:`numpy.array` of :obj:`np.float32`):
(x pathways) of np arrays [x, y, z, channels]. Scan data.
(x pathways) of np arrays [channels, x, y, z]. Scan data.
masks (:obj:`numpy.array` of :obj:`np.int8`):
numpy arrays [x, y, z, channels]. Ground truth data.
numpy arrays [channels, x, y, z]. Ground truth data.
n_dimensions (int): the number of dimensions
Returns:
data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [x, y, z, channels]
masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [x,y,z, classes]
data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [channels, x, y, z]
masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [classes, x, y, z]
"""

axis = [dim for dim in range(1, n_dimensions) if np.random.choice([True, False])]
Expand All @@ -54,31 +54,38 @@ def random_histogram_distortion(data: np.array, shift={'mu': 0.0, 'std': 0}, sca
Args:
data (:obj:`numpy.array` of :obj:`np.float32`):
(x pathways) of np arrays [x, y, z, channels]. Scan data.
(x pathways) of np arrays [channels, x, y, z]. Scan data.
shift (:obj:`dict` of :obj:`dict`): {'mu': 0.0, 'std':0.}
params (:obj:`dict` of :obj:`dict`): {'mu': 1.0, 'std': '0.'}
Returns:
data (:obj:`numpy.array` of :obj:`np.float32`):
(x pathways) of numpy arrays [x, y, z, channels]
(x pathways) of numpy arrays [channels, x, y, z]
References:
Adapted from https://github.com/deepmedic/deepmedic/blob/f937eaa79debf001db2df697ddb14d94e7757b9f/deepmedic/dataManagement/augmentSample.py#L23
"""

n_channs = data[0].shape[-1]
n_channs = data[0].shape[0]
if len(data[0].shape) == 3:
axis2distort = [n_channs, 1, 1]
elif len(data[0].shape) == 4:
axis2distort = [n_channs, 1, 1, 1]
else:
raise RuntimeError(f"Got unexpected dimension {len(data[0].shape)}")

if shift is None:
shift_per_chan = 0.
elif shift['std'] != 0: # np.random.normal does not work for an std==0.
shift_per_chan = np.random.normal(shift['mu'], shift['std'], [1, 1, 1, n_channs])
shift_per_chan = np.random.normal(shift['mu'], shift['std'], axis2distort)
else:
shift_per_chan = np.ones([1, 1, 1, n_channs], dtype="float32") * shift['mu']
shift_per_chan = np.ones(axis2distort, dtype="float32") * shift['mu']

if scale is None:
scale_per_chan = 1.
elif scale['std'] != 0:
scale_per_chan = np.random.normal(scale['mu'], scale['std'], [1, 1, 1, n_channs])
scale_per_chan = np.random.normal(scale['mu'], scale['std'], axis2distort)
else:
scale_per_chan = np.ones([1, 1, 1, n_channs], dtype="float32") * scale['mu']
scale_per_chan = np.ones(axis2distort, dtype="float32") * scale['mu']

# Intensity augmentation
for path_idx in range(len(data)):
Expand All @@ -92,13 +99,13 @@ def random_rotate(data, masks, degrees=[-15, -10, -5, 0, 5, 10, 15]):
Args:
data (:obj:`numpy.array` of :obj:`np.float32`):
(x pathways) of np arrays [x, y, z, channels]. Scan data.
(x pathways) of np arrays [channels, x, y, z]. Scan data.
masks (:obj:`numpy.array` of :obj:`np.int8`):
numpy arrays [x, y, z, channels]. Ground truth data.
numpy arrays [channels, x, y, z]. Ground truth data.
degrees (:obj:`numpy.array` of :obj:`int`): list of possible angle of rotation in degrees.
Returns:
data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [x, y, z, channels]
masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [x,y,z, classes]
data (:obj:`numpy.array` of :obj:`np.float32`): (x pathways) of np arrays [channels, x, y, z]
masks (:obj:`numpy.array` of :obj:`np.int8`): np array of shape [classes, x, y, z]
"""

degrees = np.random.choice(a=degrees, size=1)
Expand Down
53 changes: 42 additions & 11 deletions src/blocks.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
from tensorflow.keras.layers import Conv2D, Conv3D, BatchNormalization, Activation, SpatialDropout2D, SpatialDropout3D, Dropout
import math
from tensorflow.keras.layers import add, Conv2D, Conv3D, BatchNormalization, Activation, SpatialDropout2D, SpatialDropout3D, Dropout


def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False):
if conv_type == "2D":
def get_layers(conv_type, dropout_type, mode="3D"):
if conv_type == "2D":
conv = Conv2D
spatial_dropout = SpatialDropout2D
elif conv_type == "3D":
elif conv_type == "3D":
conv = Conv3D
spatial_dropout = SpatialDropout3D
else:
else:
raise ValueError(f"conv_type must be one of ['2D', '3D'], but got {conv_type}")

if dropout_type == "standard":
dropout = Dropout
elif dropout_type == "spatial":
dropout = spatial_dropout
if dropout_type == "standard":
dropout = Dropout
elif dropout_type == "spatial":
dropout = spatial_dropout
else:
if dropout_type:
raise ValueError(f"dropout_type must be one of ['standard', 'spatial', None], but got {dropout_type}")
else:
if dropout_type:
raise ValueError(f"dropout_type must be one of ['standard', 'spatial', None], but got {dropout_type}")
dropout = None

return {'conv': conv, 'dropout': dropout}

def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False):
layers = get_layers(conv_type, dropout_type, mode=conv_type)
conv = layers['conv']
dropout = layers['dropout']

# first layer
x = conv(filters=n_filters, **conv_kwds)(inputs)
Expand All @@ -34,3 +44,24 @@ def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type
x = Activation(activation)(x)

return x


def dilate_conv_block(x, n_filters, max_dilation_rate, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False):
layers = get_layers(conv_type, dropout_type, mode="3D")
conv = layers['conv']
dropout = layers['dropout']

dilates = []
for i in range(math.ceil(math.log(max_dilation_rate, 2))):
x = conv(filters=n_filters, dilation_rate=2**i, **conv_kwds)(x)
if batchnorm:
x = BatchNormalization()(x)
x = Activation(activation)(x)
if dropout_type and dropout_prob > 0.0:
x = dropout(dropout_prob)(x)
dilates.append(x)
x = conv(filters=n_filters, dilation_rate=max_dilation_rate, **conv_kwds)(x)
dilates.append(x)

return add(dilates)

6 changes: 4 additions & 2 deletions src/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ def __data_generation(self, list_fpaths_temp):

# Generate data
for i, imgs in enumerate(list_fpaths_temp):
modalities = np.array([np.load(imgs[m]) for m in self.scan_types])
masks = preprocess_label(np.load(imgs['seg']),
curr_slice = imgs['seg'][1]
modalities = np.array([np.load(imgs[m][0])[curr_slice] for m in self.scan_types])
masks = preprocess_label(np.load(imgs['seg'][0])[curr_slice],
output_classes=self.output_classes,
merge_classes=self.merge_classes)

Expand All @@ -189,4 +190,5 @@ def __data_generation(self, list_fpaths_temp):
X[i] = (X[i] - np.mean(X[i])) / np.std(X[i])

return X, y


46 changes: 24 additions & 22 deletions src/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,34 +196,36 @@ def cropVolumes(img1, img2, img3, img4):

def save_nifti(imgs2save):
for imgs in imgs2save:
nib.save(*imgs)
nib.save(*imgs)


# def save_npy(imgs2save):
# frst_slice = 0
# last_slice = 0
# seg = np.swapaxes(imgs2save["seg"]["modality"], 0, -1)
# for i in range(seg.shape[0]):
# curr_slice = seg[i, :, :]
# if np.sum(curr_slice) == 0:
# if last_slice <= frst_slice:
# frst_slice = i
# else:
# last_slice = i
# frst_slice += 1

# for name, data in imgs2save.items():
# modality = data["modality"]
# modality = np.swapaxes(modality, 0, -1)
# modality = modality[frst_slice:last_slice]
# with open(f"{data['path']}.npy", "wb") as f:
# np.save(f, modality)


def save_npy(imgs2save):
frst_slice = 0
last_slice = 0
seg = np.swapaxes(imgs2save["seg"]["modality"], 0, -1)
for i in range(seg.shape[0]):
curr_slice = seg[i, :, :]
if np.sum(curr_slice) == 0:
if last_slice <= frst_slice:
frst_slice = i
else:
last_slice = i
frst_slice += 1

for name, data in imgs2save.items():
modality = data["modality"]
path = data["path"]
modality = np.swapaxes(modality, 0, -1)
modality = modality[frst_slice:last_slice]
for i in range(modality.shape[0]):
curr_slice = modality[i, :, :]
if not os.path.isdir(path):
os.makedirs(path)
slice_dist_path = path + os.sep + str(i)
with open(f"{slice_dist_path}.npy", "wb") as f:
np.save(f, curr_slice)
with open(f"{data['path']}.npy", "wb") as f:
np.save(f, modality)


def preprocesse(imgs, dataset_name, dist_dir_path, mode="3D"):
Expand Down
1 change: 1 addition & 0 deletions src/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self,
self.dropout_prob_shift = dropout_prob_shift
self.batch_size = batch_size
self.model_depth = model_depth
self.bottleneck_depth = bottleneck_depth
self.dilate = dilate
self.max_dilation_rate = max_dilation_rate
self.name = name
Expand Down
110 changes: 58 additions & 52 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import glob
import re
import os
import random
import numpy as np
import nibabel as nib
from tqdm import tqdm


def get_3Dfpaths(data_dir):
def get_fpaths(data_dir, mode="3D"):
'''Parse all the filenames and create a dictionary for each patient with structure:
{
't1': <path to t1 MRI file>
Expand All @@ -18,13 +19,18 @@ def get_3Dfpaths(data_dir):
'''

# Get a list of files for all modalities individually
t1 = glob.glob(os.path.join(data_dir, '*/*t1.nii.gz'))
t2 = glob.glob(os.path.join(data_dir, '*/*t2.nii.gz'))
flair = glob.glob(os.path.join(data_dir, '*/*flair.nii.gz'))
t1ce = glob.glob(os.path.join(data_dir, '*/*t1ce.nii.gz'))
seg = glob.glob(os.path.join(data_dir, '*/*seg.nii.gz')) # Ground Truth
if mode == "3D":
ext = 'nii.gz'
pat = re.compile('.*_(\w*)\.nii\.gz')
elif mode == "2D":
ext = 'npy'
pat = re.compile('.*_(\w*)\.npy')

pat = re.compile('.*_(\w*)\.nii\.gz')
t1 = glob.glob(os.path.join(data_dir, f'*/*t1.{ext}'))
t2 = glob.glob(os.path.join(data_dir, f'*/*t2.{ext}'))
flair = glob.glob(os.path.join(data_dir, f'*/*flair.{ext}'))
t1ce = glob.glob(os.path.join(data_dir, f'*/*t1ce.{ext}'))
seg = glob.glob(os.path.join(data_dir, f'*/*seg.{ext}')) # Ground Truth

data_paths = [{
pat.findall(item)[0]:item
Expand All @@ -35,54 +41,36 @@ def get_3Dfpaths(data_dir):
return data_paths


def get_2Dfpaths(data_dir):
'''Parse all the filenames and create a dictionary for each patient with structure:
{
't1': list(<paths to t1 MRI file>)
't2': list(<paths to t2 MRI>)
'flair': list(<paths to FLAIR MRI file>)
't1ce': list(<paths to t1ce MRI file>)
'seg': list(<paths to Ground Truth file>)
}
'''

pat = re.compile('.*_(\w*)')
data_paths = []
for case in glob.glob(os.path.join(data_dir, '*')):
# Get a list of files for all modalities individually
t1 = sorted(glob.glob(os.path.join(case, '*t1/*.npy')))
t2 = sorted(glob.glob(os.path.join(case, '*t2/*.npy')))
flair = sorted(glob.glob(os.path.join(case, '*flair/*.npy')))
t1ce = sorted(glob.glob(os.path.join(case, '*t1ce/*.npy')))
seg = sorted(glob.glob(os.path.join(case, '*seg/*.npy'))) # Ground Truth

data = {}
for items in list(zip(t1, t2, t1ce, flair, seg)):
for item in items:
data[pat.findall(item)[0]] = data.get(pat.findall(item)[0], []) + [item]

data_paths.append(data)

return data_paths


def unpack_2D_fpaths(packed_data_paths):
def unpack_2D_fpaths(packed_data_paths, only_with_mask=True):
upacked_data_paths = []
mod_names = packed_data_paths[0].keys()
for paths in packed_data_paths:
t1 = paths['t1']
t2 = paths['t2']
flair = paths['flair']
t1ce = paths['t1ce']
seg = paths['seg']

frst_slice = 0
last_slice = 0
img = np.load(paths['seg'])
if only_with_mask:
for i in range(img.shape[0]):
curr_slice = img[i, :, :]
if np.sum(curr_slice) == 0:
if last_slice <= frst_slice:
frst_slice = i
else:
last_slice = i
frst_slice += 1
else:
last_slice = img.shape[0]
depth = 0
modalities = {}
for name, path in paths.items():
data = np.load(path)
data = data[frst_slice:last_slice]
depth = data.shape[0]
for i in range(depth):
modalities[name] = modalities.get(name, []) + [(path, i)]

for i in range(depth):
upacked_data_paths.append({name: modalities[name][i] for name in mod_names})

pat = re.compile('.*_(\w*)')

upacked_data_paths.extend([{
pat.findall(item)[0]:item
for item in items
}
for items in list(zip(t1, t2, t1ce, flair, seg))])
return upacked_data_paths


Expand Down Expand Up @@ -150,3 +138,21 @@ def get_preprocessed_data(data_paths:dict, scan_types=['t1', 'seg']):
data.append(scans)

return data


def get_dataset_split(data_paths, train_ratio=0.7, seed=42, shuffle=True):
random.seed(seed)

n_samples = len(data_paths)
n_train = int(n_samples*train_ratio)
n_test = int(n_samples*(train_ratio-1)/2)

if shuffle:
random.shuffle(data_paths)

train_data_paths = data_paths[:n_train]
test_data_paths = data_paths[n_train:]
val_data_paths = test_data_paths[:n_test]
test_data_paths = test_data_paths[n_test:]

return train_data_paths, test_data_paths, val_data_paths
Loading

0 comments on commit 78e6b8d

Please sign in to comment.