From 2d869bb2432c4efc8a6b6b6fa402f661784db07b Mon Sep 17 00:00:00 2001 From: lnex Date: Tue, 11 Jun 2024 13:05:41 +0800 Subject: [PATCH] fix: line order in RobustLoss (#44) --- roma/losses/robust_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/roma/losses/robust_loss.py b/roma/losses/robust_loss.py index 853cadf..3ac3a60 100644 --- a/roma/losses/robust_loss.py +++ b/roma/losses/robust_loss.py @@ -49,10 +49,10 @@ def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale): G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99] + certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) if not torch.any(cls_loss): cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere - - certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) + losses = { f"gm_certainty_loss_{scale}": certainty_loss.mean(), f"gm_cls_loss_{scale}": cls_loss.mean(),