Skip to content

Commit

Permalink
Add 2D data generator
Browse files Browse the repository at this point in the history
  • Loading branch information
ViiSkor committed May 9, 2020
1 parent 1e1ae39 commit e796811
Showing 1 changed file with 101 additions and 11 deletions.
112 changes: 101 additions & 11 deletions src/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
import nibabel as nib
from tensorflow.keras.utils import Sequence

from augmentation import pad, crop, augment
from preprocessing import preprocess_label
from augmentation import augment
from preprocessing import pad, crop, preprocess_label


class MedDataGenerator(Sequence):
class Med3DDataGenerator(Sequence):
'Generates data for Keras'
def __init__(self, list_fpaths,
batch_size=8, dim=(144, 192, 160),
n_channels=1,
n_classes=1, scan_types=['t1'],
scan_types=['t1'],
output_classes=['ncr'],
merge_classes=False,
shuffle=True,
shuffle=False,
hist_dist=False,
flip=False,
rand_rot=False):
Expand All @@ -23,11 +22,13 @@ def __init__(self, list_fpaths,
self.dim = dim
self.dim_before_axes_swap = (dim[-1], dim[1], dim[0])
self.list_fpaths = list_fpaths
self.n_channels = n_channels
self.n_classes = n_classes
self.merge_classes = merge_classes
self.scan_types = scan_types
self.output_classes = output_classes
self.n_channels = len(scan_types)
self.n_classes = len(output_classes)
if self.merge_classes:
self.n_classes = 1
self.shuffle = shuffle
self.augment_params = {'hist_dist': None,
'flip': flip,
Expand Down Expand Up @@ -82,7 +83,6 @@ def __data_generation(self, list_fpaths_temp):
merge_classes=self.merge_classes)

modalities, masks = crop(modalities, masks, depth=self.dim[0])

modalities, masks = pad(modalities, masks, masks.shape[1:],
self.dim_before_axes_swap,
self.n_channels,
Expand All @@ -96,7 +96,97 @@ def __data_generation(self, list_fpaths_temp):
X[i] = np.swapaxes(modalities, 0, -2)
y[i] = np.swapaxes(masks, 0, -2)

# X[i] /= 255.
X[i] = (X[i] - np.mean(X[i])) / np.std(X[i])

return X, y
return X, y


class Med2DDataGenerator(Sequence):
'Generates data for Keras'
def __init__(self, list_fpaths,
batch_size=8, dim=(192, 160),
scan_types=['t1'],
output_classes=['ncr'],
merge_classes=False,
shuffle=False,
hist_dist=False,
flip=False,
rand_rot=False):
'Initialization'
self.batch_size = batch_size
self.dim = dim
self.list_fpaths = list_fpaths
self.merge_classes = merge_classes
self.scan_types = scan_types
self.output_classes = output_classes
self.n_channels = len(scan_types)
self.n_classes = len(output_classes)
if self.merge_classes:
self.n_classes = 1
self.shuffle = shuffle
self.augment_params = {'hist_dist': None,
'flip': flip,
'rand_rot': rand_rot}
if hist_dist:
self.augment_params['hist_dist'] = {
'shift': {
'mu': 0,
'std': 0.25
},
'scale': {
'mu': 1,
'std': 0.25
}
}
self.on_epoch_end()

def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_fpaths) / self.batch_size))

def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

# Find list of IDs
list_fpaths_temp = [self.list_fpaths[k] for k in indexes]

# Generate data
X, y = self.__data_generation(list_fpaths_temp)

return X, y

def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_fpaths))
if self.shuffle == True:
np.random.shuffle(self.indexes)

def __data_generation(self, list_fpaths_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.empty((self.batch_size, *self.dim, self.n_channels))
y = np.empty((self.batch_size, *self.dim, self.n_classes), dtype=np.float32)

# Generate data
for i, imgs in enumerate(list_fpaths_temp):
modalities = np.array([np.load(imgs[m]) for m in self.scan_types])
masks = preprocess_label(np.load(imgs['seg']),
output_classes=self.output_classes,
merge_classes=self.merge_classes)

modalities, masks = pad(modalities, masks, masks.shape[1:],
self.dim,
self.n_channels,
self.n_classes)

modalities, masks = augment(modalities, masks, self.augment_params)

X[i] = np.moveaxis(modalities, 0, -1)
y[i] = np.moveaxis(masks, 0, -1)

X[i] = (X[i] - np.mean(X[i])) / np.std(X[i])

return X, y

0 comments on commit e796811

Please sign in to comment.