Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions sam3/train/loss/loss_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def _dice_loss(inputs, targets, num_boxes, loss_on_multimask=False, reduce=True)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
# Replace NaN/Inf in loss (dice loss should be in [0, 1])
loss = torch.nan_to_num(loss, nan=1.0, posinf=1.0, neginf=0.0)
if loss_on_multimask:
return loss / num_boxes
if not reduce:
Expand Down Expand Up @@ -163,6 +165,9 @@ def sigmoid_focal_loss(
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss

# Replace NaN/Inf in loss (can occur with extreme logits)
loss = torch.nan_to_num(loss, nan=0.0, posinf=100.0, neginf=0.0)

if not reduce:
return loss

Expand Down Expand Up @@ -361,6 +366,8 @@ def get_loss(self, outputs, targets, indices, num_boxes):
)

iou = box_ops.fast_diag_box_iou(src_boxes_xyxy, target_boxes_giou)
# Replace NaN/Inf in IoU (IoU should be in [0, 1])
iou = torch.nan_to_num(iou, nan=0.0, posinf=1.0, neginf=0.0)
t = prob[(indices[0], indices[1])] ** self.alpha * iou ** (1 - self.alpha)
t = torch.clamp(t, 0.01).detach()
positive_target_classes = target_classes.clone()
Expand Down Expand Up @@ -503,6 +510,10 @@ def get_loss(self, outputs, targets, indices, num_boxes):
task="binary",
)

# Replace NaN/Inf in losses
loss_bce = torch.nan_to_num(loss_bce, nan=0.0, posinf=100.0, neginf=0.0)
presence_loss = torch.nan_to_num(presence_loss, nan=0.0, posinf=100.0, neginf=0.0)

losses = {
"loss_ce": loss_bce,
"ce_f1": bce_f1,
Expand Down Expand Up @@ -551,13 +562,17 @@ def get_loss(self, outputs, targets, indices, num_boxes):
)

loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
# Replace NaN/Inf in loss_bbox (can occur with extreme predicted boxes)
loss_bbox = torch.nan_to_num(loss_bbox, nan=1.0, posinf=1.0, neginf=0.0)

losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes

loss_giou = 1 - box_ops.fast_diag_generalized_box_iou(
src_boxes_xyxy, target_boxes_giou
)
# Replace NaN/Inf in loss_giou (GIoU loss should be in [0, 2])
loss_giou = torch.nan_to_num(loss_giou, nan=2.0, posinf=2.0, neginf=0.0)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses

Expand Down Expand Up @@ -1128,6 +1143,8 @@ def get_loss(self, out_dict, targets):
# should also track presence_acc
presence_acc = torch.tensor(0.0, device=loss.device)

# Replace NaN/Inf in presence loss
loss_presence = torch.nan_to_num(loss_presence, nan=0.0, posinf=100.0, neginf=0.0)
loss_dict["loss_semantic_presence"] = loss_presence
loss_dict["presence_acc"] = presence_acc

Expand All @@ -1141,6 +1158,10 @@ def get_loss(self, out_dict, targets):
loss = (loss * mask.float()).sum() / (nb_valid + 1e-6)
loss_dice = (loss_dice * mask.float()).sum() / (nb_valid + 1e-6)

# Replace NaN/Inf in semantic segmentation losses
loss = torch.nan_to_num(loss, nan=0.0, posinf=100.0, neginf=0.0)
loss_dice = torch.nan_to_num(loss_dice, nan=1.0, posinf=1.0, neginf=0.0)

loss_dict.update(
{
"loss_semantic_seg": loss,
Expand Down
7 changes: 7 additions & 0 deletions sam3/train/loss/sam3_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,11 @@ def forward(self, find_stages: SAM3Output, find_targets):
else:
total_losses[k] += v

# Final safety check: replace any NaN/Inf in the core loss
# This catches any NaN that slipped through individual loss guards
if isinstance(total_losses[CORE_LOSS_KEY], torch.Tensor):
total_losses[CORE_LOSS_KEY] = torch.nan_to_num(
total_losses[CORE_LOSS_KEY], nan=0.0, posinf=1e6, neginf=0.0
)

return total_losses
9 changes: 9 additions & 0 deletions sam3/train/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,15 @@ def forward(

# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Replace NaN/Inf with large values to prevent linear_sum_assignment failure
cost_bbox = torch.nan_to_num(cost_bbox, nan=1e6, posinf=1e6, neginf=1e6)

# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
# Replace NaN/Inf (GIoU should be in [-1, 1], so cost_giou in [-1, 1])
cost_giou = torch.nan_to_num(cost_giou, nan=0.0, posinf=2.0, neginf=-2.0)

out_prob = self.norm(out_score)
if not self.focal:
Expand All @@ -596,6 +600,9 @@ def forward(
if not self.stable:
cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)

# Replace NaN/Inf in cost_class (can occur with extreme logits)
cost_class = torch.nan_to_num(cost_class, nan=1e6, posinf=1e6, neginf=-1e6)

assert cost_class.shape == cost_bbox.shape

# Final cost matrix
Expand All @@ -604,6 +611,8 @@ def forward(
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
# Final safety check: replace any remaining NaN/Inf in the cost matrix
C = torch.nan_to_num(C, nan=1e9, posinf=1e9, neginf=-1e9)
# assign a very high cost (1e9) to invalid outputs and targets, so that we can
# filter them out (in `_do_matching`) from bipartite matching results
do_filtering = out_is_valid is not None or target_is_valid_padded is not None
Expand Down