Skip to content

Commit

Permalink
Update nanodet tutorial too support configurable number of classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Idan-BenAmi committed Mar 17, 2024
1 parent a5b5365 commit 977fe12
Showing 1 changed file with 15 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,16 @@ def nanodet_generate_anchors(batch_size, featmap_sizes, strides):
anchors_list.append(anchors)
return np.concatenate(anchors_list, axis=1, dtype=float)

def nanodet_plus_head(n, feat_channels=128):
def nanodet_plus_head(n, feat_channels=128, num_classes=80):
feat_out = num_classes + 32
h = n
for idx in range(4):
h[idx] = depthwise_conv_module(n[idx], out_channels=feat_channels, stride=1, name_prefix='head.cls_convs.' + str(idx) + '.0')
h[idx] = depthwise_conv_module(h[idx], out_channels=feat_channels, stride=1, name_prefix='head.cls_convs.' + str(idx) + '.1')
h[idx] = Conv2D(112, 1, name='head.gfl_cls.' + str(idx))(h[idx])
h[idx] = Conv2D(feat_out, 1, name='head.gfl_cls.' + str(idx))(h[idx])
return h

def nanodet_box_decoding(h, res):
def nanodet_box_decoding(h, res, num_classes=80):
strides = [8, 16, 32, 64]
batch_size = 1
featmap_sizes = [(np.ceil(res / stride), np.ceil(res / stride)) for stride in strides]
Expand All @@ -309,7 +310,7 @@ def nanodet_box_decoding(h, res):
h_bbox = []
for idx in range(4):
# Split to 80 classes and 4 * 8 bounding boxes regression
cls, regr = tf.split(h[idx], [80, 32],-1)
cls, regr = tf.split(h[idx], [num_classes, 32],-1)
ndet = cls.shape[1] * cls.shape[2]

# Distributed Focal loss integral
Expand All @@ -323,7 +324,7 @@ def nanodet_box_decoding(h, res):
bbox = tf.stack([bbox1, bbox0, bbox3, bbox2], -1)
bbox = tf.expand_dims(bbox,2)

cls = tf.reshape(cls, [-1, ndet, 80])
cls = tf.reshape(cls, [-1, ndet, num_classes])
h_cls.append(cls)
h_bbox.append(bbox)
classes = Concatenate(axis=1, name='bb_dec_class')([h_cls[0], h_cls[1], h_cls[2], h_cls[3]])
Expand All @@ -332,24 +333,25 @@ def nanodet_box_decoding(h, res):
return classes, boxes

# Nanodet-Plus model definition
def nanodet_plus_m(input_shape, scale_factor, bottleneck_ratio, feat_channels):
def nanodet_plus_m(input_shape, scale_factor, bottleneck_ratio, feat_channels, num_classes=80):
"""
Create the Nanodet-Plus object detection model.
Args:
input_shape (tuple): The shape of input images (height, width, channels).
scale_factor (float): Scale factor for ShuffleNetV2 backbone.
bottleneck_ratio (float): Bottleneck ratio for ShuffleNetV2 backbone.
scale_factor (float): Scale factor for the ShuffleNetV2 backbone.
bottleneck_ratio (float): Bottleneck ratio for the ShuffleNetV2 backbone.
feat_channels (int): Number of feature channels.
num_classes (int): Number of output classes.
Returns:
tf.keras.Model: The Nanodet-Plus model.
Configuration options:
nanodet-plus-m-1.5x-416: input_shape = (416,416,3), scale_factor=1.5, feat_channels=128
nanodet-plus-m-1.5x-320: input_shape = (320,320,3), scale_factor=1.5, feat_channels=128
nanodet-plus-m-416: input_shape = (416,416,3), scale_factor=1.0, feat_channels=96
nanodet-plus-m-320: input_shape = (320,320,3), scale_factor=1.0, feat_channels=96
nanodet-plus-m-1.5x-416: input_shape = (416,416,3), scale_factor=1.5, bottleneck_ratio=0.5, feat_channels=128, num_classes=80
nanodet-plus-m-1.5x-320: input_shape = (320,320,3), scale_factor=1.5, bottleneck_ratio=0.5, feat_channels=128, num_classes=80
nanodet-plus-m-416: input_shape = (416,416,3), scale_factor=1.0, bottleneck_ratio=0.5, feat_channels=96, num_classes=80
nanodet-plus-m-320: input_shape = (320,320,3), scale_factor=1.0, bottleneck_ratio=0.5, feat_channels=96, num_classes=80
"""
# Nanodet backbone
Expand All @@ -359,7 +361,7 @@ def nanodet_plus_m(input_shape, scale_factor, bottleneck_ratio, feat_channels):
x = nanodet_ghostpan(x, out_channels=feat_channels, res=input_shape[0])

# Nanodet head
x = nanodet_plus_head(x, feat_channels=feat_channels)
x = nanodet_plus_head(x, feat_channels=feat_channels, num_classes=num_classes)

# Define Keras model
return Model(inputs, x, name=f'Nanodet_plus_m_{scale_factor}x_{input_shape[0]}')

0 comments on commit 977fe12

Please sign in to comment.