diff --git a/aia_net.py b/aia_net.py index c7818d9..b405205 100644 --- a/aia_net.py +++ b/aia_net.py @@ -186,7 +186,7 @@ def forward(self, input1, input2): input_ri = self.input(input_merge) for i in range(len(self.row_trans)): if i >=1: - output_mag_i = input_mag + output_list_ri[-1] + output_mag_i = output_list_mag[-1] + output_list_ri[-1] else: output_mag_i = input_mag AFA_input_mag = output_mag_i.permute(3, 0, 2, 1).contiguous().view(dim1, b*dim2, -1) # [F, B*T, c] AFA_output_mag = self.row_trans[i](AFA_input_mag) # [F, B*T, c] @@ -314,4 +314,4 @@ def forward(self, input_list): #X:BCTFG Y:B11G1 # model2 = AHAM(64) # output_mag, output_mag_list, output_ri, output_ri_list = model(x, x) # aham = model2(output_mag_list) -# print(str(aham.shape)) \ No newline at end of file +# print(str(aham.shape))