Skip to content

Conversation

@negvet
Copy link
Collaborator

@negvet negvet commented Nov 6, 2025

Description

Fused amax was computed for compute_t data (fp32), not the output_t (e.g. bf16), which is not exactly correct.
Now amax is computed on output_t data (adding one more conversion).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet
Copy link
Collaborator Author

negvet commented Nov 6, 2025

/te-ci pytorch

@negvet negvet marked this pull request as ready for review November 6, 2025 16:24
@negvet negvet requested review from ptrendx and timmoon10 November 6, 2025 16:24
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixes amax computation to accurately reflect output tensor precision by computing on output_t data for non-fp8 outputs instead of compute_t.

Key Changes:

  • LayerNorm and RMSNorm kernels now branch on params.fp8_out when computing amax
  • For fp8 outputs: amax computed on pre-scale compute_t (unchanged behavior)
  • For non-fp8 outputs: amax computed after round-trip conversion compute_t(output_t(value)) to capture precision loss
  • Reference test implementation updated to match corrected semantics
  • Applied consistently across tuned and general kernel variants in both normalization types

Why This Matters:
When outputting to reduced-precision formats like bf16/fp16, the amax should reflect the actual maximum absolute value in the output tensor after quantization/conversion, not the higher-precision intermediate value. This ensures accurate downstream quantization scaling.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - it's a well-isolated correctness fix for amax computation
  • The fix is surgical and correct: adds proper branching logic to compute amax on the appropriate data type, matches reference implementation, maintains backward compatibility for fp8 paths, and applies the pattern consistently across all affected kernels
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh 5/5 Fixed amax computation for non-fp8 outputs by computing on output_t instead of compute_t, matching expected semantics in two kernel variants
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh 5/5 Fixed amax computation for non-fp8 outputs by computing on output_t instead of compute_t, matching expected semantics in two kernel variants
tests/cpp/operator/test_normalization.h 5/5 Updated reference implementation to match corrected amax semantics: fp8 uses pre-scale compute_t, non-fp8 uses value after output_t conversion

Sequence Diagram

sequenceDiagram
    participant Kernel as Normalization Kernel
    participant Compute as compute_t (fp32)
    participant Output as output_t (bf16/fp16)
    participant Amax as Amax Computation

    Kernel->>Compute: Calculate normalized value (temp_output)
    
    alt fp8_out == true
        Compute->>Amax: Compute amax on pre-scale compute_t
        Note over Amax: amax = max(amax, |temp_output|)
        Compute->>Compute: Apply scale (temp_output *= scale)
        Compute->>Output: Convert to output_t
    else fp8_out == false (FIXED)
        Compute->>Output: Convert to output_t first
        Note over Output: out_t_val = output_t(temp_output)
        Output->>Compute: Convert back to compute_t
        Compute->>Amax: Compute amax on round-trip value
        Note over Amax: amax = max(amax, |compute_t(out_t_val)|)
    end
    
    Output->>Kernel: Store final output value
Loading

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@negvet
Copy link
Collaborator Author

negvet commented Nov 7, 2025

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants