-
Notifications
You must be signed in to change notification settings - Fork 23
Description
首先感谢贵团队的辛勤付出!
我的疑问存在于main.py中(line80-93)在获取模型的输出a, v, out之后,通过矩阵相乘(torch.mm)方法来获取到out_v以及out_a。具体如下:

当我使用sum方法与concat方法时,该部分代码会正确运行。然而,当我使用FiLM以及Gated Fusion方法的时候,代码理应执行”else“部分的操作。但是却发生了报错,报错的主要原因是进行矩阵相乘操作的时候mat1与mat2的维度不相同。
以FiLM方法为例:

设定目标种类数为n_classes=8
视觉特征张量v的shape为bs, 512
音频特征张量a的shape为bs, 512
out_v = (torch.mm(v, torch.transpose(model.module.fusion_module.fc_out.weight[:, weight_size // 2:], 0, 1)) + model.module.fusion_module.fc_out.bias / 2)
out_a = (torch.mm(a, torch.transpose(model.module.fusion_module.fc_out.weight[:, :weight_size // 2], 0, 1)) + model.module.fusion_module.fc_out.bias / 2)
model.module.fusion_module.fc_out.weight的shape为[n_classes, input_dim],即为8以及512。
经过上述的操作之后,
torch.transpose(model.module.fusion_module.fc_out.weight[:, weight_size // 2:], 0, 1)的shape为256, 8
torch.transpose(model.module.fusion_module.fc_out.weight[:, :weight_size // 2], 0, 1)的shape亦为256, 8
然而v与a的shape均为bs, 512。因此会出现维度不匹配的相关报错。
基于上述观察,我将if opt.fusion_method == 'sum':时的代码移植到当前的情况中来,移植后的代码如下:
out_v = (torch.mm(v, torch.transpose(model.module.fusion_module.fc_out.weight, 0, 1)) + model.module.fusion_module.fc_out.bias)
out_a = (torch.mm(a, torch.transpose(model.module.fusion_module.fc_out.weight, 0, 1)) + model.module.fusion_module.fc_out.bias)
我经过测试后发现,这样代码是可以正常运行的,但是这样的视觉分支与音频分支的准确率很低,在整体准确率可以达到44.4的时候,视觉分支与音频分支的准确率只分别有14.6以及9.2。
这对于我造成了困惑,整体上就是我运行repo中的代码不正确,之后进行了相关更改。代码可以成功运行,但是两个分支的准确率很不理想的问题。
我想询问是哪里出现了问题还是我理解上出现了偏差,希望得到贵团队的回复!
谢谢!