From 258a929cbc2c2596f8035272a3dec062dbac4a7a Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 3 Mar 2023 20:55:18 +0800 Subject: [PATCH 01/13] support rtmdet ins training --- ...tmdet-ins_l_syncbn_fast_8xb32-300e_coco.py | 336 ++++++++++++++++++ ...tmdet-ins_s_syncbn_fast_8xb32-300e_coco.py | 33 ++ mmyolo/models/dense_heads/__init__.py | 4 +- mmyolo/models/dense_heads/rtmdet_ins_head.py | 179 +++++++++- mmyolo/models/dense_heads/yolov5_head.py | 6 +- .../assigners/batch_dsl_assigner.py | 7 +- 6 files changed, 554 insertions(+), 11 deletions(-) create mode 100644 configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py create mode 100644 configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py diff --git a/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py new file mode 100644 index 000000000..cf45b9a99 --- /dev/null +++ b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py @@ -0,0 +1,336 @@ +_base_ = ['../_base_/default_runtime.py'] + +# ========================Frequently modified parameters====================== +# -----data related----- +data_root = 'data/coco100/' +# Path of train annotation file +train_ann_file = 'annotations/instances_train2017.json' +train_data_prefix = 'train2017/' # Prefix of train image path +# Path of val annotation file +val_ann_file = 'annotations/instances_val2017.json' +val_data_prefix = 'val2017/' # Prefix of val image path + +num_classes = 80 # Number of classes for classification +# Batch size of a single GPU during training +train_batch_size_per_gpu = 32 +# Worker to pre-fetch data for each single GPU during training +train_num_workers = 10 +# persistent_workers must be False if num_workers is 0. +persistent_workers = True + +# -----train val related----- +# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs +base_lr = 0.004 +max_epochs = 300 # Maximum training epochs +# Change train_pipeline for final 20 epochs (stage 2) +num_epochs_stage2 = 20 + +model_test_cfg = dict( + # The config of multi-label for multi-class prediction. + multi_label=True, + # The number of boxes before NMS + nms_pre=1000, + score_thr=0.05, # Threshold to filter out boxes. + nms=dict(type='nms', iou_threshold=0.6), # NMS type and threshold + max_per_img=100, # Max number of detections of each image + mask_thr_binary=0.5) # Threshold of binary mask + +# ========================Possible modified parameters======================== +# -----data related----- +img_scale = (640, 640) # width, height +# ratio range for random resize +random_resize_ratio_range = (0.1, 2.0) +# Cached images number in mosaic +mosaic_max_cached_images = 40 +# Number of cached images in mixup +mixup_max_cached_images = 20 +# Dataset type, this will be used to define the dataset +dataset_type = 'YOLOv5CocoDataset' +# Batch size of a single GPU during validation +val_batch_size_per_gpu = 32 +# Worker to pre-fetch data for each single GPU during validation +val_num_workers = 10 +use_mask2refine = True +copypaste_prob = 0.3 + +# Config of batch shapes. Only on val. +batch_shapes_cfg = dict( + type='BatchShapePolicy', + batch_size=val_batch_size_per_gpu, + img_size=img_scale[0], + size_divisor=32, + extra_pad_ratio=0.5) + +# -----model related----- +# The scaling factor that controls the depth of the network structure +deepen_factor = 1.0 +# The scaling factor that controls the width of the network structure +widen_factor = 1.0 +# Strides of multi-scale prior box +strides = [8, 16, 32] + +norm_cfg = dict(type='BN') # Normalization config + +# -----train val related----- +lr_start_factor = 1.0e-5 +dsl_topk = 13 # Number of bbox selected in each level +loss_cls_weight = 1.0 +loss_bbox_weight = 2.0 +loss_mask_weight = 2.0 +qfl_beta = 2.0 # beta of QualityFocalLoss +weight_decay = 0.05 + +# Save model checkpoint and validation intervals +save_checkpoint_intervals = 10 +# validation intervals in stage 2 +val_interval_stage2 = 1 +# The maximum checkpoints to keep. +max_keep_ckpts = 3 +# single-scale training is recommended to +# be turned on, which can speed up training. +env_cfg = dict(cudnn_benchmark=True) + +# ===============================Unmodified in most cases==================== +model = dict( + type='YOLODetector', + data_preprocessor=dict( + type='YOLOv5DetDataPreprocessor', + mean=[103.53, 116.28, 123.675], + std=[57.375, 57.12, 58.395], + bgr_to_rgb=False), + backbone=dict( + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=deepen_factor, + widen_factor=widen_factor, + channel_attention=True, + norm_cfg=norm_cfg, + act_cfg=dict(type='SiLU', inplace=True)), + neck=dict( + type='CSPNeXtPAFPN', + deepen_factor=deepen_factor, + widen_factor=widen_factor, + in_channels=[256, 512, 1024], + out_channels=256, + num_csp_blocks=3, + expand_ratio=0.5, + norm_cfg=norm_cfg, + act_cfg=dict(type='SiLU', inplace=True)), + bbox_head=dict( + type='RTMDetInsHead', + head_module=dict( + type='RTMDetInsSepBNHeadModule', + num_classes=num_classes, + in_channels=256, + stacked_convs=2, + feat_channels=256, + norm_cfg=norm_cfg, + act_cfg=dict(type='SiLU', inplace=True), + share_conv=True, + pred_kernel_size=1, + featmap_strides=strides), + prior_generator=dict( + type='mmdet.MlvlPointGenerator', offset=0, strides=strides), + bbox_coder=dict(type='DistancePointBBoxCoder'), + loss_cls=dict( + type='mmdet.QualityFocalLoss', + use_sigmoid=True, + beta=qfl_beta, + loss_weight=loss_cls_weight), + loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=loss_bbox_weight), + loss_mask=dict( + type='mmdet.DiceLoss', + loss_weight=loss_mask_weight, + eps=5e-6, + reduction='mean')), + train_cfg=dict( + assigner=dict( + type='BatchDynamicSoftLabelAssigner', + num_classes=num_classes, + topk=dsl_topk, + iou_calculator=dict(type='mmdet.BboxOverlaps2D')), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=model_test_cfg, +) + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + mask2bbox=use_mask2refine), + dict( + type='Mosaic', + img_scale=img_scale, + use_cached=True, + max_cached_images=mosaic_max_cached_images, + pad_val=114.0), + dict(type='YOLOv5CopyPaste', prob=copypaste_prob), + dict( + type='mmdet.RandomResize', + # img_scale is (width, height) + scale=(img_scale[0] * 2, img_scale[1] * 2), + ratio_range=random_resize_ratio_range, + resize_type='mmdet.Resize', + keep_ratio=True), + dict( + type='mmdet.RandomCrop', + crop_size=img_scale, + recompute_bbox=True, + allow_negative_crop=True), + dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict(type='mmdet.RandomFlip', prob=0.5), + dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), + dict( + type='YOLOv5MixUp', + use_cached=True, + max_cached_images=mixup_max_cached_images), + dict(type='mmdet.PackDetInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + mask2bbox=use_mask2refine), + dict( + type='mmdet.RandomResize', + scale=img_scale, + ratio_range=random_resize_ratio_range, + resize_type='mmdet.Resize', + keep_ratio=True), + dict( + type='mmdet.RandomCrop', + crop_size=img_scale, + recompute_bbox=True, + allow_negative_crop=True), + dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict(type='mmdet.RandomFlip', prob=0.5), + dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), + dict(type='mmdet.PackDetInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), + dict(type='YOLOv5KeepRatioResize', scale=img_scale), + dict( + type='LetterResize', + scale=img_scale, + allow_scale_up=False, + pad_val=dict(img=114)), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + _scope_='mmdet'), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'pad_param')) +] + +train_dataloader = dict( + batch_size=train_batch_size_per_gpu, + num_workers=train_num_workers, + persistent_workers=persistent_workers, + pin_memory=True, + collate_fn=dict(type='yolov5_collate'), + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=train_ann_file, + data_prefix=dict(img=train_data_prefix), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=val_batch_size_per_gpu, + num_workers=val_num_workers, + persistent_workers=persistent_workers, + pin_memory=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=val_ann_file, + data_prefix=dict(img=val_data_prefix), + test_mode=True, + batch_shapes_cfg=batch_shapes_cfg, + pipeline=test_pipeline)) + +test_dataloader = val_dataloader + +# Reduce evaluation time +val_evaluator = dict( + type='mmdet.CocoMetric', + proposal_nums=(100, 1, 10), + ann_file=data_root + val_ann_file, + metric=['bbox', 'segm']) +test_evaluator = val_evaluator + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=weight_decay), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=lr_start_factor, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=save_checkpoint_intervals, + max_keep_ckpts=max_keep_ckpts # only keep latest 3 checkpoints + )) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + strict_load=False, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - num_epochs_stage2, + switch_pipeline=train_pipeline_stage2) +] + +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_interval=save_checkpoint_intervals, + dynamic_intervals=[(max_epochs - num_epochs_stage2, val_interval_stage2)]) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') diff --git a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py new file mode 100644 index 000000000..4acbde6cb --- /dev/null +++ b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py @@ -0,0 +1,33 @@ +_base_ = './rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py' +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa + +# ========================modified parameters====================== +deepen_factor = 0.33 +widen_factor = 0.5 +img_scale = _base_.img_scale + +# ratio range for random resize +random_resize_ratio_range = (0.5, 2.0) +# Number of cached images in mosaic +mosaic_max_cached_images = 40 +# Number of cached images in mixup +mixup_max_cached_images = 20 + +# =======================Unmodified in most cases================== +model = dict( + backbone=dict( + deepen_factor=deepen_factor, + widen_factor=widen_factor, + # Since the checkpoint includes CUDA:0 data, + # it must be forced to set map_location. + # Once checkpoint is fixed, it can be removed. + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint=checkpoint, + map_location='cpu')), + neck=dict( + deepen_factor=deepen_factor, + widen_factor=widen_factor, + ), + bbox_head=dict(head_module=dict(widen_factor=widen_factor))) diff --git a/mmyolo/models/dense_heads/__init__.py b/mmyolo/models/dense_heads/__init__.py index a95abd611..fec1ee05e 100644 --- a/mmyolo/models/dense_heads/__init__.py +++ b/mmyolo/models/dense_heads/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ppyoloe_head import PPYOLOEHead, PPYOLOEHeadModule from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule -from .rtmdet_ins_head import RTMDetInsSepBNHead, RTMDetInsSepBNHeadModule +from .rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHeadModule from .rtmdet_rotated_head import (RTMDetRotatedHead, RTMDetRotatedSepBNHeadModule) from .yolov5_head import YOLOv5Head, YOLOv5HeadModule @@ -15,6 +15,6 @@ 'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead', 'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule', 'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule', - 'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead', + 'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsHead', 'RTMDetInsSepBNHeadModule' ] diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index 1d0562aad..aa2e61d55 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -9,9 +9,10 @@ from mmcv.cnn import ConvModule, is_norm from mmcv.ops import batched_nms from mmdet.models.utils import filter_scores_and_topk -from mmdet.structures.bbox import get_box_tensor, get_box_wh, scale_boxes +from mmdet.structures.bbox import (distance2bbox, get_box_tensor, get_box_wh, + scale_boxes) from mmdet.utils import (ConfigType, InstanceList, OptConfigType, - OptInstanceList, OptMultiConfig) + OptInstanceList, OptMultiConfig, reduce_mean) from mmengine import ConfigDict from mmengine.model import (BaseModule, bias_init_with_prob, constant_init, normal_init) @@ -19,6 +20,7 @@ from torch import Tensor from mmyolo.registry import MODELS +from ..utils import gt_instances_preprocess from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule @@ -185,7 +187,7 @@ def _init_layers(self): norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.cls_convs.append(cls_convs) - self.reg_convs.append(cls_convs) + self.reg_convs.append(reg_convs) self.kernel_convs.append(kernel_convs) self.rtm_cls.append( @@ -212,6 +214,7 @@ def _init_layers(self): for i in range(self.stacked_convs): self.cls_convs[n][i].conv = self.cls_convs[0][i].conv self.reg_convs[n][i].conv = self.reg_convs[0][i].conv + self.kernel_convs[n][i].conv = self.kernel_convs[0][i].conv self.mask_head = MaskFeatModule( in_channels=self.in_channels, @@ -286,7 +289,7 @@ def forward(self, feats: Tuple[Tensor, ...]) -> tuple: @MODELS.register_module() -class RTMDetInsSepBNHead(RTMDetHead): +class RTMDetInsHead(RTMDetHead): """RTMDet Instance Segmentation head. Args: @@ -343,6 +346,7 @@ def __init__(self, if isinstance(self.head_module, RTMDetInsSepBNHeadModule): assert self.use_sigmoid_cls == self.head_module.use_sigmoid_cls self.loss_mask = MODELS.build(loss_mask) + self.mask_loss_stride = 4 def predict_by_feat(self, cls_scores: List[Tensor], @@ -428,7 +432,7 @@ def predict_by_feat(self, # flatten cls_scores, bbox_preds flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, - self.num_classes) + self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ @@ -719,7 +723,170 @@ def loss_by_feat( self, cls_scores: List[Tensor], bbox_preds: List[Tensor], + kernel_preds: List[Tensor], + mask_feats: Tensor, batch_gt_instances: InstanceList, + batch_gt_masks: Tensor, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None) -> dict: - raise NotImplementedError + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Decoded box for each scale + level with shape (N, num_anchors * 4, H, W) in + [tl_x, tl_y, br_x, br_y] format. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_gt_masks (list[Tensor]): Batch of gt masks. Has shape + (num_instance, H, W). + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(batch_img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs) + gt_labels = gt_info[:, :, :1] + gt_bboxes = gt_info[:, :, 1:] # xyxy + pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0) + # downsample gt masks + batch_gt_masks = batch_gt_masks[:, self.mask_loss_stride // + 2::self.mask_loss_stride, + self.mask_loss_stride // + 2::self.mask_loss_stride] + + device = cls_scores[0].device + + # If the shape does not equal, generate new one + if featmap_sizes != self.featmap_sizes_train: + self.featmap_sizes_train = featmap_sizes + mlvl_priors_with_stride = self.prior_generator.grid_priors( + featmap_sizes, device=device, with_stride=True) + self.flatten_priors_train = torch.cat( + mlvl_priors_with_stride, dim=0) + + flatten_cls_scores = torch.cat([ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ], 1).contiguous() + + flatten_bboxes = torch.cat([ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for bbox_pred in bbox_preds + ], 1) + flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1, + None] + flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2], + flatten_bboxes) + flatten_kernels = torch.cat([ + kernel_pred.permute(0, 2, 3, 1).reshape( + num_imgs, -1, self.head_module.num_gen_params) + for kernel_pred in kernel_preds + ], 1) + + assigned_result = self.assigner(flatten_bboxes.detach(), + flatten_cls_scores.detach(), + self.flatten_priors_train, gt_labels, + gt_bboxes, pad_bbox_flag.float()) + + labels = assigned_result['assigned_labels'].reshape(-1) + label_weights = assigned_result['assigned_labels_weights'].reshape(-1) + bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4) + assign_metrics = assigned_result['assign_metrics'].reshape(-1) + cls_preds = flatten_cls_scores.reshape(-1, self.num_classes) + bbox_preds = flatten_bboxes.reshape(-1, 4) + kernels = flatten_kernels.reshape(-1, self.head_module.num_gen_params) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item() + + loss_cls = self.loss_cls( + cls_preds, (labels, assign_metrics), + label_weights, + avg_factor=avg_factor) + + if len(pos_inds) > 0: + loss_bbox = self.loss_bbox( + bbox_preds[pos_inds], + bbox_targets[pos_inds], + weight=assign_metrics[pos_inds], + avg_factor=avg_factor) + else: + loss_bbox = bbox_preds.sum() * 0 + + # --------mask loss-------- + num_pos = len(pos_inds) + num_pos = reduce_mean(mask_feats.new_tensor([num_pos + ])).clamp_(min=1).item() + if len(pos_inds) > 0: + + pos_kernels = kernels[pos_inds] + matched_gt_inds = assigned_result['assigned_gt_inds'] + batch_index = assigned_result['assigned_batch_index'] + + if num_imgs > 1: + # remapping the padded batch index to the original index + index_shift = pad_bbox_flag.int().sum((1, 2)).cumsum(0) + index_shift = torch.cat( + [index_shift.new_zeros(1), index_shift[:-1]]) + all_index_shift = ( + pad_bbox_flag * + index_shift[:, None, None])[batch_index, + matched_gt_inds].reshape(-1) + matched_gt_inds = matched_gt_inds + all_index_shift + mask_targets = batch_gt_masks[matched_gt_inds] + pos_mask_feats = mask_feats[batch_index] + pos_priors = self.flatten_priors_train.repeat(num_imgs, + 1)[pos_inds] + + h, w = pos_mask_feats.size()[-2:] + coord = self.prior_generator.single_level_grid_priors( + (h, w), level_idx=0, + device=pos_mask_feats.device).reshape(1, -1, 2) + num_inst = pos_priors.shape[0] + points = pos_priors[:, :2].reshape(-1, 1, 2) + strides = pos_priors[:, 2:].reshape(-1, 1, 2) + relative_coord = (points - coord).permute(0, 2, 1) / ( + strides[..., 0].reshape(-1, 1, 1) * 8) + relative_coord = relative_coord.reshape(num_inst, 2, h, w) + + pos_mask_feats = torch.cat([relative_coord, pos_mask_feats], dim=1) + weights, biases = self.parse_dynamic_params(pos_kernels) + + n_layers = len(weights) + x = pos_mask_feats.reshape(1, -1, h, w) + for i, (weight, bias) in enumerate(zip(weights, biases)): + x = F.conv2d( + x, weight, bias=bias, stride=1, padding=0, groups=num_inst) + if i < n_layers - 1: + x = F.relu(x) + pos_mask_logits = x.reshape(num_inst, h, w) + scale = self.prior_generator.strides[0][0] // self.mask_loss_stride + pos_mask_logits = F.interpolate( + pos_mask_logits.unsqueeze(0), + scale_factor=scale, + mode='bilinear', + align_corners=False).squeeze(0) + loss_mask = self.loss_mask( + pos_mask_logits, mask_targets, weight=None, avg_factor=num_pos) + + else: + loss_mask = mask_feats.sum() * 0 + + return dict( + loss_cls=loss_cls, loss_bbox=loss_bbox, loss_mask=loss_mask) diff --git a/mmyolo/models/dense_heads/yolov5_head.py b/mmyolo/models/dense_heads/yolov5_head.py index c49d08518..93ff9e041 100644 --- a/mmyolo/models/dense_heads/yolov5_head.py +++ b/mmyolo/models/dense_heads/yolov5_head.py @@ -460,8 +460,10 @@ def loss(self, x: Tuple[Tensor], batch_data_samples: Union[list, else: outs = self(x) # Fast version - loss_inputs = outs + (batch_data_samples['bboxes_labels'], - batch_data_samples['img_metas']) + loss_inputs = outs + (batch_data_samples['bboxes_labels'], ) + if 'masks' in batch_data_samples: + loss_inputs += (batch_data_samples['masks'], ) + loss_inputs += (batch_data_samples['img_metas'], ) losses = self.loss_by_feat(*loss_inputs) return losses diff --git a/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py b/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py index 5ae0f8023..2dbbfd700 100644 --- a/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py +++ b/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py @@ -224,7 +224,9 @@ def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor, assigned_labels=assigned_labels, assigned_labels_weights=assigned_labels_weights, assigned_bboxes=assigned_bboxes, - assign_metrics=assign_metrics) + assign_metrics=assign_metrics, + assigned_gt_inds=matched_gt_inds, + assigned_batch_index=batch_index) def dynamic_k_matching( self, cost_matrix: Tensor, pairwise_ious: Tensor, @@ -269,4 +271,7 @@ def dynamic_k_matching( matched_pred_ious = (matching_matrix * pairwise_ious).sum(2)[fg_mask_inboxes] matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) + + # pad_bbox_flag.sum(-2).cumsum(0).reshape(-1) + return matched_pred_ious, matched_gt_inds, fg_mask_inboxes From fea6489a60fe131e7bcb1667c8f90d2a22b654f8 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 6 Mar 2023 14:24:57 +0800 Subject: [PATCH 02/13] use einsum --- mmyolo/models/dense_heads/rtmdet_ins_head.py | 42 ++++++++++++++++---- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index aa2e61d55..a5269aed9 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -22,6 +22,7 @@ from mmyolo.registry import MODELS from ..utils import gt_instances_preprocess from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule +from mmengine.utils.dl_utils import TimeCounter class MaskFeatModule(BaseModule): @@ -214,7 +215,8 @@ def _init_layers(self): for i in range(self.stacked_convs): self.cls_convs[n][i].conv = self.cls_convs[0][i].conv self.reg_convs[n][i].conv = self.reg_convs[0][i].conv - self.kernel_convs[n][i].conv = self.kernel_convs[0][i].conv + # TODO: verify whether it is correct + # self.kernel_convs[n][i].conv = self.kernel_convs[0][i].conv self.mask_head = MaskFeatModule( in_channels=self.in_channels, @@ -718,6 +720,29 @@ def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple: bias_splits[i] = bias_splits[i].reshape(n_inst) return weight_splits, bias_splits + + def parse_dynamic_params2(self, flatten_kernels: Tensor) -> tuple: + """split kernel head prediction to conv weight and bias.""" + n_inst = flatten_kernels.size(0) + n_layers = len(self.head_module.weight_nums) + params_splits = list( + torch.split_with_sizes( + flatten_kernels, + self.head_module.weight_nums + self.head_module.bias_nums, + dim=1)) + weight_splits = params_splits[:n_layers] + bias_splits = params_splits[n_layers:] + for i in range(n_layers): + if i < n_layers - 1: + weight_splits[i] = weight_splits[i].reshape( + n_inst, self.head_module.dyconv_channels, -1) + bias_splits[i] = bias_splits[i].reshape( + n_inst, self.head_module.dyconv_channels) + else: + weight_splits[i] = weight_splits[i].reshape(n_inst, 1, -1) + bias_splits[i] = bias_splits[i].reshape(n_inst, 1) + + return weight_splits, bias_splits def loss_by_feat( self, @@ -832,7 +857,7 @@ def loss_by_feat( # --------mask loss-------- num_pos = len(pos_inds) num_pos = reduce_mean(mask_feats.new_tensor([num_pos - ])).clamp_(min=1).item() + ])).clamp_(min=1).item() if len(pos_inds) > 0: pos_kernels = kernels[pos_inds] @@ -852,7 +877,7 @@ def loss_by_feat( mask_targets = batch_gt_masks[matched_gt_inds] pos_mask_feats = mask_feats[batch_index] pos_priors = self.flatten_priors_train.repeat(num_imgs, - 1)[pos_inds] + 1)[pos_inds] h, w = pos_mask_feats.size()[-2:] coord = self.prior_generator.single_level_grid_priors( @@ -864,24 +889,25 @@ def loss_by_feat( relative_coord = (points - coord).permute(0, 2, 1) / ( strides[..., 0].reshape(-1, 1, 1) * 8) relative_coord = relative_coord.reshape(num_inst, 2, h, w) - pos_mask_feats = torch.cat([relative_coord, pos_mask_feats], dim=1) - weights, biases = self.parse_dynamic_params(pos_kernels) + weights, biases = self.parse_dynamic_params2(pos_kernels) n_layers = len(weights) - x = pos_mask_feats.reshape(1, -1, h, w) + x = pos_mask_feats for i, (weight, bias) in enumerate(zip(weights, biases)): - x = F.conv2d( - x, weight, bias=bias, stride=1, padding=0, groups=num_inst) + x = torch.einsum('nij,njhw->nihw', weight, x) + x = x + bias[:, :, None, None] if i < n_layers - 1: x = F.relu(x) pos_mask_logits = x.reshape(num_inst, h, w) + scale = self.prior_generator.strides[0][0] // self.mask_loss_stride pos_mask_logits = F.interpolate( pos_mask_logits.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) + loss_mask = self.loss_mask( pos_mask_logits, mask_targets, weight=None, avg_factor=num_pos) From fa50e55df13eb1327a19b35c629ae0cb273323e1 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 6 Mar 2023 14:30:52 +0800 Subject: [PATCH 03/13] update s config --- ...tmdet-ins_s_syncbn_fast_8xb32-300e_coco.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py index 4acbde6cb..505b47e6c 100644 --- a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py +++ b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py @@ -31,3 +31,80 @@ widen_factor=widen_factor, ), bbox_head=dict(head_module=dict(widen_factor=widen_factor))) + + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + mask2bbox=_base_.use_mask2refine), + dict( + type='Mosaic', + img_scale=img_scale, + use_cached=True, + max_cached_images=mosaic_max_cached_images, + pad_val=114.0), + dict( + type='mmdet.RandomResize', + # img_scale is (width, height) + scale=(img_scale[0] * 2, img_scale[1] * 2), + ratio_range=random_resize_ratio_range, + resize_type='mmdet.Resize', + keep_ratio=True), + dict( + type='mmdet.RandomCrop', + crop_size=img_scale, + recompute_bbox=_base_.use_mask2refine, + allow_negative_crop=True), + dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict(type='mmdet.RandomFlip', prob=0.5), + dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), + dict( + type='YOLOv5MixUp', + use_cached=True, + max_cached_images=mixup_max_cached_images), + dict(type='mmdet.PackDetInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + mask2bbox=_base_.use_mask2refine), + dict( + type='mmdet.RandomResize', + scale=img_scale, + ratio_range=random_resize_ratio_range, + resize_type='mmdet.Resize', + keep_ratio=True), + dict( + type='mmdet.RandomCrop', + crop_size=img_scale, + recompute_bbox=_base_.use_mask2refine, + allow_negative_crop=True), + dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict(type='mmdet.RandomFlip', prob=0.5), + dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), + dict(type='mmdet.PackDetInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + strict_load=False, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2, + switch_pipeline=train_pipeline_stage2) +] From d8bf81adc51029c73da6ae1fa0e78ab85381bc70 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 6 Mar 2023 20:52:32 +0800 Subject: [PATCH 04/13] downsample in pipeline --- ...tmdet-ins_s_syncbn_fast_8xb32-300e_coco.py | 1 + mmyolo/datasets/transforms/transforms.py | 30 +++++++++++++++++++ mmyolo/datasets/utils.py | 4 +-- mmyolo/models/dense_heads/rtmdet_ins_head.py | 10 +++---- 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py index 505b47e6c..a9a9d1305 100644 --- a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py +++ b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py @@ -66,6 +66,7 @@ type='YOLOv5MixUp', use_cached=True, max_cached_images=mixup_max_cached_images), + dict(type='Mask2Tensor', downsample_stride=4), dict(type='mmdet.PackDetInputs') ] diff --git a/mmyolo/datasets/transforms/transforms.py b/mmyolo/datasets/transforms/transforms.py index d5179fba3..5aad80b8b 100644 --- a/mmyolo/datasets/transforms/transforms.py +++ b/mmyolo/datasets/transforms/transforms.py @@ -1555,3 +1555,33 @@ def transform(self, results: dict) -> dict: results['gt_bboxes'] = self.box_type( results['gt_bboxes'].regularize_boxes(self.angle_version)) return results + + +@TRANSFORMS.register_module() +class Mask2Tensor(BaseTransform): + """Convert mask to tensor. + + Required Keys: + + - gt_masks (Masks) + + Modified Keys: + + - gt_masks + """ + def __init__(self, downsample_stride=1) -> None: + assert downsample_stride >= 1 + # downsample_stride should be divisible by 2 + assert downsample_stride % 2 == 0 + self.downsample_stride = downsample_stride + + + def transform(self, results: dict) -> dict: + mask = results['gt_masks'].to_tensor(dtype=torch.bool, device='cpu') + if self.downsample_stride > 1: + mask = mask[:, self.downsample_stride // + 2::self.downsample_stride, + self.downsample_stride // + 2::self.downsample_stride] + results['gt_masks'] = mask + return results diff --git a/mmyolo/datasets/utils.py b/mmyolo/datasets/utils.py index 62fe5484b..5a76936cb 100644 --- a/mmyolo/datasets/utils.py +++ b/mmyolo/datasets/utils.py @@ -28,8 +28,8 @@ def yolov5_collate(data_batch: Sequence, gt_bboxes = datasamples.gt_instances.bboxes.tensor gt_labels = datasamples.gt_instances.labels if 'masks' in datasamples.gt_instances: - masks = datasamples.gt_instances.masks.to_tensor( - dtype=torch.bool, device=gt_bboxes.device) + masks = datasamples.gt_instances.masks#.to_tensor( + #dtype=torch.bool, device=gt_bboxes.device) batch_masks.append(masks) batch_idx = gt_labels.new_full((len(gt_labels), 1), i) bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes), diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index a5269aed9..d7c086243 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -785,11 +785,11 @@ def loss_by_feat( gt_labels = gt_info[:, :, :1] gt_bboxes = gt_info[:, :, 1:] # xyxy pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0) - # downsample gt masks - batch_gt_masks = batch_gt_masks[:, self.mask_loss_stride // - 2::self.mask_loss_stride, - self.mask_loss_stride // - 2::self.mask_loss_stride] + # # downsample gt masks + # batch_gt_masks = batch_gt_masks[:, self.mask_loss_stride // + # 2::self.mask_loss_stride, + # self.mask_loss_stride // + # 2::self.mask_loss_stride] device = cls_scores[0].device From 2af73c30ed1e264c7ac706c1c86dff97c6cdb786 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 6 Mar 2023 21:04:56 +0800 Subject: [PATCH 05/13] add mask2tensor to l cfg --- configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py index cf45b9a99..72f9c5fca 100644 --- a/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py +++ b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py @@ -2,7 +2,7 @@ # ========================Frequently modified parameters====================== # -----data related----- -data_root = 'data/coco100/' +data_root = 'data/coco/' # Path of train annotation file train_ann_file = 'annotations/instances_train2017.json' train_data_prefix = 'train2017/' # Prefix of train image path @@ -190,6 +190,7 @@ type='YOLOv5MixUp', use_cached=True, max_cached_images=mixup_max_cached_images), + dict(type='Mask2Tensor', downsample_stride=4), dict(type='mmdet.PackDetInputs') ] From a4a0af7311a4bd9c4808dcb16eb98afff78dcd2b Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Mar 2023 14:13:09 +0800 Subject: [PATCH 06/13] refactor mask process --- mmyolo/models/dense_heads/rtmdet_ins_head.py | 109 +++++++++---------- 1 file changed, 49 insertions(+), 60 deletions(-) diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index d7c086243..21b8d205b 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -17,12 +17,12 @@ from mmengine.model import (BaseModule, bias_init_with_prob, constant_init, normal_init) from mmengine.structures import InstanceData +from mmengine.utils.dl_utils import TimeCounter from torch import Tensor from mmyolo.registry import MODELS from ..utils import gt_instances_preprocess from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule -from mmengine.utils.dl_utils import TimeCounter class MaskFeatModule(BaseModule): @@ -423,13 +423,7 @@ def predict_by_feat(self, with_stride=True) self.featmap_sizes = featmap_sizes flatten_priors = torch.cat(self.mlvl_priors) - - mlvl_strides = [ - flatten_priors.new_full( - (featmap_size.numel() * self.num_base_priors, ), stride) for - featmap_size, stride in zip(featmap_sizes, self.featmap_strides) - ] - flatten_stride = torch.cat(mlvl_strides) + flatten_stride = flatten_priors[:, -1] # flatten cls_scores, bbox_preds flatten_cls_scores = [ @@ -612,9 +606,9 @@ def _bbox_mask_post_process( results = results[:cfg.max_per_img] # process masks - mask_logits = self._mask_predict_by_feat(mask_feat, - results.kernels, - results.priors) + mask_logits = self._mask_predict_by_feat( + mask_feat.repeat(len(results), 1, 1, 1), results.kernels, + results.priors) stride = self.prior_generator.strides[0][0] mask_logits = F.interpolate( @@ -664,6 +658,7 @@ def _mask_predict_by_feat(self, mask_feat: Tensor, kernels: Tensor, Tensor: Instance segmentation masks for each instance. Has shape (num_instance, H, W). """ + # import ipdb; ipdb.set_trace() num_inst = kernels.shape[0] h, w = mask_feat.size()[-2:] if num_inst < 1: @@ -671,28 +666,20 @@ def _mask_predict_by_feat(self, mask_feat: Tensor, kernels: Tensor, size=(num_inst, h, w), dtype=mask_feat.dtype, device=mask_feat.device) - if len(mask_feat.shape) < 4: - mask_feat.unsqueeze(0) - - coord = self.prior_generator.single_level_grid_priors( - (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2) - num_inst = priors.shape[0] - points = priors[:, :2].reshape(-1, 1, 2) - strides = priors[:, 2:].reshape(-1, 1, 2) - relative_coord = (points - coord).permute(0, 2, 1) / ( - strides[..., 0].reshape(-1, 1, 1) * 8) - relative_coord = relative_coord.reshape(num_inst, 2, h, w) - - mask_feat = torch.cat( - [relative_coord, - mask_feat.repeat(num_inst, 1, 1, 1)], dim=1) - weights, biases = self.parse_dynamic_params(kernels) + + coord = self.mlvl_priors[0][:, :2] + relative_coord = (priors[:, None, :2] - coord[None, ...]) / ( + priors[:, -1, None, None] * 8) + relative_coord = relative_coord.permute(0, 2, + 1).reshape(num_inst, 2, h, w) + mask_feat = torch.cat([relative_coord, mask_feat], dim=1) + weights, biases = self.parse_dynamic_params2(kernels) n_layers = len(weights) - x = mask_feat.reshape(1, -1, h, w) + x = mask_feat for i, (weight, bias) in enumerate(zip(weights, biases)): - x = F.conv2d( - x, weight, bias=bias, stride=1, padding=0, groups=num_inst) + x = torch.einsum('nij,njhw->nihw', weight, x) + x = x + bias[:, :, None, None] if i < n_layers - 1: x = F.relu(x) x = x.reshape(num_inst, h, w) @@ -720,7 +707,7 @@ def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple: bias_splits[i] = bias_splits[i].reshape(n_inst) return weight_splits, bias_splits - + def parse_dynamic_params2(self, flatten_kernels: Tensor) -> tuple: """split kernel head prediction to conv weight and bias.""" n_inst = flatten_kernels.size(0) @@ -796,10 +783,12 @@ def loss_by_feat( # If the shape does not equal, generate new one if featmap_sizes != self.featmap_sizes_train: self.featmap_sizes_train = featmap_sizes - mlvl_priors_with_stride = self.prior_generator.grid_priors( - featmap_sizes, device=device, with_stride=True) - self.flatten_priors_train = torch.cat( - mlvl_priors_with_stride, dim=0) + self.mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device, + with_stride=True) + self.flatten_priors_train = torch.cat(self.mlvl_priors, dim=0) flatten_cls_scores = torch.cat([ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, @@ -857,10 +846,12 @@ def loss_by_feat( # --------mask loss-------- num_pos = len(pos_inds) num_pos = reduce_mean(mask_feats.new_tensor([num_pos - ])).clamp_(min=1).item() + ])).clamp_(min=1).item() if len(pos_inds) > 0: pos_kernels = kernels[pos_inds] + pos_priors = self.flatten_priors_train.repeat(num_imgs, + 1)[pos_inds] matched_gt_inds = assigned_result['assigned_gt_inds'] batch_index = assigned_result['assigned_batch_index'] @@ -876,38 +867,36 @@ def loss_by_feat( matched_gt_inds = matched_gt_inds + all_index_shift mask_targets = batch_gt_masks[matched_gt_inds] pos_mask_feats = mask_feats[batch_index] - pos_priors = self.flatten_priors_train.repeat(num_imgs, - 1)[pos_inds] - - h, w = pos_mask_feats.size()[-2:] - coord = self.prior_generator.single_level_grid_priors( - (h, w), level_idx=0, - device=pos_mask_feats.device).reshape(1, -1, 2) - num_inst = pos_priors.shape[0] - points = pos_priors[:, :2].reshape(-1, 1, 2) - strides = pos_priors[:, 2:].reshape(-1, 1, 2) - relative_coord = (points - coord).permute(0, 2, 1) / ( - strides[..., 0].reshape(-1, 1, 1) * 8) - relative_coord = relative_coord.reshape(num_inst, 2, h, w) - pos_mask_feats = torch.cat([relative_coord, pos_mask_feats], dim=1) - weights, biases = self.parse_dynamic_params2(pos_kernels) - - n_layers = len(weights) - x = pos_mask_feats - for i, (weight, bias) in enumerate(zip(weights, biases)): - x = torch.einsum('nij,njhw->nihw', weight, x) - x = x + bias[:, :, None, None] - if i < n_layers - 1: - x = F.relu(x) - pos_mask_logits = x.reshape(num_inst, h, w) + pos_mask_logits = self._mask_predict_by_feat( + pos_mask_feats, pos_kernels, pos_priors) scale = self.prior_generator.strides[0][0] // self.mask_loss_stride pos_mask_logits = F.interpolate( pos_mask_logits.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) + + # # visualize mask and gt mask + # from mmcv import imshow + # import numpy as np + # import cv2 + # for idx, (mask, gt_mask) in enumerate(zip(pos_mask_logits, mask_targets)): + # print('instance_id:', idx) + # print('batch_idx:', batch_index[idx]) + # mask = mask.sigmoid().detach().cpu().numpy() * 255 + # mask = mask.astype(np.uint8) + # mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + + # gt_mask = gt_mask.detach().cpu().numpy().astype(np.uint8) * 255 + # gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_GRAY2BGR) + # gt_bbox = bbox_targets[pos_inds][idx] / 4 + # cv2.rectangle(gt_mask, (int(gt_bbox[0]), int(gt_bbox[1])), (int(gt_bbox[2]), int(gt_bbox[3])), (0, 0, 255), 1) + # concat_mask = np.concatenate([mask, gt_mask], axis=1) + # imshow(concat_mask, win_name='mask_and_gt', wait_time=0) + + loss_mask = self.loss_mask( pos_mask_logits, mask_targets, weight=None, avg_factor=num_pos) From fe715b9bd005b12549be1bedc289c36a2938ba8a Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Mar 2023 14:34:36 +0800 Subject: [PATCH 07/13] add viz code --- mmyolo/models/dense_heads/rtmdet_ins_head.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index 21b8d205b..4e7458ce0 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -881,9 +881,20 @@ def loss_by_feat( # from mmcv import imshow # import numpy as np # import cv2 + # h, w = mask_feats.size()[-2:] + # coord = self.mlvl_priors[0][:, :2] + # relative_coord = (pos_priors[:, None, :2] - coord[None, ...]) / ( + # pos_priors[:, -1, None, None] * 8) + # relative_coord = relative_coord.permute(0, 2, + # 1).reshape(len(pos_inds), 2, h, w) + # for idx, (mask, gt_mask) in enumerate(zip(pos_mask_logits, mask_targets)): # print('instance_id:', idx) # print('batch_idx:', batch_index[idx]) + # relative_coord_1= relative_coord[idx][0].detach().cpu().numpy() + # relative_coord_2= relative_coord[idx][1].detach().cpu().numpy() + # concat_coord = np.concatenate([relative_coord_1, relative_coord_2], axis=1) + # imshow(concat_coord, win_name='relative_coord', wait_time=1) # mask = mask.sigmoid().detach().cpu().numpy() * 255 # mask = mask.astype(np.uint8) # mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) @@ -896,7 +907,6 @@ def loss_by_feat( # concat_mask = np.concatenate([mask, gt_mask], axis=1) # imshow(concat_mask, win_name='mask_and_gt', wait_time=0) - loss_mask = self.loss_mask( pos_mask_logits, mask_targets, weight=None, avg_factor=num_pos) From a8a9d440eac39e3ea1edb56d77f1db49c2e4e6a8 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Mar 2023 14:42:25 +0800 Subject: [PATCH 08/13] fix pipeline2 mask2tensor --- configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py | 1 + configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py | 1 + 2 files changed, 2 insertions(+) diff --git a/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py index 72f9c5fca..b17723ef2 100644 --- a/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py +++ b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py @@ -216,6 +216,7 @@ dict(type='mmdet.YOLOXHSVRandomAug'), dict(type='mmdet.RandomFlip', prob=0.5), dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), + dict(type='Mask2Tensor', downsample_stride=4), dict(type='mmdet.PackDetInputs') ] diff --git a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py index a9a9d1305..7ae06c414 100644 --- a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py +++ b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py @@ -92,6 +92,7 @@ dict(type='mmdet.YOLOXHSVRandomAug'), dict(type='mmdet.RandomFlip', prob=0.5), dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), + dict(type='Mask2Tensor', downsample_stride=4), dict(type='mmdet.PackDetInputs') ] From 9f922bd10980f5f55bdf975aa7efe7708d1f0942 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Mar 2023 15:05:42 +0800 Subject: [PATCH 09/13] lint --- mmyolo/datasets/transforms/transforms.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mmyolo/datasets/transforms/transforms.py b/mmyolo/datasets/transforms/transforms.py index 5aad80b8b..2b073b672 100644 --- a/mmyolo/datasets/transforms/transforms.py +++ b/mmyolo/datasets/transforms/transforms.py @@ -1569,19 +1569,17 @@ class Mask2Tensor(BaseTransform): - gt_masks """ + def __init__(self, downsample_stride=1) -> None: assert downsample_stride >= 1 # downsample_stride should be divisible by 2 assert downsample_stride % 2 == 0 self.downsample_stride = downsample_stride - def transform(self, results: dict) -> dict: mask = results['gt_masks'].to_tensor(dtype=torch.bool, device='cpu') if self.downsample_stride > 1: - mask = mask[:, self.downsample_stride // - 2::self.downsample_stride, - self.downsample_stride // - 2::self.downsample_stride] + mask = mask[:, self.downsample_stride // 2::self.downsample_stride, + self.downsample_stride // 2::self.downsample_stride] results['gt_masks'] = mask return results From b02bf5a7fbcd989197b7f70eaf6cf0f2a7c4a1ad Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Mar 2023 17:09:44 +0800 Subject: [PATCH 10/13] center of mask --- mmyolo/models/dense_heads/rtmdet_ins_head.py | 57 +++++++++++++++---- .../assigners/batch_dsl_assigner.py | 16 ++++-- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index 4e7458ce0..41fa8ec3f 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -643,8 +643,11 @@ def _bbox_mask_post_process( device=results.bboxes.device) return results - def _mask_predict_by_feat(self, mask_feat: Tensor, kernels: Tensor, - priors: Tensor) -> Tensor: + def _mask_predict_by_feat(self, + mask_feat: Tensor, + kernels: Tensor, + priors: Tensor, + training=False) -> Tensor: """Generate mask logits from mask features with dynamic convs. Args: @@ -666,8 +669,10 @@ def _mask_predict_by_feat(self, mask_feat: Tensor, kernels: Tensor, size=(num_inst, h, w), dtype=mask_feat.dtype, device=mask_feat.device) - - coord = self.mlvl_priors[0][:, :2] + if self.training: + coord = self.mlvl_priors_train[0][:, :2] + else: + coord = self.mlvl_priors[0][:, :2] relative_coord = (priors[:, None, :2] - coord[None, ...]) / ( priors[:, -1, None, None] * 8) relative_coord = relative_coord.permute(0, 2, @@ -783,12 +788,13 @@ def loss_by_feat( # If the shape does not equal, generate new one if featmap_sizes != self.featmap_sizes_train: self.featmap_sizes_train = featmap_sizes - self.mlvl_priors = self.prior_generator.grid_priors( + self.mlvl_priors_train = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device, with_stride=True) - self.flatten_priors_train = torch.cat(self.mlvl_priors, dim=0) + self.flatten_priors_train = torch.cat( + self.mlvl_priors_train, dim=0) flatten_cls_scores = torch.cat([ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, @@ -810,10 +816,17 @@ def loss_by_feat( for kernel_pred in kernel_preds ], 1) + # get mask center for assigner + gt_centers = torch.zeros_like(gt_bboxes[..., :2]) + gt_centers[pad_bbox_flag.squeeze( + 2)] = _get_mask_center(batch_gt_masks) * self.mask_loss_stride + # import ipdb; ipdb.set_trace() + assigned_result = self.assigner(flatten_bboxes.detach(), flatten_cls_scores.detach(), - self.flatten_priors_train, gt_labels, - gt_bboxes, pad_bbox_flag.float()) + self.flatten_priors_train, + gt_labels, gt_bboxes, + pad_bbox_flag.float(), gt_centers) labels = assigned_result['assigned_labels'].reshape(-1) label_weights = assigned_result['assigned_labels_weights'].reshape(-1) @@ -869,20 +882,20 @@ def loss_by_feat( pos_mask_feats = mask_feats[batch_index] pos_mask_logits = self._mask_predict_by_feat( - pos_mask_feats, pos_kernels, pos_priors) + pos_mask_feats, pos_kernels, pos_priors, training=True) scale = self.prior_generator.strides[0][0] // self.mask_loss_stride pos_mask_logits = F.interpolate( pos_mask_logits.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) - + # # visualize mask and gt mask # from mmcv import imshow # import numpy as np # import cv2 # h, w = mask_feats.size()[-2:] - # coord = self.mlvl_priors[0][:, :2] + # coord = self.mlvl_priors_train[0][:, :2] # relative_coord = (pos_priors[:, None, :2] - coord[None, ...]) / ( # pos_priors[:, -1, None, None] * 8) # relative_coord = relative_coord.permute(0, 2, @@ -901,7 +914,7 @@ def loss_by_feat( # gt_mask = gt_mask.detach().cpu().numpy().astype(np.uint8) * 255 # gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_GRAY2BGR) - + # gt_bbox = bbox_targets[pos_inds][idx] / 4 # cv2.rectangle(gt_mask, (int(gt_bbox[0]), int(gt_bbox[1])), (int(gt_bbox[2]), int(gt_bbox[3])), (0, 0, 255), 1) # concat_mask = np.concatenate([mask, gt_mask], axis=1) @@ -915,3 +928,23 @@ def loss_by_feat( return dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_mask=loss_mask) + + +def _get_mask_center(masks: Tensor, eps: float = 1e-7) -> Tensor: + """Compute the masks center of mass. + + Args: + masks: Mask tensor, has shape (num_masks, H, W). + eps: a small number to avoid normalizer to be zero. + Defaults to 1e-7. + Returns: + Tensor: The masks center of mass. Has shape (num_masks, 2). + """ + n, h, w = masks.shape + grid_h = torch.arange(h, device=masks.device)[:, None] + grid_w = torch.arange(w, device=masks.device) + normalizer = masks.sum(dim=(1, 2)).float().clamp(min=eps) + center_y = (masks * grid_h).sum(dim=(1, 2)) / normalizer + center_x = (masks * grid_w).sum(dim=(1, 2)) / normalizer + center = torch.cat([center_x[:, None], center_y[:, None]], dim=1) + return center diff --git a/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py b/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py index 2dbbfd700..270e24f78 100644 --- a/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py +++ b/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py @@ -119,9 +119,14 @@ def __init__( self.batch_iou = batch_iou @torch.no_grad() - def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor, - gt_labels: Tensor, gt_bboxes: Tensor, - pad_bbox_flag: Tensor) -> dict: + def forward(self, + pred_bboxes: Tensor, + pred_scores: Tensor, + priors: Tensor, + gt_labels: Tensor, + gt_bboxes: Tensor, + pad_bbox_flag: Tensor, + gt_centers: Tensor = None) -> dict: num_gt = gt_bboxes.size(1) decoded_bboxes = pred_bboxes batch_size, num_bboxes, box_dim = decoded_bboxes.size() @@ -155,11 +160,12 @@ def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor, # (B, N_points) valid_mask = is_in_gts.sum(dim=-1) > 0 - gt_center = get_box_center(gt_bboxes, box_dim) + if gt_centers is None: + gt_centers = get_box_center(gt_bboxes, box_dim) strides = priors[..., 2] distance = (priors[None].unsqueeze(2)[..., :2] - - gt_center[:, None, :, :] + gt_centers[:, None, :, :] ).pow(2).sum(-1).sqrt() / strides[None, :, None] # prevent overflow From dd184766bf4679cb6e45f138350268804581adf5 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 8 Mar 2023 13:57:51 +0800 Subject: [PATCH 11/13] add downsample stride --- ...tmdet-ins_l_syncbn_fast_8xb32-300e_coco.py | 8 ++++-- ...tmdet-ins_s_syncbn_fast_8xb32-300e_coco.py | 5 ++-- mmyolo/models/dense_heads/rtmdet_ins_head.py | 28 ++----------------- 3 files changed, 10 insertions(+), 31 deletions(-) diff --git a/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py index b17723ef2..40ab1e274 100644 --- a/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py +++ b/configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py @@ -52,6 +52,7 @@ val_num_workers = 10 use_mask2refine = True copypaste_prob = 0.3 +mask_downsample_stride = 4 # Config of batch shapes. Only on val. batch_shapes_cfg = dict( @@ -143,7 +144,8 @@ type='mmdet.DiceLoss', loss_weight=loss_mask_weight, eps=5e-6, - reduction='mean')), + reduction='mean'), + mask_loss_stride=mask_downsample_stride), train_cfg=dict( assigner=dict( type='BatchDynamicSoftLabelAssigner', @@ -190,7 +192,7 @@ type='YOLOv5MixUp', use_cached=True, max_cached_images=mixup_max_cached_images), - dict(type='Mask2Tensor', downsample_stride=4), + dict(type='Mask2Tensor', downsample_stride=mask_downsample_stride), dict(type='mmdet.PackDetInputs') ] @@ -216,7 +218,7 @@ dict(type='mmdet.YOLOXHSVRandomAug'), dict(type='mmdet.RandomFlip', prob=0.5), dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), - dict(type='Mask2Tensor', downsample_stride=4), + dict(type='Mask2Tensor', downsample_stride=mask_downsample_stride), dict(type='mmdet.PackDetInputs') ] diff --git a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py index 7ae06c414..00dcb2c04 100644 --- a/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py +++ b/configs/rtmdet_ins/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py @@ -32,7 +32,6 @@ ), bbox_head=dict(head_module=dict(widen_factor=widen_factor))) - train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), dict( @@ -66,7 +65,7 @@ type='YOLOv5MixUp', use_cached=True, max_cached_images=mixup_max_cached_images), - dict(type='Mask2Tensor', downsample_stride=4), + dict(type='Mask2Tensor', downsample_stride=_base_.mask_downsample_stride), dict(type='mmdet.PackDetInputs') ] @@ -92,7 +91,7 @@ dict(type='mmdet.YOLOXHSVRandomAug'), dict(type='mmdet.RandomFlip', prob=0.5), dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))), - dict(type='Mask2Tensor', downsample_stride=4), + dict(type='Mask2Tensor', downsample_stride=_base_.mask_downsample_stride), dict(type='mmdet.PackDetInputs') ] diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index 41fa8ec3f..1e2da8e03 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -330,6 +330,7 @@ def __init__(self, loss_weight=2.0, eps=5e-6, reduction='mean'), + mask_loss_stride=4, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, init_cfg: OptMultiConfig = None): @@ -348,7 +349,7 @@ def __init__(self, if isinstance(self.head_module, RTMDetInsSepBNHeadModule): assert self.use_sigmoid_cls == self.head_module.use_sigmoid_cls self.loss_mask = MODELS.build(loss_mask) - self.mask_loss_stride = 4 + self.mask_loss_stride = mask_loss_stride def predict_by_feat(self, cls_scores: List[Tensor], @@ -678,7 +679,7 @@ def _mask_predict_by_feat(self, relative_coord = relative_coord.permute(0, 2, 1).reshape(num_inst, 2, h, w) mask_feat = torch.cat([relative_coord, mask_feat], dim=1) - weights, biases = self.parse_dynamic_params2(kernels) + weights, biases = self.parse_dynamic_params(kernels) n_layers = len(weights) x = mask_feat @@ -691,29 +692,6 @@ def _mask_predict_by_feat(self, return x def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple: - """split kernel head prediction to conv weight and bias.""" - n_inst = flatten_kernels.size(0) - n_layers = len(self.head_module.weight_nums) - params_splits = list( - torch.split_with_sizes( - flatten_kernels, - self.head_module.weight_nums + self.head_module.bias_nums, - dim=1)) - weight_splits = params_splits[:n_layers] - bias_splits = params_splits[n_layers:] - for i in range(n_layers): - if i < n_layers - 1: - weight_splits[i] = weight_splits[i].reshape( - n_inst * self.head_module.dyconv_channels, -1, 1, 1) - bias_splits[i] = bias_splits[i].reshape( - n_inst * self.head_module.dyconv_channels) - else: - weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1) - bias_splits[i] = bias_splits[i].reshape(n_inst) - - return weight_splits, bias_splits - - def parse_dynamic_params2(self, flatten_kernels: Tensor) -> tuple: """split kernel head prediction to conv weight and bias.""" n_inst = flatten_kernels.size(0) n_layers = len(self.head_module.weight_nums) From f7017fe28628770d5ca799b498a83830bc5fc551 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 8 Mar 2023 19:42:29 +0800 Subject: [PATCH 12/13] fix mask decode --- mmyolo/models/dense_heads/rtmdet_ins_head.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index 1e2da8e03..61298f620 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import math from typing import List, Optional, Tuple import numpy as np @@ -509,6 +510,7 @@ def predict_by_feat(self, if rescale: if pad_param is not None: + # import ipdb; ipdb.set_trace() results.bboxes -= results.bboxes.new_tensor([ pad_param[2], pad_param[0], pad_param[2], pad_param[0] ]) @@ -617,20 +619,18 @@ def _bbox_mask_post_process( if rescale_mask: # TODO: When use mmdet.Resize or mmdet.Pad, will meet bug # Use img_meta to crop and resize + scale_factor = [1 / s for s in img_meta['scale_factor']] ori_h, ori_w = img_meta['ori_shape'][:2] - if isinstance(pad_param, np.ndarray): - pad_param = pad_param.astype(np.int32) - crop_y1, crop_y2 = pad_param[ - 0], mask_logits.shape[-2] - pad_param[1] - crop_x1, crop_x2 = pad_param[ - 2], mask_logits.shape[-1] - pad_param[3] - mask_logits = mask_logits[..., crop_y1:crop_y2, - crop_x1:crop_x2] + pad_param = pad_param.astype(np.int32) + mask_logits = mask_logits[..., pad_param[0]:, pad_param[2]:] mask_logits = F.interpolate( mask_logits, - size=[ori_h, ori_w], + size=[ + math.ceil(mask_logits.shape[-2] * scale_factor[0]), + math.ceil(mask_logits.shape[-1] * scale_factor[1]) + ], mode='bilinear', - align_corners=False) + align_corners=False)[..., :ori_h, :ori_w] masks = mask_logits.sigmoid().squeeze(0) masks = masks > cfg.mask_thr_binary From ff0431b369146ef888a228608e81239a074ab842 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 8 Mar 2023 19:46:06 +0800 Subject: [PATCH 13/13] share k head --- mmyolo/models/dense_heads/rtmdet_ins_head.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmyolo/models/dense_heads/rtmdet_ins_head.py b/mmyolo/models/dense_heads/rtmdet_ins_head.py index 61298f620..2caf75503 100644 --- a/mmyolo/models/dense_heads/rtmdet_ins_head.py +++ b/mmyolo/models/dense_heads/rtmdet_ins_head.py @@ -216,8 +216,7 @@ def _init_layers(self): for i in range(self.stacked_convs): self.cls_convs[n][i].conv = self.cls_convs[0][i].conv self.reg_convs[n][i].conv = self.reg_convs[0][i].conv - # TODO: verify whether it is correct - # self.kernel_convs[n][i].conv = self.kernel_convs[0][i].conv + self.kernel_convs[n][i].conv = self.kernel_convs[0][i].conv self.mask_head = MaskFeatModule( in_channels=self.in_channels,