Skip to content

Commit

Permalink
Add bunch of comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ViiSkor committed May 25, 2020
1 parent 6df9a7f commit 9f1f1cf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
6 changes: 2 additions & 4 deletions src/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def random_rotate(data, masks, degree_range=(-15, 15)):
degrees = np.arange(*degree_range, 1)
degrees = np.random.choice(a=degrees, size=1)
rot_deg = degrees[0]
if rot_deg != 0:
data = ndimage.rotate(data, rot_deg, reshape=False, axes=(1,2))
masks = ndimage.rotate(masks, rot_deg, reshape=False, axes=(1,2))
return data, masks
data = ndimage.rotate(data, rot_deg, reshape=False, axes=(1,2))
masks = ndimage.rotate(masks, rot_deg, reshape=False, axes=(1,2))
return data, masks
9 changes: 8 additions & 1 deletion src/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@


def get_layers(conv_type, dropout_type, mode="3D"):
"""
Raises:
ValueError: If conv_type is not one of "2D" or "3D"
ValueError: If dropout_type is not one of "spatial" or "standard"
"""

if conv_type == "2D":
conv = Conv2D
spatial_dropout = SpatialDropout2D
Expand All @@ -24,6 +30,7 @@ def get_layers(conv_type, dropout_type, mode="3D"):

return {'conv': conv, 'dropout': dropout}


def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False):
layers = get_layers(conv_type, dropout_type, mode=conv_type)
conv = layers['conv']
Expand Down Expand Up @@ -52,7 +59,7 @@ def dilate_conv_block(x, n_filters, max_dilation_rate, conv_kwds, activation, dr
dropout = layers['dropout']

dilates = []
for i in range(math.ceil(math.log(max_dilation_rate, 2))):
for i in range(math.ceil(math.log(max_dilation_rate, 2))+1):
x = conv(filters=n_filters, dilation_rate=2**i, **conv_kwds)(x)
if batchnorm:
x = BatchNormalization()(x)
Expand Down
31 changes: 30 additions & 1 deletion src/unet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose, Input, Activation, MaxPool2D, MaxPool3D, Concatenate
from tensorflow.keras.regularizers import l2

from blocks import conv_block, dilate_conv_block


class Unet:
"""
The class of U-Net architecture [1].
Attributes:
n_classes (int): Unique classes in the output mask.
input_shape: Tensor of shape [x, y, channels]/[x, y, z, channels]
activation (str): A tensorflow.keras.activations.Activation to use.
n_base_filters (int): Convolutional filters in the initial convolutional block. Will be doubled every block.
batchnorm (bool): Use Batch Normalisation or not
dropout_prob (float): The probobility to eluminate a nerone after the initial convolutional block. Set to 0. to turn Dropout off
dropout_type (one of "spatial" or "standard"): Type of Dropout to apply.
dropout_prob_shift (float between 0. and 1.): Factor to add to the Dropout after each conv block.
batch_size (int): The subset size of a training sample.
model_depth (int): The number of blocks in decoder and encoder.
dilate (bool): Set to True to use dilated convolution.
bottleneck_depth (int): Number of layers in the bottleneck. Not matter if dilate is True.
max_dilation_rate (int): Num of holes in the last conv layer in the bottleneck. Will set the number of layers in the bottleneck equal to ceil(log2(max_dilation_rate)-1.
name (str): Name of assembled model.
mode (one of "2D" or "3D"): Set type of U-Net.
Returns:
model (tensorflow.keras.models.Model): The built U-Net.
[1]: https://arxiv.org/abs/1505.04597
"""

def __init__(self,
n_classes,
input_shape,
Expand Down Expand Up @@ -60,27 +86,31 @@ def __set_layers_prms(self):
"activation": None,
"padding": "same",
"kernel_initializer": "he_normal",
"kernel_regularizer": l2(0.001)
}

self.conv_transpose_kwds = {
"kernel_size": (2, 2),
"strides": 2,
"padding": "same",
"kernel_initializer": "he_normal",
"kernel_regularizer": l2(0.001)
}
elif self.mode == "3D":
self.conv_kwds = {
"kernel_size": (3, 3, 3),
"activation": None,
"padding": "same",
"kernel_initializer": "he_normal",
"kernel_regularizer": l2(0.001)
}

self.conv_transpose_kwds = {
"kernel_size": (2, 2, 2),
"strides": 2,
"padding": "same",
"kernel_initializer": "he_normal",
"kernel_regularizer": l2(0.001)
}
else:
raise ValueError(f"'mode' must be one of ['2D', '3D'], but got {self.mode}")
Expand Down Expand Up @@ -161,7 +191,6 @@ def decoder(self, x):
x = self.conv(filters=self.n_classes, kernel_size=1)(x)
return x


def build_model(self):
inputs = Input(shape=self.input_shape, batch_size=self.batch_size)
x = self.encoder(inputs)
Expand Down

0 comments on commit 9f1f1cf

Please sign in to comment.