Skip to content

关于main.py中获取out_a与out_v相关操作的疑惑 #39

@XuecWu

Description

@XuecWu

首先感谢贵团队的辛勤付出!
我的疑问存在于main.py中(line80-93)在获取模型的输出a, v, out之后,通过矩阵相乘(torch.mm)方法来获取到out_v以及out_a。具体如下:
image
当我使用sum方法与concat方法时,该部分代码会正确运行。然而,当我使用FiLM以及Gated Fusion方法的时候,代码理应执行”else“部分的操作。但是却发生了报错,报错的主要原因是进行矩阵相乘操作的时候mat1与mat2的维度不相同。
以FiLM方法为例:
image
设定目标种类数为n_classes=8
视觉特征张量v的shape为bs, 512
音频特征张量a的shape为bs, 512

out_v = (torch.mm(v, torch.transpose(model.module.fusion_module.fc_out.weight[:, weight_size // 2:], 0, 1)) + model.module.fusion_module.fc_out.bias / 2)
out_a = (torch.mm(a, torch.transpose(model.module.fusion_module.fc_out.weight[:, :weight_size // 2], 0, 1)) + model.module.fusion_module.fc_out.bias / 2)
model.module.fusion_module.fc_out.weight的shape为[n_classes, input_dim],即为8以及512。
经过上述的操作之后,
torch.transpose(model.module.fusion_module.fc_out.weight[:, weight_size // 2:], 0, 1)的shape为256, 8
torch.transpose(model.module.fusion_module.fc_out.weight[:, :weight_size // 2], 0, 1)的shape亦为256, 8
然而v与a的shape均为bs, 512。因此会出现维度不匹配的相关报错。

基于上述观察,我将if opt.fusion_method == 'sum':时的代码移植到当前的情况中来,移植后的代码如下:
out_v = (torch.mm(v, torch.transpose(model.module.fusion_module.fc_out.weight, 0, 1)) + model.module.fusion_module.fc_out.bias)
out_a = (torch.mm(a, torch.transpose(model.module.fusion_module.fc_out.weight, 0, 1)) + model.module.fusion_module.fc_out.bias)
我经过测试后发现,这样代码是可以正常运行的,但是这样的视觉分支与音频分支的准确率很低,在整体准确率可以达到44.4的时候,视觉分支与音频分支的准确率只分别有14.6以及9.2。
这对于我造成了困惑,整体上就是我运行repo中的代码不正确,之后进行了相关更改。代码可以成功运行,但是两个分支的准确率很不理想的问题。
我想询问是哪里出现了问题还是我理解上出现了偏差,希望得到贵团队的回复!

谢谢!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions