Skip to content

fix: restore cached AMP step context after no_grad workaround#21616

Open
littlebullGit wants to merge 4 commits intoLightning-AI:masterfrom
littlebullGit:fix/21611-autocast-cache-enabled
Open

fix: restore cached AMP step context after no_grad workaround#21616
littlebullGit wants to merge 4 commits intoLightning-AI:masterfrom
littlebullGit:fix/21611-autocast-cache-enabled

Conversation

@littlebullGit
Copy link
Copy Markdown
Contributor

@littlebullGit littlebullGit commented Mar 26, 2026

What does this PR do?

Fixes #21611

This PR addresses the AMP memory regression introduced when #20921 changed MixedPrecision.autocast_context_manager() to use
cache_enabled=False globally as a workaround for the nested no_grad() autocast cache-poisoning bug reported in #20644.

That workaround fixed the original correctness issue, but it also forced Lightning's step execution onto the uncached autocast
path. In workloads that call the same module repeatedly within one training_step (for example iterative decoding / RL-style
loops), this can cause repeated recasting of the same parameters and significant memory growth, which is the regression
reported in #21611.

This PR keeps the existing public behavior of MixedPrecision.autocast_context_manager() unchanged for compatibility, but
narrows the workaround in Lightning's internal runtime path:

  • autocast_context_manager() still uses cache_enabled=False
  • forward_context() now uses the cached autocast path again
  • when nested torch.no_grad() or torch.inference_mode() exits inside forward_context(), the autocast cache gets cleared
    so the original Computation graph not being built #20644 bug remains fixed

In short, this restores the cached AMP behavior for normal Lightning step execution while preserving the original nested
no_grad() safeguard.

Changes

  • keep MixedPrecision.autocast_context_manager() behavior unchanged
  • move the narrower workaround into MixedPrecision.forward_context()
  • clear the autocast cache after nested no_grad() / inference_mode() exits within forward_context()
  • add regression tests covering:
    • raw PyTorch autocast behavior with cache_enabled=True vs False
    • nested no_grad() inside Lightning AMP forward_context()
    • nested inference_mode() inside Lightning AMP forward_context()
    • restoration of patched grad-mode context managers after exiting forward_context()

📚 Documentation preview 📚: https://pytorch-lightning--21616.org.readthedocs.build/en/21616/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Mar 26, 2026
@littlebullGit littlebullGit force-pushed the fix/21611-autocast-cache-enabled branch from 1511c8a to af1af99 Compare March 26, 2026 21:32
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 26, 2026

Codecov Report

❌ Patch coverage is 96.66667% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 79%. Comparing base (612ab08) to head (3d5a8d6).
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (612ab08) and HEAD (3d5a8d6). Click for more details.

HEAD has 168 uploads less than BASE
Flag BASE (612ab08) HEAD (3d5a8d6)
cpu 84 42
lightning_fabric 27 0
pytest 42 0
python3.12 24 12
python 6 3
lightning 30 15
python3.11 12 6
python3.13 18 9
python3.12.7 18 9
python3.10 6 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21616     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       23898    23866     -32     
=========================================
- Hits        20678    18794   -1884     
- Misses       3220     5072   +1852     

Copy link
Copy Markdown
Collaborator

@deependujha deependujha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any specific reason for not having CUDA tests? Adding them would be valuable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

cache_enabled=False in autocast causes OOM regression for iterative decoding workloads

3 participants