diff --git a/src/blocks.py b/src/blocks.py index ec34d38..7c0b1e2 100644 --- a/src/blocks.py +++ b/src/blocks.py @@ -1,13 +1,26 @@ -from tensorflow.keras.layers import Conv3D, BatchNormalization, Activation, SpatialDropout3D, Dropout +from tensorflow.keras.layers import Conv2D, Conv3D, BatchNormalization, Activation, SpatialDropout2D, SpatialDropout3D, Dropout -def conv3d_block(inputs, n_filters, conv_kwds, activation, dropout_prob, dropout_type=None, batchnorm=False): + +def conv_block(inputs, n_filters, conv_kwds, activation, dropout_prob, conv_type="3D", dropout_type=None, batchnorm=False): + if conv_type == "2D": + conv = Conv2D + spatial_dropout = SpatialDropout2D + elif conv_type == "3D": + conv = Conv3D + spatial_dropout = SpatialDropout3D + else: + raise ValueError(f"conv_type must be one of ['2D', '3D'], but got {conv_type}") + if dropout_type == "standard": dropout = Dropout elif dropout_type == "spatial": - dropout = SpatialDropout3D + dropout = spatial_dropout + else: + if dropout_type: + raise ValueError(f"dropout_type must be one of ['standard', 'spatial', None], but got {dropout_type}") # first layer - x = Conv3D(filters=n_filters, **conv_kwds)(inputs) + x = conv(filters=n_filters, **conv_kwds)(inputs) if batchnorm: x = BatchNormalization()(x) x = Activation(activation)(x) @@ -15,7 +28,7 @@ def conv3d_block(inputs, n_filters, conv_kwds, activation, dropout_prob, dropout x = dropout(dropout_prob)(x) # second layer - x = Conv3D(filters=n_filters, **conv_kwds)(x) + x = conv(filters=n_filters, **conv_kwds)(x) if batchnorm: x = BatchNormalization()(x) x = Activation(activation)(x) diff --git a/src/unet.py b/src/unet.py index f71f577..a1b82d9 100644 --- a/src/unet.py +++ b/src/unet.py @@ -1,22 +1,23 @@ from tensorflow.keras import Model -from tensorflow.keras.layers import Conv3D, Conv3DTranspose, Input, Activation, MaxPool3D, Concatenate +from tensorflow.keras.layers import Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose, Input, Activation, MaxPool2D, MaxPool3D, Concatenate -from blocks import conv3d_block +from blocks import conv_block -class Unet3D: +class Unet: def __init__(self, n_classes, input_shape, activation="relu", - n_base_filters=8, + n_base_filters=64, batchnorm=False, dropout_prob=0.2, dropout_type="spatial", dropout_prob_shift=0.1, batch_size=None, - model_depth=5, - name="3DUnet"): + model_depth=4, + name="Unet", + mode="3D"): self.n_classes = n_classes self.input_shape = input_shape self.activation = activation @@ -28,53 +29,114 @@ def __init__(self, self.batch_size = batch_size self.model_depth = model_depth self.name = name - + self.mode = mode + self.skips = [] + self.__set_layers() + self.__set_layers_prms() - self.conv_kwds = { - "kernel_size": (3, 3, 3), + def __set_layers(self): + if self.mode == "2D": + self.conv = Conv2D + self.transpose = Conv2DTranspose + self.maxpool = MaxPool2D + elif self.mode == "3D": + self.conv = Conv3D + self.transpose = Conv3DTranspose + self.maxpool = MaxPool3D + else: + raise ValueError(f"'mode' must be one of ['2D', '3D'], but got {self.mode}") + + def __set_layers_prms(self): + if self.mode == "2D": + self.conv_kwds = { + "kernel_size": (3, 3), "activation": None, "padding": "same", "kernel_initializer": "he_normal", - # 'kernel_regularizer': tf.keras.regularizers.l2(0.001), } - self.conv_transpose_kwds = { - "kernel_size": (2, 2, 2), - "strides": 2, + self.conv_transpose_kwds = { + "kernel_size": (2, 2), + "strides": 2, + "padding": "same", + "kernel_initializer": "he_normal", + } + elif self.mode == "3D": + self.conv_kwds = { + "kernel_size": (3, 3, 3), + "activation": None, "padding": "same", "kernel_initializer": "he_normal", - # 'kernel_regularizer': tf.keras.regularizers.l2(0.001), } + self.conv_transpose_kwds = { + "kernel_size": (2, 2, 2), + "strides": 2, + "padding": "same", + "kernel_initializer": "he_normal", + } + else: + raise ValueError(f"'mode' must be one of ['2D', '3D'], but got {self.mode}") + + def encoder(self, inputs): x = inputs for depth in range(self.model_depth): filters = self.n_base_filters * (2**depth) - x = conv3d_block(x, filters, self.conv_kwds, self.activation, self.dropout_prob, self.dropout_type, self.batchnorm) - if depth < self.model_depth - 1: + x = conv_block(inputs=x, + n_filters=filters, + conv_kwds=self.conv_kwds, + activation=self.activation, + dropout_prob=self.dropout_prob, + conv_type=self.mode, + dropout_type=self.dropout_type, + batchnorm=self.batchnorm) + if depth < self.model_depth: self.skips.append(x) - x = MaxPool3D(2)(x) + x = self.maxpool(2)(x) self.dropout_prob += self.dropout_prob_shift return x + def bottleneck(self, x): + filters = self.n_base_filters * (2**self.model_depth) + x = conv_block(inputs=x, + n_filters=filters, + conv_kwds=self.conv_kwds, + activation=self.activation, + dropout_prob=self.dropout_prob, + conv_type=self.mode, + dropout_type=self.dropout_type, + batchnorm=self.batchnorm) + return x + def decoder(self, x): - for depth in range(self.model_depth-1, 0, -1): - filters = self.n_base_filters * (2**depth) + for depth in range(self.model_depth, 0, -1): + filters_upsampling = self.n_base_filters * (2**depth) + filters_conv = self.n_base_filters * (2**(depth-1)) self.dropout_prob -= self.dropout_prob_shift - x = Conv3DTranspose(filters, **self.conv_transpose_kwds)(x) + + x = self.transpose(filters_upsampling, **self.conv_transpose_kwds)(x) x = Concatenate(axis=-1)([self.skips[depth-1], x]) - x = conv3d_block(x, filters, self.conv_kwds, self.activation, self.dropout_prob, self.dropout_type, self.batchnorm) + x = conv_block(inputs=x, + n_filters=filters_conv, + conv_kwds=self.conv_kwds, + activation=self.activation, + dropout_prob=self.dropout_prob, + conv_type=self.mode, + dropout_type=self.dropout_type, + batchnorm=self.batchnorm) - x = Conv3D(filters=self.n_classes, kernel_size=1)(x) + x = self.conv(filters=self.n_classes, kernel_size=1)(x) return x def build_model(self): inputs = Input(shape=self.input_shape, batch_size=self.batch_size) x = self.encoder(inputs) + x = self.bottleneck(x) x = self.decoder(x) final_activation = "sigmoid" if self.n_classes == 1 else "softmax"