Skip to content

Latest commit

 

History

History
84 lines (55 loc) · 3.89 KB

File metadata and controls

84 lines (55 loc) · 3.89 KB

DETR是最早利用transformer结构的模型做object detection的模型

DETR的几个关键点:

  • position encoding
  • Hungarian matching,set prediction loss
  • CNN + transformer,端到端训练

DETR的模型结构与pipeline:

image-20210719161853311

DETR的基本流程:首先,利用CNN提取特征,然后,将image特征过一个transformer结构的enc-dec,得到一系列的bbox set prediction,然后,将bbox prediction与真实的gt bbox通过bipartite match建立loss函数,进行端到端的训练优化。简言之,DETR的基本思路就是将目标检测问题转化为bbox set prediction的过程,从而直接端到端训练。

具体的操作如下:

image-20210719164209652

首先,用CNN提特征,然后将feature map与固定的pos emb进行相加(feature map需要1x1 conv降维),得到的结果作为encoder的输入,encoder是一个transformer结构,得到各个位置编码后的结果。

然后,将encoder的结果输入decoder。注意,这里的decoder不是直接接受encoder的输入,然后做transformer的,而是定义一组可以学的object queries作为learnt pos emb,用来生成最后的bbox。这里的object queries的数量N是预先设定的,比如N=100就意味着最终要产生100个预测框,而如果没有这么多实际的object的话,那么就用背景补充,并且将该类别置为no object,相当于给背景也增加了一个类别。最后,将预测的bbox set与gt bbox set进行二分匹配,找到每个gt bbox对应的预测bbox,然后计算loss。

position encoding

DETR的位置编码用的是手动的sin和cos的编码方式,并且将图像的两个方向,即(i, j)都进行编码。

参考官方代码:(ref:https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py)

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        # 两个方向分别编码
        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        # 利用sin和cos进行计算
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

匈牙利算法实现set matching

代码中使用的是:scipy.optimize.linear_sum_assignment(cost_matrix),将pred boxes中的每个元素与gt boxes中的每个元素计算一个weight,然后按照assignment problem的方式找到最优匹配。