diff --git a/paz/models/segmentation/unet.py b/paz/models/segmentation/unet.py index 4fb7d92a4..378c62e1f 100644 --- a/paz/models/segmentation/unet.py +++ b/paz/models/segmentation/unet.py @@ -1,3 +1,6 @@ +from tensorflow.keras.applications import ConvNeXtTiny, ConvNeXtSmall +from tensorflow.keras.applications import ConvNeXtBase, ConvNeXtLarge +from tensorflow.keras.applications import ConvNeXtXLarge from tensorflow.keras.layers import Conv2DTranspose, Concatenate, UpSampling2D from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation from tensorflow.keras.layers import MaxPooling2D, Input @@ -6,6 +9,20 @@ from tensorflow.keras.applications import ResNet50V2 +def compute_upsampling_size(first_layer, second_layer): + """Function to compute the upsampling size + + # Arguments + first_layer: branch layer + second_layer: decoder layer + + # Returns + upsampling size + """ + size = int(first_layer.shape[1]/second_layer.shape[1]) + return size + + def convolution_block(inputs, filters, kernel_size=3, activation='relu'): """UNET convolution block containing Conv2D -> BatchNorm -> Activation @@ -25,7 +42,7 @@ def convolution_block(inputs, filters, kernel_size=3, activation='relu'): return x -def upsample_block(x, filters, branch): +def upsample_block(x, filters, branch, size): """UNET upsample block. This block upsamples ``x``, concatenates a ``branch`` tensor and applies two convolution blocks: Upsample -> Concatenate -> 2 x ConvBlock. @@ -38,7 +55,7 @@ def upsample_block(x, filters, branch): # Returns A Keras tensor. """ - x = UpSampling2D(size=2)(x) + x = UpSampling2D(size)(x) x = Concatenate(axis=3)([x, branch]) x = convolution_block(x, filters) x = convolution_block(x, filters) @@ -146,7 +163,8 @@ def build_UNET(num_classes, backbone, branch_tensors, x = convolution_block(x, 512) for branch, filters in zip(branch_tensors, decoder_filters): - x = decoder(x, filters, branch) + size = compute_upsampling_size(branch, x) + x = decoder(x, filters, branch, size) kwargs = {'use_bias': True, 'kernel_initializer': 'glorot_uniform'} x = Conv2D(num_classes, 3, (1, 1), 'same', **kwargs)(x) @@ -289,3 +307,173 @@ def UNET_RESNET50(num_classes=1, input_shape=(224, 224, 3), weights='imagenet', return UNET(input_shape, num_classes, RESNET50_branches, ResNet50V2, weights, freeze_backbone, activation, decoder_type, decode_filters, input_tensor, 'UNET-RESNET50') + + +def UNET_ConvNeXtTiny(num_classes=1, input_shape=(224, 224, 3), + weights='imagenet', freeze_backbone=False, + activation='sigmoid', decoder_type='upsample', + decode_filters=[256, 128, 64, 32, 16]): + """Build a UNET model with a ``ConvNeXtTiny`` backbone. + + # Arguments + input_shape: List of integers: ``(H, W, num_channels)``. + num_classes: Integer used for output number of channels. + branch_names: List of strings containing layer names of ``BACKBONE()``. + BACKBONE: Class for instantiating a backbone model + weights: String indicating backbone weights e.g. + ''imagenet'', ``None``. + freeze_backbone: Boolean. If True ``BACKBONE()`` updates are frozen. + decoder_type: String indicating decoding function e.g. + ''upsample ''transpose''. + decoder_filters: List of integers used in each application of decoder. + activation: Output activation of the model. + input_tensor: Input tensor. If given ``shape`` is overwritten and this + tensor is used instead as input. + name: String. indicating the name of the model. + + # Returns + A UNET-VGG16 Keras/tensorflow model. + """ + ConvNeXtTiny_branches = ['convnext_tiny_stage_2_block_8_identity', + 'convnext_tiny_stage_1_block_2_identity', + 'convnext_tiny_stage_0_block_2_identity', + 'convnext_tiny_prestem_normalization'] + return UNET(input_shape, num_classes, ConvNeXtTiny_branches, ConvNeXtTiny, + weights, freeze_backbone, activation, decoder_type, + decode_filters, name='UNET-ConvNeXtTiny') + + +def UNET_ConvNeXtSmall(num_classes=1, input_shape=(224, 224, 3), + weights='imagenet', freeze_backbone=False, + activation='sigmoid', decoder_type='upsample', + decode_filters=[256, 128, 64, 32, 16]): + """Build a UNET model with a ``ConvNeXtSmall`` backbone. + + # Arguments + input_shape: List of integers: ``(H, W, num_channels)``. + num_classes: Integer used for output number of channels. + branch_names: List of strings containing layer names of ``BACKBONE()``. + BACKBONE: Class for instantiating a backbone model + weights: String indicating backbone weights e.g. + ''imagenet'', ``None``. + freeze_backbone: Boolean. If True ``BACKBONE()`` updates are frozen. + decoder_type: String indicating decoding function e.g. + ''upsample ''transpose''. + decoder_filters: List of integers used in each application of decoder. + activation: Output activation of the model. + input_tensor: Input tensor. If given ``shape`` is overwritten and this + tensor is used instead as input. + name: String. indicating the name of the model. + + # Returns + A UNET-VGG16 Keras/tensorflow model. + """ + ConvNeXtSmall_branches = ['convnext_small_stage_2_block_8_identity', + 'convnext_small_stage_1_block_2_identity', + 'convnext_small_stage_0_block_2_identity', + 'convnext_small_prestem_normalization'] + return UNET(input_shape, num_classes, ConvNeXtSmall_branches, + ConvNeXtSmall, weights, freeze_backbone, activation, + decoder_type, decode_filters, name='UNET-ConvNeXtSmall') + + +def UNET_ConvNeXtBase(num_classes=1, input_shape=(224, 224, 3), + weights='imagenet', freeze_backbone=False, + activation='sigmoid', decoder_type='upsample', + decode_filters=[256, 128, 64, 32, 16]): + """Build a UNET model with a ``ConvNeXtBase`` backbone. + + # Arguments + input_shape: List of integers: ``(H, W, num_channels)``. + num_classes: Integer used for output number of channels. + branch_names: List of strings containing layer names of ``BACKBONE()``. + BACKBONE: Class for instantiating a backbone model + weights: String indicating backbone weights e.g. + ''imagenet'', ``None``. + freeze_backbone: Boolean. If True ``BACKBONE()`` updates are frozen. + decoder_type: String indicating decoding function e.g. + ''upsample ''transpose''. + decoder_filters: List of integers used in each application of decoder. + activation: Output activation of the model. + input_tensor: Input tensor. If given ``shape`` is overwritten and this + tensor is used instead as input. + name: String. indicating the name of the model. + + # Returns + A UNET-VGG16 Keras/tensorflow model. + """ + ConvNeXtBase_branches = ['convnext_base_stage_2_block_26_identity', + 'convnext_base_stage_1_block_2_identity', + 'convnext_base_stage_0_block_2_identity', + 'convnext_base_prestem_normalization'] + return UNET(input_shape, num_classes, ConvNeXtBase_branches, ConvNeXtBase, + weights, freeze_backbone, activation, decoder_type, + decode_filters, name='UNET-ConvNeXtBase') + + +def UNET_ConvNeXtLarge(num_classes=1, input_shape=(224, 224, 3), + weights='imagenet', freeze_backbone=False, + activation='sigmoid', decoder_type='upsample', + decode_filters=[256, 128, 64, 32, 16]): + """Build a UNET model with a ``ConvNeXtLarge`` backbone. + + # Arguments + input_shape: List of integers: ``(H, W, num_channels)``. + num_classes: Integer used for output number of channels. + branch_names: List of strings containing layer names of ``BACKBONE()``. + BACKBONE: Class for instantiating a backbone model + weights: String indicating backbone weights e.g. + ''imagenet'', ``None``. + freeze_backbone: Boolean. If True ``BACKBONE()`` updates are frozen. + decoder_type: String indicating decoding function e.g. + ''upsample ''transpose''. + decoder_filters: List of integers used in each application of decoder. + activation: Output activation of the model. + input_tensor: Input tensor. If given ``shape`` is overwritten and this + tensor is used instead as input. + name: String. indicating the name of the model. + + # Returns + A UNET-VGG16 Keras/tensorflow model. + """ + ConvNeXtLarge_branches = ['convnext_large_stage_2_block_26_identity', + 'convnext_large_stage_1_block_2_identity', + 'convnext_large_stage_0_block_2_identity', + 'convnext_large_prestem_normalization'] + return UNET(input_shape, num_classes, ConvNeXtLarge_branches, + ConvNeXtLarge, weights, freeze_backbone, activation, + decoder_type, decode_filters, name='UNET-ConvNeXtLarge') + + +def UNET_ConvNeXtXLarge(num_classes=1, input_shape=(224, 224, 3), + weights='imagenet', freeze_backbone=False, + activation='sigmoid', decoder_type='upsample', + decode_filters=[256, 128, 64, 32, 16]): + """Build a UNET model with a ``ConvNeXtXLarge`` backbone. + + # Arguments + input_shape: List of integers: ``(H, W, num_channels)``. + num_classes: Integer used for output number of channels. + branch_names: List of strings containing layer names of ``BACKBONE()``. + BACKBONE: Class for instantiating a backbone model + weights: String indicating backbone weights e.g. + ''imagenet'', ``None``. + freeze_backbone: Boolean. If True ``BACKBONE()`` updates are frozen. + decoder_type: String indicating decoding function e.g. + ''upsample ''transpose''. + decoder_filters: List of integers used in each application of decoder. + activation: Output activation of the model. + input_tensor: Input tensor. If given ``shape`` is overwritten and this + tensor is used instead as input. + name: String. indicating the name of the model. + + # Returns + A UNET-VGG16 Keras/tensorflow model. + """ + ConvNeXtXLarge_branches = ['convnext_xlarge_stage_2_block_26_identity', + 'convnext_xlarge_stage_1_block_2_identity', + 'convnext_xlarge_stage_0_block_2_identity', + 'convnext_xlarge_prestem_normalization'] + return UNET(input_shape, num_classes, ConvNeXtXLarge_branches, + ConvNeXtXLarge, weights, freeze_backbone, activation, + decoder_type, decode_filters, name='UNET-ConvNeXtXLarge')