From 30c997ca367f6fa6b41ec0127113aab2d6786546 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Thu, 6 Nov 2025 15:00:17 +0000 Subject: [PATCH] Fix amax computation using output_t data in normalization Signed-off-by: Evgeny --- tests/cpp/operator/test_normalization.h | 12 +++++++++++- .../normalization/layernorm/ln_fwd_kernels.cuh | 18 ++++++++++++++++-- .../rmsnorm/rmsnorm_fwd_kernels.cuh | 18 ++++++++++++++++-- 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index fe69852d00..271345686e 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -114,8 +114,18 @@ void compute_ref_output(NormType norm_type, tmp = current * rsigma[i] * g; } + // Write output (scaled only for fp8 paths) output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); + + // amax semantics: + // - fp8_out (scale != 1): amax on pre-scale compute value 'tmp' + // - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16) + if (scale != 1.f) { + current_max = fmaxf(current_max, fabsf(tmp)); + } else { + OutputType out_t_val = static_cast(tmp); + current_max = fmaxf(current_max, fabsf(static_cast(out_t_val))); + } } } diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index 6050b164d5..38c4096073 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -123,7 +123,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( if (requires_amax) { __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(temp_output)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(temp_output); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } } if (params.fp8_out) { temp_output = temp_output * scale; @@ -290,7 +297,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(z_ij)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(z_ij); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } if (params.fp8_out) { z.data.elt[jt] = z_ij * scale; } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index fc093b73a7..7fed7f123a 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -115,7 +115,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke if (requires_amax) { __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(temp_output)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(temp_output); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } } if (params.fp8_out) { temp_output = temp_output * scale; @@ -265,7 +272,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(z_ij)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(z_ij); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } if (params.fp8_out) { z.data.elt[jt] = z_ij * scale; }