Skip to content

Commit

Permalink
Add support for 2D U-Net
Browse files Browse the repository at this point in the history
  • Loading branch information
ViiSkor committed May 8, 2020
1 parent 4848891 commit d33b3d2
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 27 deletions.
23 changes: 18 additions & 5 deletions src/blocks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
from tensorflow.keras.layers import Conv3D, BatchNormalization, Activation, SpatialDropout3D, Dropout
from tensorflow.keras.layers import Conv2D, Conv3D, BatchNormalization, Activation, SpatialDropout2D, SpatialDropout3D, Dropout

def conv3d_block(inputs, n_filters, conv_kwds, activation, dropout_prob, dropout_type=None, batchnorm=False):

def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False):
if conv_type == "2D":
conv = Conv2D
spatial_dropout = SpatialDropout2D
elif conv_type == "3D":
conv = Conv3D
spatial_dropout = SpatialDropout3D
else:
raise ValueError(f"conv_type must be one of ['2D', '3D'], but got {conv_type}")

if dropout_type == "standard":
dropout = Dropout
elif dropout_type == "spatial":
dropout = SpatialDropout3D
dropout = spatial_dropout
else:
if dropout_type:
raise ValueError(f"dropout_type must be one of ['standard', 'spatial', None], but got {dropout_type}")

# first layer
x = Conv3D(filters=n_filters, **conv_kwds)(inputs)
x = conv(filters=n_filters, **conv_kwds)(inputs)
if batchnorm:
x = BatchNormalization()(x)
x = Activation(activation)(x)
if dropout_type and dropout_prob > 0.0:
x = dropout(dropout_prob)(x)

# second layer
x = Conv3D(filters=n_filters, **conv_kwds)(x)
x = conv(filters=n_filters, **conv_kwds)(x)
if batchnorm:
x = BatchNormalization()(x)
x = Activation(activation)(x)
Expand Down
106 changes: 84 additions & 22 deletions src/unet.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, Input, Activation, MaxPool3D, Concatenate
from tensorflow.keras.layers import Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose, Input, Activation, MaxPool2D, MaxPool3D, Concatenate

from blocks import conv3d_block
from blocks import conv_block


class Unet3D:
class Unet:
def __init__(self,
n_classes,
input_shape,
activation="relu",
n_base_filters=8,
n_base_filters=64,
batchnorm=False,
dropout_prob=0.2,
dropout_type="spatial",
dropout_prob_shift=0.1,
batch_size=None,
model_depth=5,
name="3DUnet"):
model_depth=4,
name="Unet",
mode="3D"):
self.n_classes = n_classes
self.input_shape = input_shape
self.activation = activation
Expand All @@ -28,53 +29,114 @@ def __init__(self,
self.batch_size = batch_size
self.model_depth = model_depth
self.name = name

self.mode = mode

self.skips = []
self.__set_layers()
self.__set_layers_prms()

self.conv_kwds = {
"kernel_size": (3, 3, 3),
def __set_layers(self):
if self.mode == "2D":
self.conv = Conv2D
self.transpose = Conv2DTranspose
self.maxpool = MaxPool2D
elif self.mode == "3D":
self.conv = Conv3D
self.transpose = Conv3DTranspose
self.maxpool = MaxPool3D
else:
raise ValueError(f"'mode' must be one of ['2D', '3D'], but got {self.mode}")

def __set_layers_prms(self):
if self.mode == "2D":
self.conv_kwds = {
"kernel_size": (3, 3),
"activation": None,
"padding": "same",
"kernel_initializer": "he_normal",
# 'kernel_regularizer': tf.keras.regularizers.l2(0.001),
}

self.conv_transpose_kwds = {
"kernel_size": (2, 2, 2),
"strides": 2,
self.conv_transpose_kwds = {
"kernel_size": (2, 2),
"strides": 2,
"padding": "same",
"kernel_initializer": "he_normal",
}
elif self.mode == "3D":
self.conv_kwds = {
"kernel_size": (3, 3, 3),
"activation": None,
"padding": "same",
"kernel_initializer": "he_normal",
# 'kernel_regularizer': tf.keras.regularizers.l2(0.001),
}

self.conv_transpose_kwds = {
"kernel_size": (2, 2, 2),
"strides": 2,
"padding": "same",
"kernel_initializer": "he_normal",
}
else:
raise ValueError(f"'mode' must be one of ['2D', '3D'], but got {self.mode}")


def encoder(self, inputs):
x = inputs
for depth in range(self.model_depth):
filters = self.n_base_filters * (2**depth)
x = conv3d_block(x, filters, self.conv_kwds, self.activation, self.dropout_prob, self.dropout_type, self.batchnorm)
if depth < self.model_depth - 1:
x = conv_block(inputs=x,
n_filters=filters,
conv_kwds=self.conv_kwds,
activation=self.activation,
dropout_prob=self.dropout_prob,
conv_type=self.mode,
dropout_type=self.dropout_type,
batchnorm=self.batchnorm)
if depth < self.model_depth:
self.skips.append(x)
x = MaxPool3D(2)(x)
x = self.maxpool(2)(x)

self.dropout_prob += self.dropout_prob_shift

return x

def bottleneck(self, x):
filters = self.n_base_filters * (2**self.model_depth)
x = conv_block(inputs=x,
n_filters=filters,
conv_kwds=self.conv_kwds,
activation=self.activation,
dropout_prob=self.dropout_prob,
conv_type=self.mode,
dropout_type=self.dropout_type,
batchnorm=self.batchnorm)
return x

def decoder(self, x):
for depth in range(self.model_depth-1, 0, -1):
filters = self.n_base_filters * (2**depth)
for depth in range(self.model_depth, 0, -1):
filters_upsampling = self.n_base_filters * (2**depth)
filters_conv = self.n_base_filters * (2**(depth-1))
self.dropout_prob -= self.dropout_prob_shift
x = Conv3DTranspose(filters, **self.conv_transpose_kwds)(x)

x = self.transpose(filters_upsampling, **self.conv_transpose_kwds)(x)
x = Concatenate(axis=-1)([self.skips[depth-1], x])
x = conv3d_block(x, filters, self.conv_kwds, self.activation, self.dropout_prob, self.dropout_type, self.batchnorm)
x = conv_block(inputs=x,
n_filters=filters_conv,
conv_kwds=self.conv_kwds,
activation=self.activation,
dropout_prob=self.dropout_prob,
conv_type=self.mode,
dropout_type=self.dropout_type,
batchnorm=self.batchnorm)

x = Conv3D(filters=self.n_classes, kernel_size=1)(x)
x = self.conv(filters=self.n_classes, kernel_size=1)(x)
return x


def build_model(self):
inputs = Input(shape=self.input_shape, batch_size=self.batch_size)
x = self.encoder(inputs)
x = self.bottleneck(x)
x = self.decoder(x)

final_activation = "sigmoid" if self.n_classes == 1 else "softmax"
Expand Down

0 comments on commit d33b3d2

Please sign in to comment.