diff --git a/tutorials/mct_model_garden/models_keras/nanodet/nanodet_keras_model.py b/tutorials/mct_model_garden/models_keras/nanodet/nanodet_keras_model.py index 66181b234..c1cad1284 100644 --- a/tutorials/mct_model_garden/models_keras/nanodet/nanodet_keras_model.py +++ b/tutorials/mct_model_garden/models_keras/nanodet/nanodet_keras_model.py @@ -290,15 +290,16 @@ def nanodet_generate_anchors(batch_size, featmap_sizes, strides): anchors_list.append(anchors) return np.concatenate(anchors_list, axis=1, dtype=float) -def nanodet_plus_head(n, feat_channels=128): +def nanodet_plus_head(n, feat_channels=128, num_classes=80): + feat_out = num_classes + 32 h = n for idx in range(4): h[idx] = depthwise_conv_module(n[idx], out_channels=feat_channels, stride=1, name_prefix='head.cls_convs.' + str(idx) + '.0') h[idx] = depthwise_conv_module(h[idx], out_channels=feat_channels, stride=1, name_prefix='head.cls_convs.' + str(idx) + '.1') - h[idx] = Conv2D(112, 1, name='head.gfl_cls.' + str(idx))(h[idx]) + h[idx] = Conv2D(feat_out, 1, name='head.gfl_cls.' + str(idx))(h[idx]) return h -def nanodet_box_decoding(h, res): +def nanodet_box_decoding(h, res, num_classes=80): strides = [8, 16, 32, 64] batch_size = 1 featmap_sizes = [(np.ceil(res / stride), np.ceil(res / stride)) for stride in strides] @@ -309,7 +310,7 @@ def nanodet_box_decoding(h, res): h_bbox = [] for idx in range(4): # Split to 80 classes and 4 * 8 bounding boxes regression - cls, regr = tf.split(h[idx], [80, 32],-1) + cls, regr = tf.split(h[idx], [num_classes, 32],-1) ndet = cls.shape[1] * cls.shape[2] # Distributed Focal loss integral @@ -323,7 +324,7 @@ def nanodet_box_decoding(h, res): bbox = tf.stack([bbox1, bbox0, bbox3, bbox2], -1) bbox = tf.expand_dims(bbox,2) - cls = tf.reshape(cls, [-1, ndet, 80]) + cls = tf.reshape(cls, [-1, ndet, num_classes]) h_cls.append(cls) h_bbox.append(bbox) classes = Concatenate(axis=1, name='bb_dec_class')([h_cls[0], h_cls[1], h_cls[2], h_cls[3]]) @@ -332,24 +333,25 @@ def nanodet_box_decoding(h, res): return classes, boxes # Nanodet-Plus model definition -def nanodet_plus_m(input_shape, scale_factor, bottleneck_ratio, feat_channels): +def nanodet_plus_m(input_shape, scale_factor, bottleneck_ratio, feat_channels, num_classes=80): """ Create the Nanodet-Plus object detection model. Args: input_shape (tuple): The shape of input images (height, width, channels). - scale_factor (float): Scale factor for ShuffleNetV2 backbone. - bottleneck_ratio (float): Bottleneck ratio for ShuffleNetV2 backbone. + scale_factor (float): Scale factor for the ShuffleNetV2 backbone. + bottleneck_ratio (float): Bottleneck ratio for the ShuffleNetV2 backbone. feat_channels (int): Number of feature channels. + num_classes (int): Number of output classes. Returns: tf.keras.Model: The Nanodet-Plus model. Configuration options: - nanodet-plus-m-1.5x-416: input_shape = (416,416,3), scale_factor=1.5, feat_channels=128 - nanodet-plus-m-1.5x-320: input_shape = (320,320,3), scale_factor=1.5, feat_channels=128 - nanodet-plus-m-416: input_shape = (416,416,3), scale_factor=1.0, feat_channels=96 - nanodet-plus-m-320: input_shape = (320,320,3), scale_factor=1.0, feat_channels=96 + nanodet-plus-m-1.5x-416: input_shape = (416,416,3), scale_factor=1.5, bottleneck_ratio=0.5, feat_channels=128, num_classes=80 + nanodet-plus-m-1.5x-320: input_shape = (320,320,3), scale_factor=1.5, bottleneck_ratio=0.5, feat_channels=128, num_classes=80 + nanodet-plus-m-416: input_shape = (416,416,3), scale_factor=1.0, bottleneck_ratio=0.5, feat_channels=96, num_classes=80 + nanodet-plus-m-320: input_shape = (320,320,3), scale_factor=1.0, bottleneck_ratio=0.5, feat_channels=96, num_classes=80 """ # Nanodet backbone @@ -359,7 +361,7 @@ def nanodet_plus_m(input_shape, scale_factor, bottleneck_ratio, feat_channels): x = nanodet_ghostpan(x, out_channels=feat_channels, res=input_shape[0]) # Nanodet head - x = nanodet_plus_head(x, feat_channels=feat_channels) + x = nanodet_plus_head(x, feat_channels=feat_channels, num_classes=num_classes) # Define Keras model return Model(inputs, x, name=f'Nanodet_plus_m_{scale_factor}x_{input_shape[0]}')