-
Notifications
You must be signed in to change notification settings - Fork 540
[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag #2311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers the entire PR, which introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The core implementation trades compute for memory by avoiding saving intermediate activations (ln_out, fc1_out, gelu_out) during forward passes and instead recomputing them during backward passes—crucially skipping fc2 recomputation since it's not needed for gradients. The module integrates with TransformerEngine's existing infrastructure (sequence_parallel, tensor_parallel, FP8 quantization) while explicitly disabling CPU offloading. The benchmark data shows dramatic memory reduction (64GB → 10GB forward memory in the example) at the cost of 20% slower backward passes. The change maintains full compatibility with existing LayerNormMLP configurations and is implemented as an additional module rather than modifying existing code.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 3/5 | Implements the new SelectiveLayerNormMLP module with selective checkpointing; contains commented-out FSDP/offloading code that needs verification |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 2/5 | Adds deferred init tests but has a critical bug: seq_length kwarg is not set for SelectiveLayerNormMLP (line 34 condition only checks LayerNormMLP) |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 4/5 | Distributed test file that creates LayerNormMLP instances instead of SelectiveLayerNormMLP (lines 403-404), so doesn't actually test the new feature |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 5/5 | Adds CUDA graph tests with proper module instantiation and parametrization; PR description notes tests fail with 4e-4 errors |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Minor formatting changes only: function signature reformatted and blank lines removed for style consistency |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Comprehensive sanity tests for the new module with minor formatting adjustments to match project style |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Formatting cleanup: removed duplicate imports and adjusted parameter list formatting |
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Benchmark script comparing LayerNormMLP vs SelectiveLayerNormMLP; formatting changes only (PEP8 compliance) |
Confidence score: 2/5
- This PR requires careful review due to critical bugs in test files and incomplete CUDA graph compatibility
- Score reflects three major issues: (1) test_deferred_init.py will fail due to missing
seq_lengthparameter handling, (2) distributed/run_numerics.py doesn't actually test the new module, (3) CUDA graph tests are failing per PR description, and (4) extensive commented-out code in the main implementation suggests incomplete feature development particularly around FSDP and CPU offloading - Pay close attention to transformer_engine/pytorch/module/selective_layernorm_mlp.py (commented sections lines 550-571, 915-929, 1545-1552), tests/pytorch/selective_layernorm_mlp/test_deferred_init.py (line 34 conditional logic), and tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (lines 403-404 model instantiation)
8 files reviewed, 2 comments
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the formatting and style changes made since the last review, not the entire PR. The changes consist entirely of whitespace adjustments: removing extraneous blank lines from seven test files (test_sanity.py, test_cuda_graphs.py, compare.py, test_recipe.py, test_numerics.py, and distributed/run_numerics.py) and the end-of-file blank line. These modifications align the new test suite with the project's formatting standards enforced by pre-commit hooks (black formatter, end-of-file-fixer, trailing-whitespace checker) and the .clang-format configuration that specifies 100-character column limits, 2-space indentation for C++, and strict whitespace rules. Since SelectiveLayerNormMLP is a memory-optimized variant of LayerNormMLP with identical API and forward-pass behavior, the test files are appropriately adapted copies of the existing LayerNormMLP test infrastructure, ensuring comprehensive coverage of FP8 recipes, distributed configurations, deferred initialization, and numerical accuracy without duplicating test logic.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Removed two extraneous blank lines (lines 131,161) for formatting consistency |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 5/5 | Removed trailing blank line at end of file |
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Removed leading blank line and reformatted long print statement to fit 100-char limit |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 4/5 | Contains duplicate import of transformer_engine_torch (lines 22 and 29) |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 5/5 | No visible formatting changes in this diff |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Removed three blank lines (lines 202, 273, 531) and reformatted function signature (lines 406-408) |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Removed three blank lines (lines 33, 383, 491) for consistent spacing |
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 5/5 | No whitespace-related changes; core implementation is unchanged |
Confidence score: 4/5
- This PR is generally safe to merge with respect to the formatting changes; however, the test_cuda_graphs.py failure and the duplicate import in test_recipe.py require attention before merging.
- Score reflects that the formatting changes are correct and align with project standards, but the duplicate import (test_recipe.py lines 22 and 29) should be removed, and the test_cuda_graphs.py failures mentioned in the PR description remain unresolved and may indicate numerical instability or a CUDA graph compatibility issue that requires investigation.
- Pay close attention to tests/pytorch/selective_layernorm_mlp/test_recipe.py (duplicate import) and confirm that test_cuda_graphs.py failures are understood and either resolved or deemed acceptable before merging.
8 files reviewed, 2 comments
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the most recent docstring correction in tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (line 391), updating the reference from "LayerNormMLP" to "SelectiveLayerNormMLP". This is a trivial documentation fix that completes the refactoring of existing tests to validate the new SelectiveLayerNormMLP module introduced by this PR. All substantive issues (duplicate import, missing condition logic, extra blank lines) were already flagged in earlier reviews and are not repeated here.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Corrected docstring to reference SelectiveLayerNormMLP instead of LayerNormMLP (line 391). |
Confidence score: 5/5
- This single‑line docstring fix is safe to merge with no risk; it improves documentation accuracy.
- Score remains perfect because the change is purely cosmetic—no logic, imports, or functionality is affected.
- No files require additional attention; earlier reviews have already identified all substantive issues in this test file.
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. A duplicate import of transformer_engine_torch as tex has been removed from the test file tests/pytorch/selective_layernorm_mlp/test_recipe.py, cleaning up the code. However, the file still contains the same import on line 22, indicating either an incomplete change or a subsequent revert not captured in the diff.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Removed duplicate import statement, no functional impact on test logic |
Confidence score: 5/5
- This change is safe to merge with minimal risk
- The duplicate import removal is a simple cleanup with no functional impact; the test logic remains identical and validation of FP8 quantizer behavior is unchanged
- No files require special attention
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The most recent updates focus on cleaning up test files for the new SelectiveLayerNormMLP module by removing unused imports and fixing minor formatting issues. The changes address previous review feedback by eliminating duplicate imports (like transformer_engine_torch as tex) and removing extraneous blank lines. However, several issues remain unresolved from earlier reviews, including the incomplete fix in test_deferred_init.py where the conditional logic in get_module_args still references unused module types (te.Linear, te.LayerNormLinear, te.MultiheadAttention, te.TransformerLayer), creating unreachable dead code. The CUDA graph test failures mentioned in the PR description (errors around 4e-4 and 0.1 absolute errors) persist and are acknowledged as incomplete work.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Removed numerous unused imports while preserving FP8 recipe test functionality |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 4/5 | Removed blank line but left dead code in get_module_args referencing unused module types |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 3/5 | Simplified to test only SelectiveLayerNormMLP; tests currently failing per PR description |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 4/5 | Cleaned up imports for focused SelectiveLayerNormMLP testing with comprehensive parameterization |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 4/5 | Removed unused imports and functions; minor formatting inconsistency remains |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Removed unused imports with no functional changes |
Confidence score: 3/5
- This PR requires careful review due to incomplete CUDA graph functionality and remaining code quality issues.
- Score reflects acknowledged test failures (test_cuda_graphs.py), dead code in helper methods (test_deferred_init.py line 25-43), and minor formatting inconsistencies. The changes are primarily test infrastructure cleanup, but the core feature is not yet production-ready.
- Pay close attention to test_cuda_graphs.py (numerical precision failures) and test_deferred_init.py (unreachable conditional branches for
te.Linear,te.MultiheadAttention,te.TransformerLayer).
6 files reviewed, 3 comments
d09f38b to
b601d83
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The forward pass saves only input tensors and weights rather than intermediate activations (fc1_out, gelu_out), reducing memory consumption by ~80% during training. During backward, FC1 and activation are recomputed, but FC2 recomputation is skipped since it's unnecessary for gradient calculation. The implementation maintains compatibility with TransformerEngine's quantization systems (FP8, MXFP8, NVFP4) and distributed features (tensor parallel, sequence parallel), though CPU offloading is explicitly disabled. The module integrates into the existing PyTorch API alongside LayerNormMLP and includes comprehensive test coverage mirroring the standard module's test suite. CUDA graph tests currently fail with ~4e-4 numerical differences, indicating potential interactions between graph capture and the selective recomputation strategy.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 4/5 | New 1000+ line module implementing selective checkpointing with bifurcated forward logic for training vs. recomputation |
| transformer_engine/pytorch/module/init.py | 5/5 | Adds SelectiveLayerNormMLP import to module's public API |
| transformer_engine/pytorch/init.py | 5/5 | Exposes SelectiveLayerNormMLP in top-level PyTorch API |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 1/5 | Critical bug: reference implementations initialize normalization weights to zeros instead of ones when zero_centered_gamma=False |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 4/5 | Missing import warnings causes NameError when zero-tensors are detected; otherwise sound distributed validation |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 4/5 | CUDA graph tests currently failing per PR description; unused import and incorrect return type annotation present |
| tests/pytorch/selective_layernorm_mlp/compare.py | 3/5 | Performance comparison script with reversed weight-copy direction (copies from SLN to LN instead of vice versa) |
| tests/pytorch/selective_layernorm_mlp/utils.py | 4/5 | Test utilities with incorrect return type annotation (declares 2-tuple but returns 3 elements) |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Comprehensive sanity tests covering dtypes, recipes, activations, normalizations, and microbatching |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 4/5 | Deferred initialization test with dead code for untested modules |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 4/5 | FP8 recipe validation test with unused capability-check imports |
| tests/pytorch/selective_layernorm_mlp/distributed/test_numerics.py | 4/5 | Distributed test wrapper with typo in docstring and unused variable |
Confidence score: 2/5
- This PR cannot be merged safely due to critical bugs in the test reference implementations and missing imports that will cause runtime failures.
- Score reflects: (1) test_numerics.py initializes reference normalization weights to zeros instead of ones, making all tests invalid; (2) distributed/run_numerics.py uses
warnings.warn()without importingwarnings; (3) CUDA graph tests are explicitly failing per PR description; (4) compare.py has reversed weight-copy direction; (5) multiple type annotation mismatches that will cause type-checking failures. - Pay close attention to tests/pytorch/selective_layernorm_mlp/test_numerics.py (lines 144-145, 170-171), distributed/run_numerics.py (line 34), and the core module's commented-out FSDP code which may indicate incomplete distributed functionality.
Sequence Diagram
sequenceDiagram
participant User
participant SelectiveLayerNormMLP
participant _SelectiveLayerNormMLP
participant ForwardPass
participant BackwardPass
participant Quantizers
participant GEMM
User->>SelectiveLayerNormMLP: forward(inp, is_first_microbatch)
SelectiveLayerNormMLP->>SelectiveLayerNormMLP: prepare_forward()
SelectiveLayerNormMLP->>SelectiveLayerNormMLP: _get_quantizers()
SelectiveLayerNormMLP->>Quantizers: Initialize quantizers
Quantizers-->>SelectiveLayerNormMLP: Return quantizers
SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=False)
Note over _SelectiveLayerNormMLP: Save tensors for backward (inp, weights, etc.)
_SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
alt sequence_parallel
_SelectiveLayerNormMLP->>ForwardPass: gather_along_first_dim()
ForwardPass-->>_SelectiveLayerNormMLP: ln_out_total
end
_SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
Note over GEMM: FC1 GEMM with optional gelu fusion
GEMM-->>_SelectiveLayerNormMLP: fc1_out
_SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
ForwardPass-->>_SelectiveLayerNormMLP: act_out
_SelectiveLayerNormMLP->>GEMM: general_gemm(fc2_weight, act_out)
Note over GEMM: FC2 GEMM
GEMM-->>_SelectiveLayerNormMLP: fc2_out
alt sequence_parallel
_SelectiveLayerNormMLP->>ForwardPass: reduce_scatter_along_first_dim()
ForwardPass-->>_SelectiveLayerNormMLP: fc2_out
end
_SelectiveLayerNormMLP-->>SelectiveLayerNormMLP: fc2_out
SelectiveLayerNormMLP-->>User: Return output
User->>BackwardPass: loss.backward()
BackwardPass->>_SelectiveLayerNormMLP: _recompute(ctx)
Note over _SelectiveLayerNormMLP: Restore saved tensors
_SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=True)
Note over _SelectiveLayerNormMLP: Recompute activations WITHOUT saving fc2
_SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
_SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
GEMM-->>_SelectiveLayerNormMLP: fc1_out
_SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
ForwardPass-->>_SelectiveLayerNormMLP: act_out
Note over _SelectiveLayerNormMLP: Skip FC2 forward (not needed for backward)
_SelectiveLayerNormMLP->>BackwardPass: backward(grad_outputs)
BackwardPass->>GEMM: general_gemm(fc2_weight, grad_output)
Note over GEMM: FC2 DGRAD
GEMM-->>BackwardPass: fc2_dgrad
BackwardPass->>ForwardPass: activation_func_bwd(fc2_dgrad, fc1_out)
ForwardPass-->>BackwardPass: dact
alt fc2_weight_requires_grad
BackwardPass->>GEMM: general_gemm(act_out, grad_output)
Note over GEMM: FC2 WGRAD
GEMM-->>BackwardPass: fc2_wgrad, fc2_bias_grad
end
BackwardPass->>GEMM: general_gemm(fc1_weight, dact)
Note over GEMM: FC1 DGRAD
GEMM-->>BackwardPass: fc1_dgrad
alt fc1_weight_requires_grad
BackwardPass->>GEMM: general_gemm(ln_out_total, dact)
Note over GEMM: FC1 WGRAD
GEMM-->>BackwardPass: fc1_wgrad, fc1_bias_grad
end
BackwardPass->>ForwardPass: layernorm_bwd/rmsnorm_bwd()
ForwardPass-->>BackwardPass: dgrad, dgamma, dbeta
BackwardPass-->>User: Return gradients
12 files reviewed, 13 comments
tests/pytorch/selective_layernorm_mlp/distributed/test_numerics.py
Outdated
Show resolved
Hide resolved
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer removed redundant nested if-condition logic in the deferred initialization test file for SelectiveLayerNormMLP. Previously, the code checked if module == te.SelectiveLayerNormMLP twice: once at line 29 and again at line 33 within the same block. This review captures the cleanup of that duplication, streamlining the conditional logic. The test file verifies that SelectiveLayerNormMLP can be instantiated on a 'meta' device with zero memory allocation and then properly materialized via reset_parameters().
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 5/5 | Removed redundant nested if-check for SelectiveLayerNormMLP module condition |
Confidence score: 5/5
- This PR is safe to merge with minimal risk
- Score reflects straightforward refactoring with no functional changes or edge cases introduced
- No files require special attention
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest commit fixes a critical bug in the test reference implementations (TorchLayerNorm and TorchRMSNorm) within tests/pytorch/selective_layernorm_mlp/test_numerics.py. Previously, the weight initialization logic was inverted: when zero_centered_gamma=False, the gamma parameter was incorrectly initialized to zeros instead of ones, which would cause the normalized output to be scaled incorrectly. The fix swaps the conditional branches so that zero_centered_gamma=False correctly initializes weights to ones (standard LayerNorm behavior), and zero_centered_gamma=True initializes to zeros (zero-centered variant). This correction ensures that the torch-based reference models produce the expected numerical outputs, allowing accurate validation of the SelectiveLayerNormMLP module against known-correct baselines. This change is isolated to the test infrastructure and does not affect the production implementation of SelectiveLayerNormMLP.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Fixed inverted weight initialization logic in TorchLayerNorm and TorchRMSNorm reference implementations |
Confidence score: 5/5
- This change is safe to merge—it fixes a clear bug in test reference code with no risk to production paths
- The fix corrects an inverted conditional that was causing reference models to initialize gamma weights incorrectly, ensuring accurate numerical validation going forward
- No files require special attention—the change is a straightforward two-line correction in test utilities
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer has addressed previous feedback by fixing the parameter-setting logic in test_deferred_init.py, correcting the weight-copy direction in compare.py, and refining timing/measurement code. However, several critical issues remain: a type annotation bug in test_cuda_graphs.py (get_nvfp4_inp_supported_dtypes returns List[torch.dtype] but is annotated as bool), unused imports (warnings in run_numerics.py, time in compare.py), and a typo in the PR description (ffn_fidden on line 261 of compare.py). The test files validate that SelectiveLayerNormMLP maintains numerical parity with LayerNormMLP across distributed and non-distributed configurations while significantly reducing memory usage by recomputing activations instead of caching them. The PR description notes that test_cuda_graphs.py is failing with numerical errors (typically 4e-4, occasionally 0.1 absolute), which correlates with the type annotation bug in that file.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 2/5 | Critical type annotation bug on line 71: function returns List[torch.dtype] but annotated as bool, causing type-checking failures and possibly contributing to test failures |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 4/5 | Adds comprehensive numerical validation comparing SelectiveLayerNormMLP against torch reference implementations across dtypes, activations, and FP8/NVFP4 recipes |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 3/5 | Validates distributed numerics for tensor/sequence parallelism; unused warnings import added but comment updated correctly |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 4/5 | Simplified parameter-setting logic by removing redundant conditional check; correct for single-module test file |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 4/5 | Standard sanity tests adapted for SelectiveLayerNormMLP; comprehensive coverage of configurations and FP8 recipes |
| tests/pytorch/selective_layernorm_mlp/compare.py | 3/5 | Benchmark script refactored to use CUDA events and per-config stats; unused time import remains and typo ffn_fidden on line 261 |
Confidence score: 3/5
- This PR requires careful attention due to a critical type annotation bug and known test failures (cuda_graphs), though the core numerical validation appears sound
- Score reflects: (1) critical type annotation error in
test_cuda_graphs.pythat breaks type checking and may contribute to reported test failures; (2) unused imports in two files suggesting incomplete cleanup; (3) PR description acknowledges cuda_graphs tests are failing with numerical errors, indicating incomplete functionality - Pay close attention to
test_cuda_graphs.py(line 71 type annotation) and investigate why CUDA graph tests fail with 4e-4 errors—the annotation bug may be masking logic errors in the supported-dtypes check
6 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The most recent changes apply code formatting to tests/pytorch/selective_layernorm_mlp/compare.py, reformatting a multi-line f-string descriptor to comply with line-length and style guidelines. While spacing, operator placement, and list formatting improvements are beneficial, a critical syntax error was introduced: the descriptor string is now missing its closing parenthesis after the f-string definition, preventing the script from running. No other changes were made to imports, logic, or functionality in this update.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/compare.py | 0/5 | Applied code formatting but introduced syntax error: missing closing parenthesis after multi-line f-string definition (line 261) |
Confidence score: 0/5
- This PR will fail to run due to a syntax error that prevents the comparison script from executing
- Score is zero because a closing parenthesis is missing after the f-string definition on line 261, causing a guaranteed SyntaxError when the script is imported or executed
- The file
tests/pytorch/selective_layernorm_mlp/compare.pyrequires immediate attention to add the missing closing parenthesis after line 262
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This update addresses a single-character syntax fix in the benchmark comparison script for SelectiveLayerNormMLP. The change adds a missing closing parenthesis on line 262 of tests/pytorch/selective_layernorm_mlp/compare.py that completes the multi-line f-string definition for the desc variable. This variable formats configuration metadata (sequence length, hidden size, FFN hidden size, and number of layers) that labels benchmark results comparing LayerNormMLP and SelectiveLayerNormMLP. Without this closing parenthesis, the script would raise a SyntaxError and fail to execute. The fix ensures the benchmarking tool can run properly and display properly formatted configuration descriptions in its output.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Added missing closing parenthesis to complete multi-line f-string definition |
Confidence score: 5/5
- This PR is safe to merge with minimal risk
- The change fixes a syntax error that would have prevented the script from running, and introduces no new logic or behavioral changes
- No files require special attention; this is a trivial syntax correction
1 file reviewed, no comments
|
Hi @jaimec00, thank you for the contribution! |
|
Thanks for the comment, @ptrendx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The most recent update adds a new SelectiveLayerNormMLP module that implements selective activation checkpointing, trading ~20% additional backward compute time for dramatic memory savings (up to 6x reduction in forward peak memory). The implementation introduces a checkpoint parameter that gates three key behaviors: (1) which tensors are saved during forward (only inputs/weights vs. full activations), (2) whether FC2 is recomputed during backward (it's skipped since FC2's output isn't needed for weight gradients), and (3) CPU offloading availability (disabled when checkpointing is active). The module closely mirrors LayerNormMLP's structure, reusing the same FP8/tensor-parallel/sequence-parallel infrastructure while adding conditional logic through save_for_checkpoint and is_recomputation flags. This new module integrates with the existing PyTorch module hierarchy under transformer_engine/pytorch/module/, alongside sibling modules like LayerNormMLP and LayerNormLinear.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 3/5 | Adds new SelectiveLayerNormMLP module with selective activation checkpointing that skips FC2 recomputation during backward, reducing memory at the cost of additional compute |
Confidence score: 3/5
- This PR requires careful review due to incomplete testing and a known CUDA graphs compatibility issue that causes numerical errors (4e-4 typical, 0.1 absolute in some cases)
- Score reflects the unresolved CUDA graphs test failures and the significant complexity of the selective checkpointing logic that conditionally saves/recomputes tensors across forward/backward passes—while the core implementation appears sound, the interaction between checkpoint flags, FP8 metadata tracking, and CUDA graph capture needs verification
- Pay close attention to the CUDA graphs test failures in test_cuda_graphs.py (lines documenting 4e-4 typical errors with occasional 0.1 absolute errors), the conditional tensor saving logic in lines 621-691 (which determines what gets saved vs. recomputed based on checkpoint/is_grad_enabled/save_for_checkpoint flags), and the _recompute method's dual-path behavior (lines 920-925) that either reruns forward with
recompute_for_bwd=Trueor loads saved tensors
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer has addressed the review feedback by correcting module references, removing dead code, fixing initialization logic for RMSNorm/LayerNorm gamma parameters, refactoring the comparison script to test SelectiveLayerNormMLP against itself (with/without checkpointing), and adding missing imports. The core implementation of selective activation checkpointing in SelectiveLayerNormMLP._forward remains unchanged. Test infrastructure has been updated to parametrize the new checkpoint flag across all test suites, though CUDA graph tests remain skipped when checkpoint=True.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 5/5 | Core implementation of selective activation checkpointing for LayerNormMLP; no changes since last review |
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Refactored to compare SelectiveLayerNormMLP(checkpoint=False) vs SelectiveLayerNormMLP(checkpoint=True) instead of comparing against LayerNormMLP |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Added checkpoint parameter to test matrix; fixed gamma initialization logic for RMSNorm/LayerNorm |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Added test cases for checkpoint=True and checkpoint=False; added missing warnings import |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Added checkpoint parameter to sanity test matrix |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Added checkpoint parameter to quantizer update test |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 5/5 | Added checkpoint parameter to deferred initialization tests |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 4/5 | Added checkpoint parameter but explicitly skips tests when checkpoint=True due to known failures |
Confidence score: 5/5
- This PR is safe to merge with minimal risk; the selective checkpointing implementation is sound and well-tested
- Score reflects that all previous review issues were addressed, tests pass (except known CUDA graph limitations), and the implementation follows TE patterns for custom autograd functions with memory management
- Pay close attention to
test_cuda_graphs.py—the developer explicitly asks for help with CUDA graph failures whencheckpoint=True, which remain unresolved and are currently skipped
8 files reviewed, 3 comments
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
…ributed/run_numerics.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Jaime <102792198+jaimec00@users.noreply.github.com> Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
…recipe.py Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest changes address previous review feedback by:
-
Fixing the skip message in test_cuda_graphs.py: Updated the skip message from "SelectiveLayerNorm" to "LayerNormMLP" to correctly reference the module being tested (lines 297-298). This aligns with the PR's approach of adding a
checkpointparameter directly toLayerNormMLPrather than creating a separate module. -
Minor formatting in compare.py: Collapsed lines 35-36 from two lines to one for improved readability. No functional changes were made to the comparison script, which benchmarks
LayerNormMLPwith and without the checkpoint flag to validate the performance claims (6X memory reduction, ~20% slower backward). -
Core implementation in layernorm_mlp.py: The main logic implements selective activation checkpointing by (a) adding a
recompute_for_bwdparameter to_forwardthat skips FC2 computation during recomputation, (b) saving minimal state when checkpointing is enabled, and (c) introducing a_recomputedispatcher that either re-executes the forward pass or loads saved tensors. The change bypasses CPU offloading and FSDP scattering whencheckpoint=Truesince there are no activations to offload.
The implementation integrates with TransformerEngine's existing FP8 quantization, distributed training (sequence/tensor parallelism), and autocast infrastructure. The primary trade-off is ~20% slower backward pass for ~83% memory reduction, enabling training of larger MLP models.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 5/5 | Fixed typo in skip message from "SelectiveLayerNorm" to "LayerNormMLP" |
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Minor formatting change collapsing module instantiation from 2 lines to 1 |
| transformer_engine/pytorch/module/layernorm_mlp.py | 3/5 | Core implementation of selective checkpointing with checkpoint flag, recompute_for_bwd parameter, and recomputation logic; potential issues with return type consistency and stale workspace comment |
Confidence score: 3/5
- This PR requires careful review before merging due to unresolved issues in the core implementation logic and incomplete CUDA graphs support
- Score reflects three main concerns: (1) return type inconsistency in
_recomputewhere the non-checkpoint branch returns a list but caller expects tuple unpacking, (2) workspace update comment at line 442 that references commented-out code and may be stale, and (3) CUDA graphs tests are explicitly skipped whencheckpoint=Truewhich means a known compatibility gap exists - Pay close attention to transformer_engine/pytorch/module/layernorm_mlp.py especially the
_recomputedispatcher (line 950), the FC2-skip condition (line 574), and the workspace update logic (line 442)
3 files reviewed, 2 comments
|
Hi @ptrendx, just wanted to give an update on the PR. I implemented selective activation checkpointing directly in LayerNormMLP, which can be controlled by initializing the module with a 'checkpoint' flag. When checkpoint=False, behavior is unchanged, and it is simply LayerNormMLP (all tests pass, INCLUDING test_cuda_graphs.py). When checkpoint=True, selective activation checkpointing is implemented, getting >6X memory reduction in the forward pass at the cost of 20% slower backward pass. All tests pass EXCEPT for test_cuda_graphs.py. I have started debugging this and I am confident that this is due to the quantizer state changing in between the initial forward pass and the recomputation forward pass. I am currently working on a fix for this. In the meantime, I was wondering if this feature is something that you and the people at NVIDIA would be interested in merging into the main repo. If so, are there any notes you have for me? Thanks! |
|
@jaimec00 yes, we are definitely interested in the functionality :-). It is still a little sad to me that pyTorch does not really give us a way to do that in a more generic way. I will review this PR either this or beginning of next week. |
|
@ptrendx perfect, thank you! |
…ing fwd. Float8CurrentScaling with bf16 and fp16 still fail with checkpointing Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint flag. When enabled, activations are recomputed during backward pass instead of being saved, reducing memory usage by ~83% with only ~20% slower backward pass. The implementation recomputes LayerNorm and FC1 operations but skips FC2 (not needed for gradients).
Key changes:
- Refactored
_LayerNormMLP.forward()into_forward()withrecompute_for_bwdparameter - Added
_recompute()method that either recomputes from saved inputs (checkpoint=True) or loads saved activations (checkpoint=False) - Clones all quantizers during forward pass and saves them to
ctx.other_argsfor recomputation - Added comprehensive test coverage including CUDA graph compatibility tests
Critical issues:
- Quantizer state cloning (lines 191-204, 1034-1037): The
_copy_quantizer()function creates independent copies ofscale/amaxtensors forFloat8CurrentScalingQuantizer. These cloned tensors capture a snapshot of the quantizer state during forward pass but don't reflect updates that occur during training. During backward recomputation, stale quantizer values are used, causing numerical mismatches. This is the root cause of CUDA graph test failures withFloat8CurrentScaling+checkpoint=True+ non-fp32 dtypes. - The workspace update logic (line 520-528) is not skipped during recomputation, but quantized weights may have changed between forward and backward passes when using dynamic quantization.
Confidence Score: 3/5
- This PR has a critical quantizer cloning issue that breaks Float8CurrentScaling with CUDA graphs, but is safe for other quantization recipes and non-graph usage
- Score of 3 reflects that the core checkpointing mechanism is well-implemented and tests pass for most scenarios, but the quantizer state management has a fundamental design flaw. The cloning approach creates stale copies that don't track live quantizer updates, breaking Float8CurrentScaling compatibility. The PR author acknowledges this limitation and skips affected test cases. The feature provides significant memory benefits (6x reduction) with acceptable performance tradeoff (~20% slower backward) for supported configurations.
- transformer_engine/pytorch/module/layernorm_mlp.py requires attention for the quantizer cloning logic (especially
_copy_quantizerfunction and its usage in recomputation)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 2/5 | Implements selective activation checkpointing with checkpoint flag. Critical issue: quantizer cloning creates stale state copies that break Float8CurrentScaling with CUDA graphs. Additional concerns with skipped workspace updates during recomputation. |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 4/5 | Test file for CUDA graph compatibility with checkpoint flag. Skip condition at line 299-307 correctly identifies known failure case (Float8CurrentScaling + checkpoint + non-fp32 dtypes). |
Sequence Diagram
sequenceDiagram
participant User
participant LayerNormMLP
participant Forward as _forward()
participant Backward as backward()
participant Recompute as _recompute()
User->>LayerNormMLP: forward(checkpoint=True)
LayerNormMLP->>Forward: Call with checkpoint=True, recompute_for_bwd=False
Note over Forward: save_for_checkpoint=True<br/>is_recomputation=False
Forward->>Forward: Save input tensors (inp, weights, bias)
Forward->>Forward: Clone all quantizers to ctx.other_args
Forward->>Forward: Compute LayerNorm
Forward->>Forward: Compute FC1 GEMM
Forward->>Forward: Apply activation (GELU)
Forward->>Forward: Compute FC2 GEMM
Forward->>Forward: Clear activation tensors (fc1_out, act_out)
Forward-->>LayerNormMLP: Return fc2_out
User->>LayerNormMLP: backward(grad_output)
LayerNormMLP->>Backward: Process gradients
Backward->>Recompute: Call _recompute(ctx)
alt checkpoint=True
Recompute->>Forward: Call _forward() with recompute_for_bwd=True
Note over Forward: is_recomputation=True<br/>Recompute from saved inputs
Forward->>Forward: Restore saved tensors from ctx
Forward->>Forward: Use cloned quantizers from ctx.other_args
Forward->>Forward: Re-execute LayerNorm
Forward->>Forward: Re-execute FC1 GEMM
Forward->>Forward: Re-execute activation
Forward->>Forward: Skip FC2 GEMM (not needed for backward)
Forward->>Forward: Save recomputed activations to ctx
Forward-->>Recompute: Return [ctx, activations...]
else checkpoint=False
Recompute->>Recompute: Load saved activations from ctx
Recompute-->>Backward: Return [ctx, activations...]
end
Backward->>Backward: Compute FC2 weight/input gradients
Backward->>Backward: Compute activation gradients
Backward->>Backward: Compute FC1 weight/input gradients
Backward->>Backward: Compute LayerNorm gradients
Backward-->>User: Return input/weight gradients
2 files reviewed, 2 comments
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, activations are not saved during forward pass; instead, LayerNorm and FC1 are recomputed during backward (FC2 is skipped as it's not needed for gradients).
Key changes:
- Refactored
forward()to call_forward()withrecompute_for_bwdflag - Added
_recompute()method that either loads saved activations (checkpoint=False) or recomputes from inputs (checkpoint=True) - Saves and restores autocast state and FP8 quantizer states (for DelayedScaling recipe) to ensure numerical consistency between forward and recompute
- Conditionally skips FC2, CPU offloading, and FSDP scattering when checkpointing is enabled
- Tests confirm CUDA graphs now work with checkpoint=True after fixing autocast/quantizer state management
Trade-offs:
- Memory: ~83% reduction (6X lower in forward pass)
- Compute: ~22% slower backward pass
- Numerics: Identical outputs (0.0 max absolute error)
Issue found:
backwards_needs_fc1_inputcondition doesn't match its comment—should includeis_recomputationcase
Confidence Score: 4/5
- This PR is safe to merge with one logical fix needed for the backwards_needs_fc1_input condition
- The implementation is well-tested (all tests pass including CUDA graphs), has impressive benchmarks showing proper memory/compute tradeoffs, and correctly handles complex state management for autocast and quantizers. The one logic issue with backwards_needs_fc1_input should be fixed before merge—the condition doesn't match the comment and may cause incorrect quantizer usage during recomputation. Otherwise, the checkpointing logic is sound: inputs are saved during forward, FC2 is correctly skipped during recompute (not needed for gradients), and state restoration ensures numerical consistency.
- transformer_engine/pytorch/module/layernorm_mlp.py needs fix for backwards_needs_fc1_input logic at line 330
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 4/5 | Implements selective activation checkpointing by refactoring forward into _forward, adding recomputation logic with autocast/quantizer state management, and conditionally skipping FC2 during backward recompute. Minor logic issue with backwards_needs_fc1_input condition. |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 5/5 | Test file for CUDA graphs with checkpoint parameter, correctly parametrized to test both checkpoint=True and checkpoint=False with various FP8 recipes and dtypes. |
Sequence Diagram
sequenceDiagram
participant User
participant Forward as forward()
participant FwdImpl as _forward()
participant Recompute as _recompute()
participant Backward as backward()
Note over User,Backward: checkpoint=True Path
User->>Forward: input, checkpoint=True
Forward->>FwdImpl: recompute_for_bwd=False
Note over FwdImpl: Save ctx.checkpoint=True
Note over FwdImpl: Save autocast state
Note over FwdImpl: Copy quantizer states (DelayedScaling only)
Note over FwdImpl: save_for_checkpoint=True
Note over FwdImpl: Save inputs: inp, ln_weight, ln_bias, fc1_weight, etc.
Note over FwdImpl: Compute LayerNorm
Note over FwdImpl: Compute FC1 + Activation
Note over FwdImpl: Skip FC2 caching (cleared immediately)
Note over FwdImpl: Compute FC2
FwdImpl-->>Forward: fc2_out
Forward-->>User: output
User->>Backward: grad_output
Backward->>Recompute: ctx.checkpoint=True
Note over Recompute: Restore saved inputs
Note over Recompute: Set autocast state to saved
Note over Recompute: Restore quantizer states (DelayedScaling)
Recompute->>FwdImpl: recompute_for_bwd=True
Note over FwdImpl: is_recomputation=True
Note over FwdImpl: Compute LayerNorm
Note over FwdImpl: Compute FC1 + Activation
Note over FwdImpl: Skip FC2 (not needed for backward)
FwdImpl-->>Recompute: [ctx, recomputed tensors]
Note over Recompute: Restore autocast state
Note over Recompute: Restore quantizer states
Recompute-->>Backward: tensors for backward
Note over Backward: Compute FC2 gradients
Note over Backward: Compute activation gradients
Note over Backward: Compute FC1 gradients
Note over Backward: Compute LayerNorm gradients
Backward-->>User: input_grad
2 files reviewed, 1 comment
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, it skips saving large intermediate activations (fc1_out, act_out) in the forward pass and instead recomputes them during backward, while cleverly skipping the FC2 computation during recomputation since it's not needed for gradient calculation.
Key Implementation Details:
- Forward pass with
checkpoint=Truesaves only 7 input tensors instead of 14 activation tensors, achieving ~83% memory reduction - Recomputation phase (
recompute_for_bwd=True) executes LayerNorm + FC1 + Activation, but skips FC2 - Properly handles autocast state and FP8 quantizer states for DelayedScaling recipe
- CPU offloading and FSDP scattering are bypassed when
checkpoint=True(no activations to offload)
Performance Tradeoffs:
- Forward: negligible overhead
- Backward: ~22% slower (recomputation cost)
- Memory: 6-8x reduction in forward pass, overall significant savings
Critical Issue Found:
- Line 829-835:
fc2_outis undefined whenis_recomputation=Trueand execution reaches the return statement, causingUnboundLocalErrorin edge cases whereis_grad_enabled=Falseduring recomputation
Confidence Score: 3/5
- This PR has a critical logic bug that could cause runtime errors in specific edge cases, but the core implementation is sound
- Score reflects well-designed selective checkpointing mechanism with proper FP8 state management and comprehensive test coverage. However, critical issue at line 829-835 where fc2_out is undefined during recomputation could cause UnboundLocalError. The bug is localized and fixable, but must be resolved before merge.
- transformer_engine/pytorch/module/layernorm_mlp.py requires immediate attention to fix the undefined fc2_out variable issue in the return logic (lines 829-835)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 3/5 | Implements selective activation checkpointing with checkpoint flag. Adds recomputation logic to skip FC2 during backward pass. Found critical issue where fc2_out is undefined when is_recomputation=True with is_grad_enabled=False, causing UnboundLocalError at return statement. |
Sequence Diagram
sequenceDiagram
participant User
participant LayerNormMLP
participant Forward as _forward
participant Recompute as _recompute
participant Backward
User->>LayerNormMLP: forward(inp, weights, checkpoint=True)
LayerNormMLP->>Forward: _forward(checkpoint=True, recompute_for_bwd=False)
Note over Forward: save_for_checkpoint=True
Note over Forward: is_recomputation=False
Forward->>Forward: Save autocast state & quantizer states
Forward->>Forward: Save input tensors to ctx (7 tensors)
Forward->>Forward: Save other_args dict to ctx
Forward->>Forward: Compute LayerNorm
Forward->>Forward: Compute FC1 + Activation
Forward->>Forward: Compute FC2 (full forward)
Forward->>Forward: Skip activation saving (checkpoint=True)
Forward-->>LayerNormMLP: fc2_out
LayerNormMLP-->>User: output
User->>LayerNormMLP: backward(grad_output)
LayerNormMLP->>Recompute: _recompute(ctx)
alt checkpoint=True
Note over Recompute: Restore saved tensors (7 tensors)
Recompute->>Recompute: Set autocast state from ctx
Recompute->>Recompute: Restore quantizer states (DelayedScaling)
Recompute->>Forward: _forward(saved_tensors, checkpoint=True, recompute_for_bwd=True)
Note over Forward: is_recomputation=True
Forward->>Forward: Compute LayerNorm
Forward->>Forward: Compute FC1 + Activation
Forward->>Forward: SKIP FC2 (is_recomputation=True)
Forward->>Forward: Save recomputed activations to ctx
Forward-->>Recompute: [ctx, activations...]
Recompute->>Recompute: Restore autocast & quantizer states
Recompute-->>Backward: [ctx, activations...]
else checkpoint=False
Note over Recompute: Restore saved tensors (14 tensors)
Recompute-->>Backward: [ctx, activations...]
end
Backward->>Backward: Unpack activations
Backward->>Backward: Compute gradients using recomputed activations
Backward-->>User: gradients
Additional Comments (1)
-
transformer_engine/pytorch/module/layernorm_mlp.py, line 829-835 (link)logic:
fc2_outis undefined whenis_recomputation=Trueandcheckpoint=True(FC2 is skipped at line 592), causingUnboundLocalError. The return at line 826 exits early only whenis_grad_enabled=True, but ifis_grad_enabled=Falseduring recomputation, execution continues to line 835 wherefc2_outdoesn't exist.
1 file reviewed, 1 comment
|
Hi @ptrendx, I have another quick update: all tests pass now, including the CUDA graph tests. Solution is described in the edited description. Thanks! |
| @@ -0,0 +1,269 @@ | |||
| import time | |||
| import torch | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general comment about this file - it is really nice, but it is not a test - it doesn't actually test anything, it just measures. We could introduce some test functionality here by e.g. ensuring that the error between the checkpointed LayerNormMLP is zero (since this shouldn't affect numerics) or that the memory used is lower (ideally we would quantify the expected memory usage and test against that, but for now even just making sure that the memory usage goes down would be good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I converted it into a test for checking that memory goes down at least 6X in the forward pass. I also asserted that checkpointing is slower than not checkpointing in the backward pass (not sure if this is helpful, but let me know), and that the differences are 0. I put this test in tests/pytorch/layernorm_mlp/test_selective_activation_checkpointing.py because I wasn't sure where it fit in the rest of the testing scripts, but let me know if this test would be better elsewhere!
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
| # save the state of autocast and quantizers for recomputation | ||
| ctx.autocast_state = ( | ||
| FP8GlobalStateManager.get_autocast_state() | ||
| ) # to restore autocast state during recomputation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to also save the RNG state here, e.g. to get the same results when using stochastic rounding with FP4. You can see in transformer_engine/pytorch/distributed.py how the checkpoint function handles that when there is no custom RNG tracker present. @ksivaman do you think we should add here logic to handle custom RNG trackers for Megatron?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just added this for cpu and cuda rng states, let me know if I should try to add functionality for custom rng tracker
|
@jaimec00 I left some comments. Generally it looks good, other that the duplication in the test infrastructure and not handling the RNG state. I did resolve some of the Greptile comments that you already addressed, but there are still a few that apply, so I left those - please address those as well. |
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, the forward pass saves only inputs instead of intermediate activations, achieving >6X memory reduction. During the backward pass, FC1 and activation are recomputed while FC2 computation is skipped (since it's not needed for gradients).
Key Implementation Details:
- Adds
checkpointboolean parameter toLayerNormMLPconstructor (defaultFalse) - Refactors
forward()into_forward()withrecompute_for_bwdflag to control behavior - Introduces
_recompute()method that restores FP8 quantizer states, autocast context, and RNG states before recomputing - Skips FC2 GEMM during recomputation via
if not (is_recomputation and checkpoint)guard - Properly handles FP8 DelayedScaling recipe by saving/restoring quantizer states using
FP8GlobalStateManager
Test Coverage:
- New dedicated test file validates numerical correctness (zero gradient/output differences)
- Existing test suites extended with
checkpointparameter across numerics, sanity, recipe, CUDA graphs, and distributed tests - Performance benchmarks confirm expected tradeoffs: ~83% memory reduction with ~20% backward slowdown
Minor Issues:
- Several typo fixes included (
Unsuppored→Unsupported,corectness→correctness) - CPU offloading and FSDP scattering bypassed when
checkpoint=True(expected behavior per PR description)
Confidence Score: 4/5
- This PR is safe to merge with minimal risk, pending verification of edge cases around return values during recomputation
- The implementation is well-structured with comprehensive test coverage across multiple scenarios (FP8 recipes, distributed modes, CUDA graphs). The core logic for skipping FC2 during recomputation is sound. FP8 quantizer state and RNG state management appears correct. However, there are some edge cases that should be verified: (1) behavior when
return_layernorm_output=Trueduring recomputation may not match expectations, and (2) the interaction between checkpoint mode and features like CPU offloading/FSDP is intentionally disabled but should be clearly documented. Previous review comments about quantizer cloning and state management have been addressed. - transformer_engine/pytorch/module/layernorm_mlp.py - verify return value handling during recomputation with
return_layernorm_output=True
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 4/5 | Implements selective activation checkpointing by adding checkpoint parameter. Recomputes FC1 and activation in backward pass while skipping FC2. Properly handles FP8 quantizer state and RNG state restoration. Some edge cases around return values during recomputation need verification. |
| tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py | 5/5 | New test file validating checkpoint feature. Tests numerical correctness (outputs and gradients match), memory reduction (>6X), and expected backward slowdown. Covers multiple model sizes and sequence lengths. |
| tests/pytorch/test_numerics.py | 5/5 | Adds checkpoint parameter to existing LayerNormMLP test cases. Includes typo fixes ('Unsuppored' → 'Unsupported', 'Shoule' → 'Should'). Ensures checkpoint functionality is tested across all accuracy and wgrad tests. |
| tests/pytorch/test_cuda_graphs.py | 5/5 | Splits layernorm_mlp test case into separate layernorm_mlp_nocheckpoint and layernorm_mlp_checkpoint variants to test CUDA graph compatibility with both checkpoint modes. Tests now pass for both configurations. |
Sequence Diagram
sequenceDiagram
participant User
participant LayerNormMLP
participant Forward as _forward()
participant Ctx as Context
participant FC1 as FC1 GEMM
participant Act as Activation
participant FC2 as FC2 GEMM
participant Backward as backward()
participant Recompute as _recompute()
Note over User,Recompute: Forward Pass (checkpoint=True)
User->>LayerNormMLP: forward(input)
LayerNormMLP->>Forward: _forward(checkpoint=True, recompute_for_bwd=False)
Forward->>Ctx: Save checkpoint flag, autocast state, RNG state
Forward->>Ctx: Save input tensors for recompute
Forward->>Forward: LayerNorm computation
Forward->>FC1: Compute FC1 GEMM
FC1-->>Forward: fc1_out
Forward->>Act: Apply activation (GELU)
Act-->>Forward: act_out
Forward->>FC2: Compute FC2 GEMM
FC2-->>Forward: fc2_out
Note over Forward: Deallocate activations (saved memory!)
Forward-->>User: fc2_out
Note over User,Recompute: Backward Pass (checkpoint=True)
User->>Backward: backward(grad_output)
Backward->>Recompute: _recompute(ctx)
Recompute->>Ctx: Restore autocast state & RNG state
Recompute->>Forward: _forward(checkpoint=True, recompute_for_bwd=True)
Note over Forward: Recompute from saved inputs
Forward->>Forward: LayerNorm computation
Forward->>FC1: Compute FC1 GEMM (again)
FC1-->>Forward: fc1_out
Forward->>Act: Apply activation (again)
Act-->>Forward: act_out
Note over Forward,FC2: Skip FC2 computation (not needed for gradients!)
Forward->>Ctx: Save recomputed activations to ctx
Forward-->>Recompute: [ctx, activations]
Recompute->>Ctx: Restore final autocast & RNG states
Recompute-->>Backward: ctx + recomputed tensors
Backward->>Backward: Compute gradients using recomputed activations
Backward-->>User: input gradients
17 files reviewed, no comments
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
Signed-off-by: Jaime Cardenas <jaime@evolutionaryscale.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, it trades ~20% slower backward pass for >6X memory reduction by recomputing activations during backward instead of saving them.
Key implementation details:
- Refactored
forwardinto_forward(main logic) and added_recomputemethod for backward recomputation - When
checkpoint=True, only inputs (weights, biases, input tensor) are saved during forward pass instead of all intermediate activations - During backward, activations are recomputed by calling
_forwardagain withrecompute_for_bwd=True - FC2 computation is intelligently skipped during recomputation since it's not needed for gradient calculation
- RNG state and FP8 quantizer states are properly saved/restored to ensure numerical correctness during recomputation
- CPU offloading and FSDP scattering only apply when
checkpoint=False(when activations exist to offload/scatter)
Test coverage:
- New test validates >6X memory reduction with identical numerical outputs and gradients
- Existing test suites extended with
checkpointparameter across accuracy, sanity, CUDA graphs, and distributed tests - All tests reportedly pass including CUDA graphs (fixed by proper autocast/quantizer state management)
Confidence Score: 4/5
- This PR is safe to merge with thorough testing and review of the checkpointing logic
- The implementation is well-structured with proper state management for RNG and quantizers, comprehensive test coverage validates correctness and performance claims, and the feature is opt-in (default checkpoint=False maintains existing behavior). Score is 4 rather than 5 due to the complexity of the checkpointing logic and potential edge cases in distributed/FP8 scenarios that may need additional validation in production workloads.
- Primary attention needed for
transformer_engine/pytorch/module/layernorm_mlp.pydue to complex state management logic in_forwardand_recomputemethods, particularly around FP8 quantizer state handling and conditional tensor saving/clearing paths.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 4/5 | Implements selective activation checkpointing by adding checkpoint parameter, refactoring forward into _forward and _recompute, with proper RNG/quantizer state management. Logic correctly skips FC2 during recomputation and handles activation saving conditionally. |
| tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py | 5/5 | New test validates checkpointing achieves >6X memory reduction in forward pass with identical numerical results. Tests appropriately verify performance tradeoffs and gradient correctness. |
| tests/pytorch/test_numerics.py | 5/5 | Adds checkpoint parameter to existing LayerNormMLP accuracy tests, plus typo fixes. Changes are minimal and properly extend test coverage. |
| tests/pytorch/test_cuda_graphs.py | 5/5 | Splits LayerNormMLP CUDA graph tests into separate checkpoint/no-checkpoint cases to test both modes independently. |
Sequence Diagram
sequenceDiagram
participant User
participant LayerNormMLP
participant _LayerNormMLP.forward
participant _LayerNormMLP._forward
participant _LayerNormMLP.backward
participant _LayerNormMLP._recompute
Note over User,_LayerNormMLP._recompute: Forward Pass (checkpoint=True)
User->>LayerNormMLP: forward(input)
LayerNormMLP->>_LayerNormMLP.forward: apply(input, checkpoint=True)
_LayerNormMLP.forward->>_LayerNormMLP._forward: _forward(..., recompute_for_bwd=False)
Note over _LayerNormMLP._forward: save_for_checkpoint=True
Note over _LayerNormMLP._forward: Save inputs only (inp, weights, biases)
Note over _LayerNormMLP._forward: Save RNG state & FP8 quantizer state
Note over _LayerNormMLP._forward: Compute LayerNorm → FC1 → Activation → FC2
Note over _LayerNormMLP._forward: Clear intermediate activations
_LayerNormMLP._forward-->>_LayerNormMLP.forward: fc2_out
_LayerNormMLP.forward-->>User: output
Note over User,_LayerNormMLP._recompute: Backward Pass (checkpoint=True)
User->>_LayerNormMLP.backward: backward(grad_output)
_LayerNormMLP.backward->>_LayerNormMLP._recompute: _recompute(ctx)
Note over _LayerNormMLP._recompute: Restore RNG state & quantizer state
_LayerNormMLP._recompute->>_LayerNormMLP._forward: _forward(..., recompute_for_bwd=True)
Note over _LayerNormMLP._forward: is_recomputation=True
Note over _LayerNormMLP._forward: Recompute LayerNorm → FC1 → Activation
Note over _LayerNormMLP._forward: Skip FC2 (not needed for gradients)
Note over _LayerNormMLP._forward: Save recomputed activations for backward
_LayerNormMLP._forward-->>_LayerNormMLP._recompute: (ctx, inputmat, ln_out, fc1_out, act_out, ...)
Note over _LayerNormMLP._recompute: Restore current RNG & quantizer state
_LayerNormMLP._recompute-->>_LayerNormMLP.backward: saved tensors
Note over _LayerNormMLP.backward: Compute gradients using recomputed activations
_LayerNormMLP.backward-->>User: input_grad, weight_grads
3 files reviewed, no comments
|
@ptrendx Thank you for reviewing! I commited a few changes which I believe cover the points you made
A few points:
Thanks! |
Description
Implement Selective Activation Checkpointing for LayerNormMLP by adding a "checkpoint" flag. If checkpoint=True, activations are recomputed in the backward pass, while skipping the recomputation of fc2, as it is not needed for the backward pass. This reduces memory significantly, allowing for larger MLPs without running into OOMs, while still keeping the functionality for SequenceParallel and TensorParallel. Only functionality that is changed is cpu offloading, since there are no more activations to offload when checkpoint=True.
When checkpoint=False, runs regular LayerNormMLP, and all tests in tests/pytorch/selective_layernorm_mlp (listed in "Changes") pass. When checkpoint=True, all tests pass,
except for test_cuda_graphs.py.NOTE: all tests pass now. The cuda_graphs.py issue was because the recomputation was being done outside of the autocast context. saving the autocast state (and quantizer states) via FP8GlobalStateManager in fwd, setting them to that in recomputation, and restoring for bwd fixed the issue.
Fixes #623
Type of change
Changes
Performance
with checkpoint=True, forward pass sees no notable change in runtime, but >6X reduction in memory. Backward pass is$\approx$ 20% slower, with larger memory than with checkpoint=False, but still significantly less memory overall.
#########################################
Model Config
TIME (ms)
MEM (MB)
MAX ABSOLUTE ERRORS
#########################################
Model Config
TIME (ms)
MEM (MB)
MAX ABSOLUTE ERRORS
#########################################
Checklist: