From db561e0846ef03a5db743ab056df8dfad06c1362 Mon Sep 17 00:00:00 2001 From: Tony Franca Date: Fri, 28 Nov 2025 21:20:02 +0000 Subject: [PATCH] defend against nans and Infs --- sam3/train/loss/loss_fns.py | 21 +++++++++++++++++++++ sam3/train/loss/sam3_loss.py | 7 +++++++ sam3/train/matcher.py | 9 +++++++++ 3 files changed, 37 insertions(+) diff --git a/sam3/train/loss/loss_fns.py b/sam3/train/loss/loss_fns.py index e54608f2..cd3d1be0 100644 --- a/sam3/train/loss/loss_fns.py +++ b/sam3/train/loss/loss_fns.py @@ -113,6 +113,8 @@ def _dice_loss(inputs, targets, num_boxes, loss_on_multimask=False, reduce=True) numerator = 2 * (inputs * targets).sum(1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) + # Replace NaN/Inf in loss (dice loss should be in [0, 1]) + loss = torch.nan_to_num(loss, nan=1.0, posinf=1.0, neginf=0.0) if loss_on_multimask: return loss / num_boxes if not reduce: @@ -163,6 +165,9 @@ def sigmoid_focal_loss( alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss + # Replace NaN/Inf in loss (can occur with extreme logits) + loss = torch.nan_to_num(loss, nan=0.0, posinf=100.0, neginf=0.0) + if not reduce: return loss @@ -361,6 +366,8 @@ def get_loss(self, outputs, targets, indices, num_boxes): ) iou = box_ops.fast_diag_box_iou(src_boxes_xyxy, target_boxes_giou) + # Replace NaN/Inf in IoU (IoU should be in [0, 1]) + iou = torch.nan_to_num(iou, nan=0.0, posinf=1.0, neginf=0.0) t = prob[(indices[0], indices[1])] ** self.alpha * iou ** (1 - self.alpha) t = torch.clamp(t, 0.01).detach() positive_target_classes = target_classes.clone() @@ -503,6 +510,10 @@ def get_loss(self, outputs, targets, indices, num_boxes): task="binary", ) + # Replace NaN/Inf in losses + loss_bce = torch.nan_to_num(loss_bce, nan=0.0, posinf=100.0, neginf=0.0) + presence_loss = torch.nan_to_num(presence_loss, nan=0.0, posinf=100.0, neginf=0.0) + losses = { "loss_ce": loss_bce, "ce_f1": bce_f1, @@ -551,6 +562,8 @@ def get_loss(self, outputs, targets, indices, num_boxes): ) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + # Replace NaN/Inf in loss_bbox (can occur with extreme predicted boxes) + loss_bbox = torch.nan_to_num(loss_bbox, nan=1.0, posinf=1.0, neginf=0.0) losses = {} losses["loss_bbox"] = loss_bbox.sum() / num_boxes @@ -558,6 +571,8 @@ def get_loss(self, outputs, targets, indices, num_boxes): loss_giou = 1 - box_ops.fast_diag_generalized_box_iou( src_boxes_xyxy, target_boxes_giou ) + # Replace NaN/Inf in loss_giou (GIoU loss should be in [0, 2]) + loss_giou = torch.nan_to_num(loss_giou, nan=2.0, posinf=2.0, neginf=0.0) losses["loss_giou"] = loss_giou.sum() / num_boxes return losses @@ -1128,6 +1143,8 @@ def get_loss(self, out_dict, targets): # should also track presence_acc presence_acc = torch.tensor(0.0, device=loss.device) + # Replace NaN/Inf in presence loss + loss_presence = torch.nan_to_num(loss_presence, nan=0.0, posinf=100.0, neginf=0.0) loss_dict["loss_semantic_presence"] = loss_presence loss_dict["presence_acc"] = presence_acc @@ -1141,6 +1158,10 @@ def get_loss(self, out_dict, targets): loss = (loss * mask.float()).sum() / (nb_valid + 1e-6) loss_dice = (loss_dice * mask.float()).sum() / (nb_valid + 1e-6) + # Replace NaN/Inf in semantic segmentation losses + loss = torch.nan_to_num(loss, nan=0.0, posinf=100.0, neginf=0.0) + loss_dice = torch.nan_to_num(loss_dice, nan=1.0, posinf=1.0, neginf=0.0) + loss_dict.update( { "loss_semantic_seg": loss, diff --git a/sam3/train/loss/sam3_loss.py b/sam3/train/loss/sam3_loss.py index c2b90d6e..3ff92f95 100644 --- a/sam3/train/loss/sam3_loss.py +++ b/sam3/train/loss/sam3_loss.py @@ -200,4 +200,11 @@ def forward(self, find_stages: SAM3Output, find_targets): else: total_losses[k] += v + # Final safety check: replace any NaN/Inf in the core loss + # This catches any NaN that slipped through individual loss guards + if isinstance(total_losses[CORE_LOSS_KEY], torch.Tensor): + total_losses[CORE_LOSS_KEY] = torch.nan_to_num( + total_losses[CORE_LOSS_KEY], nan=0.0, posinf=1e6, neginf=0.0 + ) + return total_losses diff --git a/sam3/train/matcher.py b/sam3/train/matcher.py index fb22aecd..42570f84 100644 --- a/sam3/train/matcher.py +++ b/sam3/train/matcher.py @@ -569,11 +569,15 @@ def forward( # Compute the L1 cost between boxes cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + # Replace NaN/Inf with large values to prevent linear_sum_assignment failure + cost_bbox = torch.nan_to_num(cost_bbox, nan=1e6, posinf=1e6, neginf=1e6) # Compute the giou cost betwen boxes cost_giou = -generalized_box_iou( box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) ) + # Replace NaN/Inf (GIoU should be in [-1, 1], so cost_giou in [-1, 1]) + cost_giou = torch.nan_to_num(cost_giou, nan=0.0, posinf=2.0, neginf=-2.0) out_prob = self.norm(out_score) if not self.focal: @@ -596,6 +600,9 @@ def forward( if not self.stable: cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox) + # Replace NaN/Inf in cost_class (can occur with extreme logits) + cost_class = torch.nan_to_num(cost_class, nan=1e6, posinf=1e6, neginf=-1e6) + assert cost_class.shape == cost_bbox.shape # Final cost matrix @@ -604,6 +611,8 @@ def forward( + self.cost_class * cost_class + self.cost_giou * cost_giou ) + # Final safety check: replace any remaining NaN/Inf in the cost matrix + C = torch.nan_to_num(C, nan=1e9, posinf=1e9, neginf=-1e9) # assign a very high cost (1e9) to invalid outputs and targets, so that we can # filter them out (in `_do_matching`) from bipartite matching results do_filtering = out_is_valid is not None or target_is_valid_padded is not None