diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py new file mode 100644 index 0000000000..654dc8d193 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -0,0 +1,307 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py +# and also the test_nvfp4_rht_quantize_exact.py. +# Separate to make sure all the functionalities are working as expected. +# Otherwise reference implementation will get messy. + +# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality +# together with the quantization functionality. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.common.recipe import NVFP4BlockScaling + +import pytest +import torch +import random +import math + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def generate_random_multiples_sum(total=8192, n=4, multiple=64): + if total % multiple != 0: + raise ValueError(f"Total ({total}) must be a multiple of {multiple}") + if (total // multiple) < n: + raise ValueError("Total too small for given n and multiple.") + + # Work in units of multiples + total_units = total // multiple + + # choose n−1 random cut points in [1, total_units−1) + cuts = sorted(random.sample(range(1, total_units), n - 1)) + + # convert to segment lengths + parts = ( + [cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]] + ) + + # convert back to multiples + return [p * multiple for p in parts] + + +def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]: + least_multiple = 64 + num_chunks = 4 + split_sections = None + + avg_split = M // num_chunks + + if M == 0 or N == 0: + # all zeros + return [0] * num_chunks + if edge_cases == "regular": + split_sections = [avg_split] * num_chunks + elif edge_cases == "zero_tokens_front": + split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2] + elif edge_cases == "zero_tokens_end": + split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0] + elif edge_cases == "zero_tokens_middle": + split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2] + elif edge_cases == "random_uneven_split": + split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple) + else: + raise ValueError(f"Invalid edge case: {edge_cases}") + + # adds up the split_sections to make it M + assert sum(split_sections) == M, "The split_sections do not add up to M" + + # make sure every split_section is a multiple of least_multiple + for split_section in split_sections: + assert ( + split_section % least_multiple == 0 + ), "The split_sections are not multiples of least_multiple" + + return split_sections + + +# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding +def get_nvfp4_scale_shape_no_padding(shape, columnwise): + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = K + inner = math.ceil(M / 16) + return (outer, inner) + # rowwise + outer = M + inner = math.ceil(K / 16) + return (outer, inner) + + +def reference_group_quantize( + x: torch.Tensor, + quantizers: list[NVFP4Quantizer], + split_sections: list[int], + return_identity: bool, + return_transpose: bool, +) -> torch.Tensor: + x_view = x.reshape(-1, x.size(-1)) + x_chunks = torch.split(x, split_sections) + + # rowwise quantization + x_qx = [] + x_sx = [] + x_amax_rowwise = [] + # columnwise quantization + x_qx_t = [] + x_sx_t = [] + x_amax_colwise = [] + + for i in range(len(x_chunks)): + x_chunk = x_chunks[i] + x_nvfp4_res = quantizers[i](x_chunk) + if return_identity: + x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8)) + x_sx.append(x_nvfp4_res._rowwise_scale_inv) + x_amax_rowwise.append(x_nvfp4_res._amax_rowwise) + else: + x_qx.append(None) + x_sx.append(None) + x_amax_rowwise.append(None) + if return_transpose: + x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8)) + x_sx_t.append(x_nvfp4_res._columnwise_scale_inv) + x_amax_colwise.append(x_nvfp4_res._amax_columnwise) + else: + x_qx_t.append(None) + x_sx_t.append(None) + x_amax_colwise.append(None) + + return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise + + +def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None: + assert x.shape == y.shape + assert x.dtype == y.dtype + + +def check_group_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, +) -> None: + + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + num_chunks = len(split_sections) + + x_splits = torch.split(x, split_sections) + + # Quantize + quantizers = [ + NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + for _ in range(len(split_sections)) + ] + x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( + reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) + ) + + split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers) + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same same and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close( + x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False) + x_sx_valid = x_sx[i][: valid_scale_shape[0], : valid_scale_shape[1]] + x_sx_ref_valid = x_sx_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]] + torch.testing.assert_close(x_sx_valid, x_sx_ref_valid, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same same and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close( + x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) + x_sx_t_valid = x_sx_t[i][: valid_scale_shape[0], : valid_scale_shape[1]] + x_sx_t_ref_valid = x_sx_t_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]] + torch.testing.assert_close(x_sx_t_valid, x_sx_t_ref_valid, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # edge case, zero tokens for all + (0, 512), + # full tile cases + (256, 1024), + (1024, 256), + # larger sizes + (8192, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +def test_rht_with_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + with_random_sign_mask: bool, + with_rht: bool, +) -> None: + + split_sections = generate_split_sections(M, N, edge_cases) + + # currently disable pre-RHT amax + with_post_rht_amax = with_rht + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_group_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index 0842de9ea4..2dd8d8db00 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -5,6 +5,9 @@ import pytest import torch import transformer_engine.pytorch as te + +import transformer_engine_torch as tex + from transformer_engine.pytorch import NVFP4Quantizer recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -151,6 +154,74 @@ def quantize_fp4( return qx, sx, qx_t, sx_t +def group_quantize_fp4( + x: torch.Tensor, + use_stochastic_rounding: bool, + use_2D: bool, + use_RHT: bool, + split_sections: list[int], + use_tex_split_quantize: bool = True, +) -> torch.Tensor: + """ + Group quantize function with toggle between tex.split_quantize and manual split/call methods. + + Args: + x (torch.Tensor): Input tensor. + use_stochastic_rounding (bool): Use stochastic rounding. + use_2D (bool): Use 2D quantization. + use_RHT (bool): Use RHT. + split_sections (list[int]): Split sizes for inputs. + use_tex_split_quantize (bool): Toggle method. If True, use tex.split_quantize, else use manual split and per-quantizer invocation. + + Returns: + tuple: Lists of quantized tensors and scale tensors for all sections. + """ + num_tensors = len(split_sections) + nvfp4_quantizers = [ + NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=use_RHT, + with_post_rht_amax=True, + stochastic_rounding=use_stochastic_rounding, + with_2d_quantization=use_2D, + ) + for _ in range(num_tensors) + ] + + if use_tex_split_quantize: + outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers) + qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs] + sx_list = [output._rowwise_scale_inv for output in outputs] + qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs] + sx_t_list = [output._columnwise_scale_inv for output in outputs] + else: + x_chunks = torch.split(x, split_sections) + qx_list = [] + sx_list = [] + qx_t_list = [] + sx_t_list = [] + for i in range(num_tensors): + x_chunk = x_chunks[i] + x_nvfp4_sut = nvfp4_quantizers[i](x_chunk) + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + assert x_nvfp4_sut._columnwise_data is not None + qx_t = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._columnwise_scale_inv is not None + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_list.append(qx) + sx_list.append(sx) + qx_t_list.append(qx_t) + sx_t_list.append(sx_t) + + return qx_list, sx_list, qx_t_list, sx_t_list + + def check_quantization_nvfp4_versus_reference( x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool ) -> None: @@ -209,6 +280,92 @@ def check_quantization_nvfp4_versus_reference( assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." +def check_group_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + use_2D: bool, + use_RHT: bool, + num_splits: int, + use_tex_split_quantize: bool = True, +) -> None: + device = "cuda" + torch.manual_seed(seed) + n_iters = 50 + + split_sections = [M // num_splits] * num_splits + x_total = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1 + x_splits = torch.split(x_total, split_sections) + + q_rn_list, s_rn_list, q_t_rn_list, s_t_rn_list = group_quantize_fp4( + x_total, + use_stochastic_rounding=False, + use_2D=use_2D, + use_RHT=use_RHT, + split_sections=split_sections, + use_tex_split_quantize=use_tex_split_quantize, + ) + sr_n_iter_results = [] + for i in range(n_iters): + q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list = group_quantize_fp4( + x_total, + use_stochastic_rounding=True, + use_2D=use_2D, + use_RHT=use_RHT, + split_sections=split_sections, + use_tex_split_quantize=use_tex_split_quantize, + ) + sr_n_iter_results.append((q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list)) + + for i, x in enumerate(x_splits): + y = x.t().contiguous() + if use_RHT: + y = RHT(y) + amax = torch.max(torch.abs(x)).float() + + # fetch q_rn, s_rn, q_t_rn, s_t_rn + q_rn = q_rn_list[i] + s_rn = s_rn_list[i] + q_t_rn = q_t_rn_list[i] + s_t_rn = s_t_rn_list[i] + + dq_rn = dequantize_fp4(q_rn, s_rn, amax) + dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax) + error_rn = (dq_rn - x).float() + me_rn = torch.sqrt((error_rn * error_rn).mean()) + error_t_rn = (dq_t_rn - y).float() + me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) + sr_result = torch.zeros_like(x).float() + sr_t_result = torch.zeros_like(x).float().t().contiguous() + for iter_idx in range(n_iters): + result_sr = sr_n_iter_results[iter_idx] + q_sr = result_sr[0][i] + s_sr = result_sr[1][i] + q_t_sr = result_sr[2][i] + s_t_sr = result_sr[3][i] + + dq_sr = dequantize_fp4(q_sr, s_sr, amax) + dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax) + sr_result += dq_sr.float() + sr_t_result += dq_t_sr.float() + + # Get the mean result of the stochastic rounding + # It should be more accurate than the RN result + sr_result /= n_iters + error_sr = (sr_result - x).float() + me_sr = torch.sqrt((error_sr * error_sr).mean()) + sr_t_result /= n_iters + error_t_sr = (sr_t_result - y).float() + me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + + print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." + assert ( + me_t_sr < me_t_rn + ), "Stochastic rounding failed - error larger than the round to nearest." + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, N", @@ -236,3 +393,39 @@ def test_quantization_block_tiling_versus_reference( M=M, N=N, ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (8192, 8192), + (4096, 7168), + (16384, 2048), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_2D", [False], ids=str) +@pytest.mark.parametrize("use_RHT", [True], ids=str) +@pytest.mark.parametrize("num_splits", [4, 8], ids=str) +@pytest.mark.parametrize("use_tex_split_quantize", [True, False], ids=str) +def test_group_stochastic_rounding_quantization_versus_reference( + x_dtype: torch.dtype, + use_2D: bool, + use_RHT: bool, + num_splits: int, + use_tex_split_quantize: bool, + M: int, + N: int, +) -> None: + if x_dtype == torch.float32 and use_RHT: + pytest.skip("RHT is only supported with bfloat16") + check_group_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + use_2D=use_2D, + use_RHT=use_RHT, + M=M, + N=N, + num_splits=num_splits, + use_tex_split_quantize=use_tex_split_quantize, + ) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 01f1deb983..b70e3f0c6c 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,7 +12,10 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import ( + FP8GlobalStateManager, + get_align_size_for_quantization, +) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -1774,9 +1777,7 @@ def _test_grouped_linear_accuracy( if num_gemms > 1: split_size = 1 if fp8: - split_size = 16 - if recipe.mxfp8() or recipe.nvfp4(): - split_size = 32 + split_size = get_align_size_for_quantization(recipe) m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero @@ -2082,9 +2083,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe): def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): - align_size = 16 - if recipe.mxfp8() or recipe.nvfp4(): - align_size = 32 + align_size = get_align_size_for_quantization(recipe) padded_tokens_per_expert = [ (num_tokens + align_size - 1) // align_size * align_size for num_tokens in tokens_per_expert diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 62b769c77e..7ef8542f56 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -175,6 +175,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise_fp4.cu hadamard_transform/hadamard_transform.cu + hadamard_transform/multi_hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu) # Compiling the files with the worst compilation time first to hopefully overlap diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 9d4bec41d5..c01ce7b78f 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -16,185 +16,12 @@ #include "common/common.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" +#include "hadamard_transform_utils.cuh" namespace transformer_engine { namespace { constexpr int kThreadsPerWarp = 32; -constexpr float k16x16HadamardScale = 0.25f; - -template -__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, - uint32_t& a2, uint32_t& a3, - void* addr) { - auto smem_addr = static_cast(__cvta_generic_to_shared(addr)); - if constexpr (kTranspose) { - asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) - : "r"(smem_addr)); - } else { - asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) - : "r"(smem_addr)); - } -} - -template -__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1, - uint32_t& a2, uint32_t& a3, - void* addr, uint32_t stride) { - if constexpr (kTranspose) { - asm volatile( - "wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 " - "{%0,%1,%2,%3}, [%4], %5;\n" - : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) - : "l"(addr), "r"(stride)); - } else { - asm volatile( - "wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 " - "{%0,%1,%2,%3}, [%4], %5;\n" - : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) - : "l"(addr), "r"(stride)); - } -} - -template -__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1, - uint32_t& a2, uint32_t& a3, void* addr, - uint32_t stride) { - if constexpr (kTranspose) { - asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" - : - : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); - } else { - asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" - : - : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); - } -} - -__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) { - asm volatile( - "movmatrix.sync.aligned.m8n8.trans.b16 " - "%0, %1;\n\t" - : "=r"(a0) - : "r"(a0)); -} - -__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) { - __nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16); - float f_a = __bfloat162float(bf16x2.x); - float f_b = __bfloat162float(bf16x2.y); - asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b)); - float_dst = fabsf(float_dst); -} - -template -__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc( - uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1, - uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3, - uint32_t& amax_result) { - uint32_t zero = 0; - uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; - asm volatile( - "wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n" - "{%0, %1, %2, %3, %4, %5, %6, %7}, \n" - "{%8, %9, %10, %11}, \n" - "{%12, %13, %14, %15}, \n" - "{%16, %17, %18, %19, %20, %21, %22, %23};\n\t" - : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6), - "=r"(temp7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero), - "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero)); - asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); - asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); - asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4)); - asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6)); - if constexpr (kCalculateAmax) { - uint32_t max_even; - uint32_t max_odd; - // Reduction tree to amax(abs(result)) into bf16x2 reg outparam. - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2)); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3)); - // N.B. mma is only called up to once per thread for identity and transpose respectively, so - // we don't have to accumulate into amax_result and can directly store into it. - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(amax_result) - : "r"(max_even), "r"(max_odd)); - } -} - -template -__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i, - uint16_t random_sign_mask, - uint32_t* had_frag_t, - uint16_t random_sign_mask_t) { - int32_t tid = threadIdx.x % 32; // Local tid - float temp_i[2]; - float temp_t[2]; -#pragma unroll - for (int i = 0; i < 2; i++) { - // i is the vertical fragment index. - // For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals. - uint32_t r = i * 8 + tid / 4; - -#pragma unroll - for (int j = 0; j < 2; j++) { -#pragma unroll - for (int k = 0; k < 2; k++) { - // k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits. - // j is the column fragment idx selecting between even and odd fragments. - // j increments 8 columns by switching fragments. - uint32_t c = j * 8 + k + tid % 4 * 2; - // 1 -> -1.0f, 0 -> 1.0f - int32_t base_sign = __popc(r & c); - if constexpr (kReturnIdentity) { - int32_t sign_i; - // Because tensor cores want the dot product dimension, - // contiguous, the regular, non-inverse hadamard swaps - // signs of columns and rows for inverse. In a simple reference, - // x.reshape(-1, 16) @ sign @ H16, this would be opposite but - // (sign @ H16) is transposed in this fragment. - if constexpr (kInverseHadamardIdentity) { - sign_i = ((random_sign_mask >> r) ^ base_sign); - } else { - sign_i = ((random_sign_mask >> c) ^ base_sign); - } - temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31)); - } - if constexpr (kReturnTransposed) { - int32_t sign_t; - if constexpr (kInverseHadamardTransposed) { - sign_t = ((random_sign_mask_t >> r) ^ base_sign); - } else { - sign_t = ((random_sign_mask_t >> c) ^ base_sign); - } - temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31)); - } - } - - if constexpr (kReturnIdentity) { - asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" - : "=r"(had_frag_i[i * 2 + j]) - : "f"(temp_i[1]), "f"(temp_i[0])); - } - if constexpr (kReturnTransposed) { - asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" - : "=r"(had_frag_t[i * 2 + j]) - : "f"(temp_t[1]), "f"(temp_t[0])); - } - } - } -} - -__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx, - uint32_t gmem_col_idx) { - uint32_t smem_row_idx = gmem_row_idx; - uint32_t xor_factor = (smem_row_idx * 2) % 8; - uint32_t smem_col_idx = gmem_col_idx ^ xor_factor; - return smem_row_idx * 8 + smem_col_idx; -} template diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_utils.cuh b/transformer_engine/common/hadamard_transform/hadamard_transform_utils.cuh new file mode 100644 index 0000000000..ad3bbf5cd7 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_utils.cuh @@ -0,0 +1,198 @@ +/************************************************************************* +* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* +* See LICENSE for license information. +************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_ +#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_ + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { + +constexpr float k16x16HadamardScale = 0.25f; + +template +__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr) { + auto smem_addr = static_cast(__cvta_generic_to_shared(addr)); + if constexpr (kTranspose) { + asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } else { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } +} + +template +__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr, uint32_t stride) { + if constexpr (kTranspose) { + asm volatile( + "wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } else { + asm volatile( + "wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } +} + +template +__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, void* addr, + uint32_t stride) { + if constexpr (kTranspose) { + asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } else { + asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } +} + +__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) { + asm volatile( + "movmatrix.sync.aligned.m8n8.trans.b16 " + "%0, %1;\n\t" + : "=r"(a0) + : "r"(a0)); +} + +__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) { + __nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16); + float f_a = __bfloat162float(bf16x2.x); + float f_b = __bfloat162float(bf16x2.y); + asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b)); + float_dst = fabsf(float_dst); +} + +template +__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc( + uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1, + uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3, + uint32_t& amax_result) { + uint32_t zero = 0; + uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; + asm volatile( + "wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n" + "{%0, %1, %2, %3, %4, %5, %6, %7}, \n" + "{%8, %9, %10, %11}, \n" + "{%12, %13, %14, %15}, \n" + "{%16, %17, %18, %19, %20, %21, %22, %23};\n\t" + : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6), + "=r"(temp7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero), + "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6)); + if constexpr (kCalculateAmax) { + uint32_t max_even; + uint32_t max_odd; + // Reduction tree to amax(abs(result)) into bf16x2 reg outparam. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2)); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3)); + // N.B. mma is only called up to once per thread for identity and transpose respectively, so + // we don't have to accumulate into amax_result and can directly store into it. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(amax_result) + : "r"(max_even), "r"(max_odd)); + } +} + +template +__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i, + uint16_t random_sign_mask, + uint32_t* had_frag_t, + uint16_t random_sign_mask_t) { + int32_t tid = threadIdx.x % 32; // Local tid + float temp_i[2]; + float temp_t[2]; +#pragma unroll + for (int i = 0; i < 2; i++) { + // i is the vertical fragment index. + // For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals. + uint32_t r = i * 8 + tid / 4; + +#pragma unroll + for (int j = 0; j < 2; j++) { +#pragma unroll + for (int k = 0; k < 2; k++) { + // k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits. + // j is the column fragment idx selecting between even and odd fragments. + // j increments 8 columns by switching fragments. + uint32_t c = j * 8 + k + tid % 4 * 2; + // 1 -> -1.0f, 0 -> 1.0f + int32_t base_sign = __popc(r & c); + if constexpr (kReturnIdentity) { + int32_t sign_i; + // Because tensor cores want the dot product dimension, + // contiguous, the regular, non-inverse hadamard swaps + // signs of columns and rows for inverse. In a simple reference, + // x.reshape(-1, 16) @ sign @ H16, this would be opposite but + // (sign @ H16) is transposed in this fragment. + if constexpr (kInverseHadamardIdentity) { + sign_i = ((random_sign_mask >> r) ^ base_sign); + } else { + sign_i = ((random_sign_mask >> c) ^ base_sign); + } + temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31)); + } + if constexpr (kReturnTransposed) { + int32_t sign_t; + if constexpr (kInverseHadamardTransposed) { + sign_t = ((random_sign_mask_t >> r) ^ base_sign); + } else { + sign_t = ((random_sign_mask_t >> c) ^ base_sign); + } + temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31)); + } + } + + if constexpr (kReturnIdentity) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_i[i * 2 + j]) + : "f"(temp_i[1]), "f"(temp_i[0])); + } + if constexpr (kReturnTransposed) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_t[i * 2 + j]) + : "f"(temp_t[1]), "f"(temp_t[0])); + } + } + } +} + +__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx, + uint32_t gmem_col_idx) { + uint32_t smem_row_idx = gmem_row_idx; + uint32_t xor_factor = (smem_row_idx * 2) % 8; + uint32_t smem_col_idx = gmem_col_idx ^ xor_factor; + return smem_row_idx * 8 + smem_col_idx; +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_ diff --git a/transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu new file mode 100644 index 0000000000..f840406c31 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu @@ -0,0 +1,601 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "hadamard_transform_utils.cuh" + +namespace transformer_engine { +namespace { + +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed +struct MultiAmaxArgs { + // (output) Amax buffer for pre-RHT amax buffer + void* output_pre_rht_amax_list[kMaxTensorsPerKernel]; + // (output) Amax buffer for RHT identity amax buffer + void* output_identity_amax_list[kMaxTensorsPerKernel]; + // (output) Amax buffer for RHT transpose amax buffer + void* output_transpose_amax_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of split_sections of each tensor of input + int split_sections_range[kMaxTensorsPerKernel + 1]; + // Number of tensors (splits) being processed by kernel + int num_tensors; +}; + +constexpr int kThreadsPerWarp = 32; + +template +__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], + IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + uint32_t& local_amax_reg, + uint32_t& local_amax_t_reg) { + uint32_t a_frag[4]; // A matrix fragment + uint32_t c_frag[4]; // Result fragment + + int warp_id = threadIdx.x / kThreadsPerWarp; + int local_rank = (threadIdx.x % kThreadsPerWarp); + + int ld_row_idx = local_rank % kHadamardDimension; + int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + + uint32_t temp_amax_reg; + uint32_t temp_amax_t_reg; + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } + + if (kReturnTransposedAmax) { + // TODO(Frank): This is not efficient, since we could directly load the + // matrix in transposed layout. + if (!kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], + b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_t_reg) + : "r"(local_amax_t_reg), "r"(temp_amax_t_reg)); + } + + if (kReturnPreRhtAmax) { + if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[1])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[2]) + : "r"(a_frag[2]), "r"(a_frag[3])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[2])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_pre_rht_amax_reg) + : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); + } +} + +template +__device__ __host__ constexpr int NextPowerOf2() { + static_assert(kN > 0, "kN must be > 0"); + // Round up to the next power of 2 by counting leading zeros. + return 1 << (32 - __builtin_clz(kN - 1)); +} + +template +__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax, + const float transpose_amax, float* staging_for_pre_rht, + float* staging_for_identity, float* staging_for_transpose, + float* output_pre_rht_amax_ptr, + float* output_identity_amax_ptr, + float* output_transpose_amax_ptr, const int warpid) { + // intra-warp reduction + constexpr int kWarpSize = 32; + int local_rank = threadIdx.x % 32; + float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max(pre_rht_amax) : 0.0f; + float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max(identity_amax) : 0.0f; + float warp_transpose_amax = + kReturnTransposedAmax ? warp_reduce_max(transpose_amax) : 0.0f; + + // inter-warp reduction + if (threadIdx.x % 32 == 0) { + if (kReturnPreRhtAmax) { + staging_for_pre_rht[warpid] = warp_pre_rht_amax; + } + if (kReturnIdentityAmax) { + staging_for_identity[warpid] = warp_identity_amax; + } + if (kReturnTransposedAmax) { + staging_for_transpose[warpid] = warp_transpose_amax; + } + } + __syncthreads(); + constexpr int kNumWarpsPow2 = NextPowerOf2(); + if (warpid == 0) { + if (kReturnIdentityAmax) { + float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f; + identity_accum = warp_reduce_max(identity_accum); + if (local_rank == 0) { + atomicMaxFloat(output_identity_amax_ptr, identity_accum); + } + } + } + if (warpid == 1) { + if (kReturnTransposedAmax) { + float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f; + transpose_accum = warp_reduce_max(transpose_accum); + if (local_rank == 0) { + atomicMaxFloat(output_transpose_amax_ptr, transpose_accum); + } + } + } + if (warpid == 2) { + if (kReturnPreRhtAmax) { + float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f; + pre_rht_accum = warp_reduce_max(pre_rht_accum); + if (local_rank == 0) { + atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum); + } + } + } +} + +// args: the mult-tensor amax arguments +__global__ void MultiZeroAmaxKernel(MultiAmaxArgs args) { + int num_tensors = args.num_tensors; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < num_tensors; tid += stride) { + float* output_pre_rht_amax_ptr = static_cast(args.output_pre_rht_amax_list[tid]); + float* output_identity_amax_ptr = static_cast(args.output_identity_amax_list[tid]); + float* output_transpose_amax_ptr = static_cast(args.output_transpose_amax_list[tid]); + if (output_pre_rht_amax_ptr != nullptr) { + *output_pre_rht_amax_ptr = 0; + } + if (output_identity_amax_ptr != nullptr) { + *output_identity_amax_ptr = 0; + } + if (output_transpose_amax_ptr != nullptr) { + *output_transpose_amax_ptr = 0; + } + } +} + +// args: the mult-tensor amax arguments +__global__ void MultiAmaxMemcpyD2DKernelPreRHT(MultiAmaxArgs args) { + int num_tensors = args.num_tensors; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < num_tensors; tid += stride) { + float* output_pre_rht_amax_ptr = static_cast(args.output_pre_rht_amax_list[tid]); + float* output_identity_amax_ptr = static_cast(args.output_identity_amax_list[tid]); + float* output_transpose_amax_ptr = static_cast(args.output_transpose_amax_list[tid]); + if (output_pre_rht_amax_ptr != nullptr) { + float pre_rht_amax = *output_pre_rht_amax_ptr; + if (output_identity_amax_ptr != nullptr) { + *output_identity_amax_ptr = pre_rht_amax; + } + if (output_transpose_amax_ptr != nullptr) { + *output_transpose_amax_ptr = pre_rht_amax; + } + } + } +} + +template +__global__ void MultiHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input, + const MultiAmaxArgs args, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, uint64_t num_rows, + uint64_t row_length) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + float* output_pre_rht_amax_ptr; + float* output_identity_amax_ptr; + float* output_transpose_amax_ptr; + + // calculate the global offset in Y direction to access the correct amax buffer + int global_offset_y = blockIdx.y * CHUNK_DIM_Y; + int tensor_id = 0; + while (args.split_sections_range[tensor_id + 1] <= global_offset_y) { + ++tensor_id; + } + output_pre_rht_amax_ptr = static_cast(args.output_pre_rht_amax_list[tensor_id]); + output_identity_amax_ptr = static_cast(args.output_identity_amax_list[tensor_id]); + output_transpose_amax_ptr = static_cast(args.output_transpose_amax_list[tensor_id]); + + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0); + static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0); + + constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y; + constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X; + + constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp; + + const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X; + + extern __shared__ __align__(128) char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uint8_t* dshmem = reinterpret_cast((base_shmem_ptr + 127) & ~127ULL); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + IType* in_sh_0 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + IType* in_sh_1 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + + IType* in_shs[2] = {in_sh_0, in_sh_1}; + + constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + + const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0); + + // Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + uint64_t* mbar = reinterpret_cast(dshmem); + dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y); + + float* max_staging_identity = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_transpose = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_pre_rht = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + + initialize_barriers(mbar, + is_master_thread); + + copy_2d_to_shared(in_shs[0], reinterpret_cast(&tensor_map_input), + input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0], + is_master_thread); + + uint32_t had_frag_i[4]; + uint32_t had_frag_t[4]; + get_hadamard_matrix_fragment( + had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t); + + float local_pre_rht_amax = 0.0; + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_pre_rht_amax_reg = *reinterpret_cast(&local_pre_rht_amax); + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { + for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { + int stage = STAGES_X * stage_y + stage_x; + + const int next_stage = stage + 1; + const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1; + const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y; + + if (next_stage < STAGES_X * STAGES_Y) { + const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y; + const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X; + + copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong + reinterpret_cast(&tensor_map_input), input_global_offset_X, + input_global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t compute_stage_x_num = + BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)); + const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y); + + const size_t in_row_stride = BUFF_DIM_X; + + IType* in_sh_ptr = in_shs[stage % 2]; + +#pragma unroll + for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) { + const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y + + threadIdx.y * kHadamardDimension); + const int in_row_offset = row_idx_offset * in_row_stride; + +#pragma unroll + for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) { + ComputeKernel( + had_frag_i, had_frag_t, + in_sh_ptr + in_row_offset + + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), + local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + } + + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + } + } + } + + const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp; + + if constexpr (kReturnPreRhtAmax) { + unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax); + } + if constexpr (kReturnIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + } + if constexpr (kReturnTransposedAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + } + + ReduceMax( + local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity, + max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr, + output_transpose_amax_ptr, warpid); + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace + +// broadcast_pre_rht_amax: when it's true, hadamard transform will be disabled +// if at this time, the amax buffers for output expects both amax_rowwise and amax_colwise +// then call MultiAmaxMemcpyD2DKernelPreRHT to D2D copy the amax values +void multi_hadamard_transform_amax(const Tensor& input_, std::vector& output_list, + const int* split_sections, size_t num_tensors, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + bool broadcast_pre_rht_amax, cudaStream_t stream) { + NVTE_API_CALL(multi_hadamard_transform_amax); +#if CUDA_VERSION >= 12080 + + // Check input tensor + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor& input = input_.data; + + // TODO: validate num_tensors and split_sections + // assert if num_tensors is greater than kMaxTensorsPerKernel + // will expand 64 to higher value if needed + // if input size is going to exceed 4KB kernel launch limit, will then support multi-launch + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + // check split_sections + // TODO: support m_splits_tensor for device initiated API + NVTE_CHECK(split_sections != nullptr, "split_sections should not be nullptr"); + + MultiAmaxArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.split_sections_range[0] = 0; + bool all_return_pre_rht_amax = true; + bool all_return_identity_amax = true; + bool all_return_transposed_amax = true; + for (size_t i = 0; i < num_tensors; ++i) { + void* output_pre_rht_amax_ptr = output_list[i]->amax.dptr; + // disable RHT(x) for now, only RHT_T(x) should be used + void* output_identity_amax_ptr = nullptr; + void* output_transpose_amax_ptr = output_list[i]->columnwise_amax.dptr; + all_return_pre_rht_amax &= (output_pre_rht_amax_ptr != nullptr); + all_return_identity_amax &= (output_identity_amax_ptr != nullptr); + all_return_transposed_amax &= (output_transpose_amax_ptr != nullptr); + // sanity check split_sections component to see if it's 64 multiple for each element + NVTE_CHECK(split_sections[i] % 64 == 0, "component ", i, + " of split_sections should be 64 multiple"); + // also skip adding this tensor to the kernel args there are zero elements in this split + if (split_sections[i] == 0) { + continue; + } + // fill in kernel arguments + kernel_args.output_pre_rht_amax_list[kernel_args.num_tensors] = output_pre_rht_amax_ptr; + kernel_args.output_identity_amax_list[kernel_args.num_tensors] = output_identity_amax_ptr; + kernel_args.output_transpose_amax_list[kernel_args.num_tensors] = output_transpose_amax_ptr; + kernel_args.split_sections_range[kernel_args.num_tensors + 1] = + kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + // check overflow + NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, + "split_sections_range overflow the int32_t"); + kernel_args.num_tensors++; + } + + NVTE_CHECK(all_return_pre_rht_amax || all_return_identity_amax || all_return_transposed_amax, + "At least one of return_pre_rht_amax, return_identity_amax, or return_transposed_amax " + "must be true"); + // currently we haven't supported all_return_identity_amax, assert error if it's mistakenly enabled + NVTE_CHECK(!all_return_identity_amax, + "Currently RHT transform should only be applied to transposed input"); + + if (broadcast_pre_rht_amax) { + NVTE_CHECK(all_return_pre_rht_amax, + "broadcast_pre_rht_amax is only supported when we compute pre-RHT amax"); + // if all_return_identity_amax and all_return_transposed_amax both are false, there is no need to broadcast anything + broadcast_pre_rht_amax &= (all_return_identity_amax || all_return_transposed_amax); + } + + // Multi zero out multiple amaxes if needed + // Curretly don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel + // let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block + dim3 block_setup_amax(kMaxTensorsPerKernel); + dim3 grid_setup_amax(1); + MultiZeroAmaxKernel<<>>(kernel_args); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + // four (1x4) 64x64 sub-tiles for ping-pong overlap + constexpr uint64_t kChunkBlockXSmall = 256; + constexpr uint64_t kChunkBlockYSmall = 64; + constexpr uint64_t kBuffDimX = 64; + constexpr uint64_t kBuffDimY = 64; + + alignas(64) CUtensorMap tensor_map_input{}; + + create_2D_tensor_map( + /*tensorMap=*/tensor_map_input, + /*tensor=*/input, + /*globalY=*/num_rows, + /*globalX=*/row_length, + /*shmemY=*/kBuffDimY, + /*shmemX=*/kBuffDimX, + /*stride_elems=*/row_length, + /*offset_elems=*/0, + /*type_num_bits=*/sizeof(IType) * 8, + /*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B); + + constexpr uint64_t kThreadBlockX = 4; + constexpr uint64_t kThreadBlockY = 1; + constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY; + + dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); + + dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + (all_return_transposed_amax && !broadcast_pre_rht_amax), kReturnTransposedAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + (all_return_identity_amax && !broadcast_pre_rht_amax), kReturnIdentityAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_return_pre_rht_amax, kReturnPreRhtAmax, + + // *2 for ping-pong + size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType); + size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) * + (kChunkBlockYSmall / kBuffDimY); + size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3; + // Add padding in case shmem ptr is not aligned to 128 bytes. + shmem_bytes = (shmem_bytes + 128); + + auto kernel = MultiHadamardAmaxTmaKernel< + IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY, + kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax, + kReturnIdentityAmax, kReturnTransposedAmax>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_bytes); + + kernel<<>>(tensor_map_input, kernel_args, + random_sign_mask, random_sign_mask_t, + num_rows, row_length); + if (broadcast_pre_rht_amax) { + MultiAmaxMemcpyD2DKernelPreRHT<<>>( + kernel_args); + }))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", + CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace transformer_engine + +// Multi hadamard transform API is unlike other multi-input & multi-output APIs +// Multi hadamard transform will take in a single input tensor, and directly calculate amax +// with optional RHT transform. That's because we can assume the input tensor list to be +// contiguous in memory, so the tensors are only splitted in dimension 0. +// RHT transform is 16x16, so as long as each split of the input has 16 multiple shape +// in dimension 0, we can treat the entire input as a single tensor. +// Although mathmatically 16 multple is enough for this function to be correct, +// for this kernel, we required 64 multiple of 16 in dimension 0 for better performance. +// Note: currently assumes split_sections is a list of integers in CPU +// TODO: split_sections could be a tensor for device initiated API +void nvte_multi_hadamard_transform_amax(const NVTETensor input, NVTETensor* outputs, + const int* split_sections, const size_t num_tensors, + int random_sign_mask, int random_sign_mask_t, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_hadamard_transform_amax); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + + Tensor* input_tensor = convertNVTETensorCheck(input); + std::vector output_list(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list[i] = convertNVTETensorCheck(outputs[i]); + } + // Call the multi-tensor Hadamard transform amax implementation. + multi_hadamard_transform_amax(*input_tensor, output_list, split_sections, num_tensors, + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), false, stream); +} + +// Multi-tensor amax without doing hadamard transform +void nvte_multi_tensor_amax(const NVTETensor input, NVTETensor* outputs, const int* split_sections, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_hadamard_transform_amax); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + + Tensor* input_tensor = convertNVTETensorCheck(input); + std::vector output_list(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list[i] = convertNVTETensorCheck(outputs[i]); + } + + multi_hadamard_transform_amax(*input_tensor, output_list, split_sections, num_tensors, 0, 0, true, + stream); +} diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index a0dd325da0..67432a7808 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -61,6 +61,38 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! + * \brief Perform multi-tensor Hadamard transform absolute maximum reduction (amax) with optional randomized Hadamard transform. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform, assumed contiguous in memory and split on dimension 0. + * \param[in,out] outputs Array of output tensors. + * \param[in] split_sections Array of splits in dimension 0 for each output tensor. + * \param[in] num_tensors Number of output tensors, must be > 0. + * \param[in] random_sign_mask 16-bit (int) sign mask for transform. + * \param[in] random_sign_mask_t 16-bit (int) sign mask for transform (transposed). + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_hadamard_transform_amax(const NVTETensor input, NVTETensor* outputs, + const int* split_sections, const size_t num_tensors, + int random_sign_mask, int random_sign_mask_t, + cudaStream_t stream); + +/*! + * \brief Perform multi-tensor absolute maximum reduction (amax) without Hadamard transform. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor, assumed contiguous in memory and split on dimension 0. + * \param[in,out] outputs Array of output tensors. + * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. + * \param[in] num_tensors Number of output tensors, must be > 0. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_tensor_amax(const NVTETensor input, NVTETensor* outputs, const int* split_sections, + const size_t num_tensors, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 7d15e436ea..dda29450f0 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -100,10 +100,223 @@ py::object dequantize(const py::handle &input, transformer_engine::DType otype) namespace { -void multi_tensor_quantize_impl(const std::vector &input_list, +void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + NVFP4Quantizer *quantizer) { + // Sanity check input before splitting + if (input.numel() == 0) { + for (size_t i = 0; i < input_list.size(); ++i) { + if (input_list[i].numel() != 0) { + NVTE_CHECK(false, + "NVFP4 multi_quantize: Single input tensor has zero elements but input_list " + "contains non-empty tensor, inconsistent args were provided."); + } + } + return; + } + // split_sections should have the same size with input output list + NVTE_CHECK(input_list.size() == output_list.size(), + "Input and output list must have the same size"); + NVTE_CHECK(split_sections.size() == input_list.size(), + "Split sections must have the same size as input and output list"); + // this function is not responsible for 2D nvfp4 quantization + NVTE_CHECK(quantizer->with_2d_quantization == false, + "NVFP4 multi_quantize: 2D NVFP4 quantization is not supported"); + // multi quantize function doesn't have amax reduction support + NVTE_CHECK(quantizer->with_amax_reduction == false, + "NVFP4 multi_quantize: amax reduction is not supported"); + + size_t num_tensors = split_sections.size(); + + size_t rows = 1; + for (size_t i = 0; i < input.ndim() - 1; ++i) { + rows *= input.size(i); + } + size_t cols = input.size(input.ndim() - 1); + + NVTE_CHECK(cols % 128 == 0, "NVFP4 multi_quantize: number of columns must be a multiple of 128"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + + // Get a list of QuantizationConfigWrapper quant_config + std::vector quant_config_list; + for (size_t i = 0; i < num_tensors; i++) { + quant_config_list.emplace_back(QuantizationConfigWrapper()); + } + + // stochastic rounding support for multi tensor + std::vector te_rng_state_list; + at::Tensor rng_states_tensor; + + // assumes one quantizer doing RS means all quantizers doing RS + if (quantizer->stochastic_rounding) { + // TODO(zhongbo): remove the for loop of generating rng states with a single call + // with rng_elts_per_thread = 1024 * num_tensors + // Change to the bulk generate rng states api when grouped quantize is available + const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); + + for (size_t i = 0; i < num_tensors; ++i) { + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + int64_t *rng_state_ptr = static_cast(rng_states_tensor.data_ptr()) + i * 2; + philox_unpack(philox_args, rng_state_ptr); + te_rng_state_list.push_back(makeTransformerEngineTensor( + static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); + quant_config_list[i].set_rng_state(te_rng_state_list[i].data()); + quant_config_list[i].set_stochastic_rounding(true); + } + } + + // with or without RHT, use nvte_multi_hadamard_transform_amax + // out.amax is the rowwise amax, out.columnwise_amax is the columnwise amax + // rowwise amax will be the amax of original amax(input) + // columnwise amax will be the amax of the amax(RHT(input.t)) + if (quantizer->with_rht) { + // bf16 only for now + NVTE_CHECK(input.dtype() == DType::kBFloat16, + "NVFP4 multi_quantize: RHT is only supported for bfloat16 input"); + if (quantizer->with_post_rht_amax) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for RHT(input.t) + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_hadamard_transform_amax( + input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, 0, quantizer->rht_matrix_random_sign_mask_t, + stream); + }); + } else { + // RHT is enabled, but amax is pre-RHT amax + // Kernel for this ready, but still disable this case since we need to verify recipe convergence first + NVTE_CHECK(false, "NVFP4 multi_quantize: Pre-RHT amax is not supported yet"); + } + } else { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for input too + // Columnwise amax will be filled with a fused D2D copy from rowwise amax + // Note that the multi compute amax API expects rowwise amax pointer to be not null + // So we need to set the pointer accordingly to make colwise-only quantization work + std::vector orig_amax_ptr_list; + for (size_t i = 0; i < num_tensors; i++) { + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + orig_amax_ptr_list.push_back(rowwise_amax_ptr); + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_tensor_amax(input.data(), + reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, stream); + }); + for (size_t i = 0; i < num_tensors; i++) { + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + } + } + + // start with quantize, with or without RHT + if (quantizer->with_rht) { + // check the availablibilty of RHT matrix definition for best perf + NVTE_CHECK(quantizer->rht_matrix.defined() && quantizer->rht_matrix.numel() > 0, + "NVFP4 multi_quantize: RHT matrix is not set"); + auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer->rht_matrix); + + NVTE_SCOPED_GIL_RELEASE({ + for (size_t i = 0; i < num_tensors; i++) { + // skip this round if input is empty + if (input_list[i].numel() == 0) { + continue; + } + if (quantizer->rowwise_usage) { + TensorWrapper out_identity(output_list[i].scaling_mode()); + auto out_identity_data = output_list[i].get_rowwise_data(); + auto out_identity_scale_inv = output_list[i].get_rowwise_scale_inv(); + auto out_identity_amax = output_list[i].get_amax(); + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, + static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input_list[i].data(), out_identity.data(), quant_config_list[i], + stream); + }); + } + + // already eligible for RHT columnwise cast fusion after the dimension check + if (quantizer->columnwise_usage) { + // Get the output columnwise data, scale_inv, and amax + auto out_columnwise_data = output_list[i].get_columnwise_data(); + auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv(); + // NOTE: should already be populated. + auto out_columnwise_amax = output_list[i].get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. + // The reason is due to the input `rht_output_t` is already in the transposed layout. + // Thus, we only need a rowwise quantization to generate the columnwise output. + TensorWrapper out_transpose(output_list[i].scaling_mode()); + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(), + rht_matrix_nvte.data(), + quant_config_list[i], stream); + } + } + }); + } else { + NVTE_SCOPED_GIL_RELEASE({ + for (size_t i = 0; i < num_tensors; i++) { + // skip this round if input is empty + if (input_list[i].numel() == 0) { + continue; + } + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + } + }); + } +} + +void multi_tensor_quantize_impl(const TensorWrapper &single_input, + const std::vector &input_list, std::vector &quantizer_py_list, std::vector> &quantizer_cpp_list, - std::vector &output_list) { + std::vector &output_list, + const std::vector &split_sections) { // Check number of tensors const size_t num_tensors = input_list.size(); NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors, @@ -114,15 +327,47 @@ void multi_tensor_quantize_impl(const std::vector &input_list, " output tensors, but got ", output_list.size()); // Choose implementation - // Note: Currently only have fused kernel for FP8 delayed scaling bool with_fused_kernel = true; + // Set the scaling mode based on the first quantizer's recipe + auto scaling_mode = + quantizer_cpp_list.empty() ? NVTE_INVALID_SCALING : quantizer_cpp_list[0]->get_scaling_mode(); + + // check if split_sections is just a dummy input + bool valid_split_sections = split_sections.size() == num_tensors; + + // Check scaling mode consistency across all tensors for (size_t i = 0; i < num_tensors; i++) { - if (!detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) { - with_fused_kernel = false; - break; - } - if (nvte_tensor_data(output_list[i].data()) == nullptr || - nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) { + if (detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) { + // for fp8 delayed scaling, only fp8 quantize transpose is supported + if (nvte_tensor_data(output_list[i].data()) == nullptr || + nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) { + with_fused_kernel = false; + break; + } + // check if the scaling mode is fp8 delayed scaling for all quantizers + if (scaling_mode != NVTE_DELAYED_TENSOR_SCALING) { + with_fused_kernel = false; + break; + } + } else if (detail::IsNVFP4Quantizers(quantizer_py_list[i].ptr())) { + // check if the list of quantizers are all NVFP4 quantizers + if (scaling_mode != NVTE_NVFP4_1D_SCALING) { + with_fused_kernel = false; + break; + } + // the nvfp4 fused kernels also have dimension limit for the split_sections + if (valid_split_sections) { + // break if each split_sections is not 64 multiple + if (split_sections[i] % 64 != 0) { + with_fused_kernel = false; + break; + } + } else { + with_fused_kernel = false; + break; + } + + } else { with_fused_kernel = false; break; } @@ -130,17 +375,33 @@ void multi_tensor_quantize_impl(const std::vector &input_list, // Launch TE kernel if (with_fused_kernel) { - // Fused kernel for multi-tensor quantize - std::vector nvte_tensor_input_list; - std::vector nvte_tensor_output_list; - for (size_t i = 0; i < num_tensors; ++i) { - nvte_tensor_input_list.push_back(input_list[i].data()); - nvte_tensor_output_list.push_back(output_list[i].data()); + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + // Fused kernel for multi-tensor quantize + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), + at::cuda::getCurrentCUDAStream()); + }); + break; + } + case NVTE_NVFP4_1D_SCALING: { + auto nvfp4_quantizer = dynamic_cast(quantizer_cpp_list[0].get()); + multi_tensor_quantize_nvfp4_impl(single_input, input_list, output_list, split_sections, + nvfp4_quantizer); + break; + } + default: + NVTE_ERROR( + "Fused multi-tensor quantize is only supported for FP8 delayed scaling and NVFP4 1D " + "scaling"); } - NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), - nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); - }); } else { // Quantize kernels individually for (size_t i = 0; i < num_tensors; ++i) { @@ -184,8 +445,13 @@ std::vector multi_tensor_quantize(const std::vector &ten output_py_list.emplace_back(std::move(output_py)); } + // Prepare for multi-tensor quantization. + // Use empty split_sections and a dummy input wrapper, since the tensors are already individually provided. + std::vector dummy_split_sections; + TensorWrapper dummy_input_wrapper; // Perform multi-tensor quantization - multi_tensor_quantize_impl(input_cpp_list, quantizer_list, quantizer_cpp_list, output_cpp_list); + multi_tensor_quantize_impl(dummy_input_wrapper, input_cpp_list, quantizer_list, + quantizer_cpp_list, output_cpp_list, dummy_split_sections); return output_py_list; } @@ -584,7 +850,7 @@ std::tuple, std::vector> bulk_allocate_nv rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); } } @@ -639,7 +905,7 @@ std::tuple, std::vector> bulk_allocate_nv columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); } } @@ -718,6 +984,8 @@ std::vector split_quantize(const at::Tensor &tensor, input_size *= d; } NVTE_CHECK(input_shape.size() > 0, "Input tensor has 0 dims"); + // get a single tensor wrapper of the input + TensorWrapper input_wrapper = makeTransformerEngineTensor(input_dptr, input_shape, input_dtype); // Split input tensor along dim 0 std::vector input_list; @@ -803,7 +1071,8 @@ std::vector split_quantize(const at::Tensor &tensor, } // Perform multi-tensor quantization - multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); + multi_tensor_quantize_impl(input_wrapper, input_list, quantizer_list, quantizer_cpp_list, + output_cpp_list, split_sections); return output_py_list; } diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index fca89fbaa9..50743dd939 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization from ..jit import no_torch_dynamo @@ -111,14 +111,8 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = ( - 32 - if ( - FP8GlobalStateManager.get_fp8_recipe().mxfp8() - or FP8GlobalStateManager.get_fp8_recipe().nvfp4() - ) - else 16 - ) + recipe = FP8GlobalStateManager.get_fp8_recipe() + self.align_size = get_align_size_for_quantization(recipe) # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 7a01f15729..d1a1565981 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization from ..jit import no_torch_dynamo @@ -109,14 +109,8 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = ( - 32 - if ( - FP8GlobalStateManager.get_fp8_recipe().mxfp8() - or FP8GlobalStateManager.get_fp8_recipe().nvfp4() - ) - else 16 - ) + recipe = FP8GlobalStateManager.get_fp8_recipe() + self.align_size = get_align_size_for_quantization(recipe) # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 030370b9db..1286c5f619 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -40,6 +40,7 @@ "is_fp8_block_scaling_available", "is_nvfp4_available", "get_default_recipe", + "get_align_size_for_quantization", ] @@ -114,6 +115,15 @@ def get_default_recipe() -> Recipe: return get_default_fp8_recipe() +def get_align_size_for_quantization(recipe: Recipe): + """Get the alignment size for quantization.""" + if recipe.mxfp8(): + return 32 + if recipe.nvfp4(): + return 64 + return 16 + + def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or (