-
Notifications
You must be signed in to change notification settings - Fork 0
defend against nans and Infs #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: tony/fix-statedict-prefix
Are you sure you want to change the base?
Conversation
NaN Bug Fix: Training Crashes with Invalid Numeric EntriesProblem SummaryTraining crashes after ~10-12 epochs with two related errors:
A previous fix attempt (commit d6d7900) added epsilon to IoU division but didn't solve the root cause. Stack TracesError 1: Invalid Numeric Entries in Hungarian Matcher (Epoch ~10)This error occurs in Error 2: Loss is NaN (Epoch ~12)This error occurs when the aggregated loss becomes NaN, triggering the check at Root Cause AnalysisWhy NaN/Inf Values AppearDuring training, model logits can become extreme (
When logits become # Example: sigmoid of extreme values
sigmoid(inf) = 1.0
sigmoid(-inf) = 0.0
# Focal loss with p=1 (from inf logit):
# cost = -α * (1-p)^γ * log(p) + (1-α) * p^γ * log(1-p)
# = -0.25 * 0^2 * 0 + 0.75 * 1^2 * log(0)
# = 0.75 * (-inf) = -inf
# GIoU with inf box coordinates returns NaNCritical Discovery:
|
| File | Changes |
|---|---|
sam3/train/matcher.py |
4 nan_to_num() calls in BinaryHungarianMatcherV2.forward() |
sam3/train/loss/loss_fns.py |
10 nan_to_num() calls across _dice_loss, sigmoid_focal_loss, IABCEMdetr.get_loss, Boxes.get_loss, SemanticSegCriterion.get_loss |
sam3/train/loss/sam3_loss.py |
1 nan_to_num() call in Sam3LossWrapper.forward() as final safety net |
Current Status
Status: Testing in Progress
Multiple rounds of fixes have been applied:
-
Round 1: Added NaN guards to the Hungarian matcher cost matrix computations
- Fixed the
ValueError: matrix contains invalid numeric entrieserror - Training progressed past epoch 10
- Fixed the
-
Round 2: Added NaN guards to core loss functions (
_dice_loss,sigmoid_focal_loss,Boxes.get_loss,IABCEMdetr.get_loss)- Training still failed at epoch 11 with
Loss is nan
- Training still failed at epoch 11 with
-
Round 3 (Current): Added comprehensive NaN guards to:
IABCEMdetr.get_loss(): Guards onloss_bceandpresence_lossSemanticSegCriterion.get_loss(): Guards onloss_presence,loss, andloss_diceSam3LossWrapper.forward(): Final safety net oncore_loss
The fix now has multiple layers of protection to catch NaN at every level of the loss computation pipeline.
Testing
Run the full training test to verify:
python tests/test_instance_segmentation_finetune.pyThe test should now complete all 40 epochs without NaN errors.
Debugging Tips
If NaN errors persist, add logging to identify the source:
# Add to Sam3LossWrapper.forward() before the nan_to_num call
if isinstance(total_losses[CORE_LOSS_KEY], torch.Tensor):
if not torch.isfinite(total_losses[CORE_LOSS_KEY]).all():
import logging
logging.warning(f"NaN/Inf detected in core_loss: {total_losses[CORE_LOSS_KEY]}")
for k, v in total_losses.items():
if isinstance(v, torch.Tensor) and not torch.isfinite(v).all():
logging.warning(f" NaN/Inf in {k}: {v}")This will help identify which specific loss component is producing NaN values.
Description
This is an attempt to fix train errors found in production.
We still have some trainings failing with stack traces like this one.
/datasets/jWj4jHcFzSROinhwByyZ/versions/7
more context on slack
stack trace below happens around epoch 10 after ~3h
This PR is the result of a "vibe-bugfixing" process with claude-code running the finetuning test on the dataset above.
Claude's explanations about its fixes are down below
Type of change
Please delete options that are not relevant.
How has this change been tested, please provide a testcase or example of how you tested the change?
Reproduced with problematic dataset then fixed with new code (pending - training still on epoch 23 rn)
Any specific deployment considerations
For example, documentation changes, usability, usage/costs, secrets, etc.
Docs