-
Notifications
You must be signed in to change notification settings - Fork 540
[PyTorch] Fix amax computation using output_t data in normalization #2355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Fixes amax computation to accurately reflect output tensor precision by computing on output_t data for non-fp8 outputs instead of compute_t.
Key Changes:
- LayerNorm and RMSNorm kernels now branch on
params.fp8_outwhen computing amax - For fp8 outputs: amax computed on pre-scale
compute_t(unchanged behavior) - For non-fp8 outputs: amax computed after round-trip conversion
compute_t(output_t(value))to capture precision loss - Reference test implementation updated to match corrected semantics
- Applied consistently across tuned and general kernel variants in both normalization types
Why This Matters:
When outputting to reduced-precision formats like bf16/fp16, the amax should reflect the actual maximum absolute value in the output tensor after quantization/conversion, not the higher-precision intermediate value. This ensures accurate downstream quantization scaling.
Confidence Score: 5/5
- This PR is safe to merge with minimal risk - it's a well-isolated correctness fix for amax computation
- The fix is surgical and correct: adds proper branching logic to compute amax on the appropriate data type, matches reference implementation, maintains backward compatibility for fp8 paths, and applies the pattern consistently across all affected kernels
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh | 5/5 | Fixed amax computation for non-fp8 outputs by computing on output_t instead of compute_t, matching expected semantics in two kernel variants |
| transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh | 5/5 | Fixed amax computation for non-fp8 outputs by computing on output_t instead of compute_t, matching expected semantics in two kernel variants |
| tests/cpp/operator/test_normalization.h | 5/5 | Updated reference implementation to match corrected amax semantics: fp8 uses pre-scale compute_t, non-fp8 uses value after output_t conversion |
Sequence Diagram
sequenceDiagram
participant Kernel as Normalization Kernel
participant Compute as compute_t (fp32)
participant Output as output_t (bf16/fp16)
participant Amax as Amax Computation
Kernel->>Compute: Calculate normalized value (temp_output)
alt fp8_out == true
Compute->>Amax: Compute amax on pre-scale compute_t
Note over Amax: amax = max(amax, |temp_output|)
Compute->>Compute: Apply scale (temp_output *= scale)
Compute->>Output: Convert to output_t
else fp8_out == false (FIXED)
Compute->>Output: Convert to output_t first
Note over Output: out_t_val = output_t(temp_output)
Output->>Compute: Convert back to compute_t
Compute->>Amax: Compute amax on round-trip value
Note over Amax: amax = max(amax, |compute_t(out_t_val)|)
end
Output->>Kernel: Store final output value
3 files reviewed, no comments
|
/te-ci |
Description
Fused amax was computed for compute_t data (fp32), not the output_t (e.g. bf16), which is not exactly correct.
Now amax is computed on output_t data (adding one more conversion).
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: