We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
大佬好,我在阅读源码过程中,发现在您的losses.py文件中的loss_labels函数中,独热编码好像好像有点问题,因为看sigmoid_focal_loss函数中的要求是target与input必须是相同维度,且值为1表示文本,值为0表示背景。您源码中的input shape应该是 [bs, num_queries, num_pts, 1], 但是如果按照您生成对应gt的独热编码代码,生成的shape是[bs, num_queries, num_pts, 1],但值全部都是0. 我看到源码在初始化target_classess矩阵的时候用的是num_class,也就是1. 同样的疑问在matcher中的计算分类的权重损失矩阵时也存在,在BoxHungarianMatcher()类中,cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids], pos_cost_class.shape = [bs,*num_queries, 1], tgt_ids由于是文本,应该都是1,这里还出现了数组越界的问题。不知道是否是我哪里理解得不正确,还请大佬能解答一下我的疑惑,非常感谢。
The text was updated successfully, but these errors were encountered:
因为原始的deformable detr中的num_classes是91,但这里的类别是1。在生成one-hot编码时候,以计算encoder的matcher和分类loss为例,tagets_class = torch.full(src_logits.shape[:-1], self.num_classes, dtype=torch.int64, device=src_logits.device) targets_classes.shape = [bs, num_queries],且值全部为1,而且后面taget_classes[idx] = target_classes_o,结果还是1 最后的target_class_onehot.scatter_(-1, target_classes.unsqueeze(-1), 1)的shape是[bs, num_query, 2],但只有[bs, num_query, 1]的值为1,最后target_class_onehot = target_classes_onehot[..., -1]的值全是0,没有起到编码作用
Sorry, something went wrong.
No branches or pull requests
大佬好,我在阅读源码过程中,发现在您的losses.py文件中的loss_labels函数中,独热编码好像好像有点问题,因为看sigmoid_focal_loss函数中的要求是target与input必须是相同维度,且值为1表示文本,值为0表示背景。您源码中的input shape应该是



[bs, num_queries, num_pts, 1], 但是如果按照您生成对应gt的独热编码代码,生成的shape是[bs, num_queries, num_pts, 1],但值全部都是0.
我看到源码在初始化target_classess矩阵的时候用的是num_class,也就是1. 同样的疑问在matcher中的计算分类的权重损失矩阵时也存在,在BoxHungarianMatcher()类中,cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids], pos_cost_class.shape = [bs,*num_queries, 1], tgt_ids由于是文本,应该都是1,这里还出现了数组越界的问题。不知道是否是我哪里理解得不正确,还请大佬能解答一下我的疑惑,非常感谢。
The text was updated successfully, but these errors were encountered: