Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion tests/cpp/operator/test_normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,18 @@ void compute_ref_output(NormType norm_type,
tmp = current * rsigma[i] * g;
}

// Write output (scaled only for fp8 paths)
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));

// amax semantics:
// - fp8_out (scale != 1): amax on pre-scale compute value 'tmp'
// - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16)
if (scale != 1.f) {
current_max = fmaxf(current_max, fabsf(tmp));
} else {
OutputType out_t_val = static_cast<OutputType>(tmp);
current_max = fmaxf(current_max, fabsf(static_cast<compute_t>(out_t_val)));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(

if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(temp_output));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(temp_output);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
}
if (params.fp8_out) {
temp_output = temp_output * scale;
Expand Down Expand Up @@ -290,7 +297,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(z_ij));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(z_ij);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke

if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(temp_output));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(temp_output);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
}
if (params.fp8_out) {
temp_output = temp_output * scale;
Expand Down Expand Up @@ -265,7 +272,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(z_ij));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(z_ij);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
Expand Down
Loading