Skip to content

Commit 6a4f6ab

Browse files
committed
fix interface bug
1 parent 7b455fb commit 6a4f6ab

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

criterions/comparison_methods.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, w=0.36, **kwargs):
8080
self.w = w
8181
self.fc = nn.Linear(kwargs.get('feat_dim', 128), 1).cuda()
8282

83-
def forward(self, f_feat, g_pred, labels, f_pred):
83+
def forward(self, f_feat, g_feat, labels, f_pred, g_pred):
8484
f_feat = f_feat.view(f_feat.shape[0], -1)
8585
f_pred = f_pred.view(f_pred.shape[0], -1)
8686
g_pred = g_pred.view(g_pred.shape[0], -1)
@@ -89,7 +89,7 @@ def forward(self, f_feat, g_pred, labels, f_pred):
8989
factor = F.softplus(factor)
9090
g_pred *= factor
9191

92-
loss = F.cross_entropy(f_pred+g_pred, labels)
92+
loss = F.cross_entropy(f_pred + g_pred, labels)
9393

9494
bias_lp = F.log_softmax(g_pred, 1)
9595
entropy = -(torch.exp(bias_lp) * bias_lp).sum(1).mean()

trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def _update_f(self, x, labels, update_outer_loop=True, loss_dict=None, prefix=''
337337
for g_idx, g_net in enumerate(self.model.g_nets):
338338
_g_preds, _g_feats = g_net(x)
339339

340-
_f_loss_indep = self.outer_criterion(f_feats, _g_feats, labels=labels, f_pred=preds)
340+
_f_loss_indep = self.outer_criterion(f_feats, _g_feats, labels=labels, f_pred=preds, g_pred=_g_preds)
341341
f_loss_indep += _f_loss_indep
342342

343343
loss_dict['{}f_loss_indep_g_{}'.format(prefix, g_idx)] = _f_loss_indep.item()

0 commit comments

Comments
 (0)