diff --git a/damo/base_models/backbones/tinynas_csp.py b/damo/base_models/backbones/tinynas_csp.py index 618b93b..560a21e 100644 --- a/damo/base_models/backbones/tinynas_csp.py +++ b/damo/base_models/backbones/tinynas_csp.py @@ -243,7 +243,7 @@ def __init__(self, convstem, act='relu', reparam=False, with_spp=False): 2, act=self.act) if self.with_spp: - self.spp = SPPBottleneck(hidden_dim * 2, hidden_dim * 2) + self.spp = SPPBottleneck(hidden_dim * 2, hidden_dim * 2, activation=self.act) if len(self.convstem) > 0: self.conv_start = ConvKXBNRELU(hidden_dim * 2, hidden_dim, diff --git a/damo/base_models/backbones/tinynas_mob.py b/damo/base_models/backbones/tinynas_mob.py index 286a502..3be5494 100644 --- a/damo/base_models/backbones/tinynas_mob.py +++ b/damo/base_models/backbones/tinynas_mob.py @@ -205,7 +205,7 @@ def __init__(self, self.block_list.append(the_block) if block_id == 0 and with_spp: self.block_list.append( - SPPBottleneck(out_channels, out_channels)) + SPPBottleneck(out_channels, out_channels, activation=act)) def forward(self, x): output = x diff --git a/damo/base_models/backbones/tinynas_res.py b/damo/base_models/backbones/tinynas_res.py index 73cf5a0..e39487f 100644 --- a/damo/base_models/backbones/tinynas_res.py +++ b/damo/base_models/backbones/tinynas_res.py @@ -128,7 +128,7 @@ def __init__(self, self.block_list.append(the_block) if block_id == 0 and with_spp: self.block_list.append( - SPPBottleneck(out_channels, out_channels)) + SPPBottleneck(out_channels, out_channels, activation=act)) def forward(self, x): output = x diff --git a/damo/base_models/core/ops.py b/damo/base_models/core/ops.py index 3bf3747..02ec062 100644 --- a/damo/base_models/core/ops.py +++ b/damo/base_models/core/ops.py @@ -39,6 +39,8 @@ def get_activation(name='silu', inplace=True): module = nn.SiLU(inplace=inplace) elif name == 'relu': module = nn.ReLU(inplace=inplace) + elif name == 'relu6': + module = nn.ReLU6(inplace=inplace) elif name == 'lrelu': module = nn.LeakyReLU(0.1, inplace=inplace) elif name == 'swish': diff --git a/damo/utils/visualize.py b/damo/utils/visualize.py index cfba883..7c2f78c 100644 --- a/damo/utils/visualize.py +++ b/damo/utils/visualize.py @@ -8,7 +8,7 @@ def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None): - + img = np.copy(img) # https://github.com/opencv/opencv/issues/24522 for i in range(len(boxes)): box = boxes[i] cls_id = int(cls_ids[i])