diff --git a/include/tritonblas/__init__.py b/include/tritonblas/__init__.py index 78a46ef..1e8c687 100644 --- a/include/tritonblas/__init__.py +++ b/include/tritonblas/__init__.py @@ -2,3 +2,6 @@ from .matmul import matmul_lt, matmul_a8w8_lt from .matmul import matmul_fp4 from .origami import MatmulHeuristicResult +from .hadamard import hadamard_blocked_fast +from .fused_mxfp4_quant import fused_rms_mxfp4_quant, fused_rms_hadamard_mxfp4_quant, fused_mxfp4_quant +from .rmsnorm import rms_norm, rmsnorm2d_fwd_with_dynamicquant \ No newline at end of file diff --git a/include/tritonblas/fused_mxfp4_quant.py b/include/tritonblas/fused_mxfp4_quant.py new file mode 100644 index 0000000..0335ae3 --- /dev/null +++ b/include/tritonblas/fused_mxfp4_quant.py @@ -0,0 +1,409 @@ +import torch +import triton +import triton.language as tl +from typing import Optional + +from tritonblas.kernels.fused_mxfp4_quant import ( + _rmsmorm_op, + _fused_rms_mxfp4_quant_kernel, + _fused_flatten_mxfp4_quant, + _fused_mxfp4_quant_kernel, + _fused_rms_hadamard_mxfp4_quant_kernel +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + + +def fused_mxfp4_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: float = 0.0, + res1: Optional[torch.Tensor] = None, + shuffle: Optional[bool] = False, + scale_shuffle_padding: Optional[bool] = False, +): + """ + This op contains several steps: + 1. if res1 is not None, x1 = x1 + res1, and store x1 to out_res1 + 2. perform RMS norm along the last dimenion for x1 + 3. if x2 is not None, perform RMS norm along the last dimenion for x2 + 4. perform mxfp4 quantization for x1 only + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out1_fp4: The output matrix with shape (M, N1 // 2). + - out1_bs: The output matrix with shape (M, cdiv(N1, MXFP4_QUANT_BLOCK_SIZE)). + - out2: The output matrix with shape (M, N2). + - out_res1: The output matrix with shape (M, N1). + + always returns (out1_fp4, out1_bs), out2, out_res1 + """ + _LOGGER.info(f"FUSED_RMS_MXFP4_QUANT: inp1={tuple(x1.shape)}") + + MXFP4_QUANT_BLOCK_SIZE = 32 + M, N1 = x1.shape + BLOCK_SIZE_N = max(triton.next_power_of_2(N1), MXFP4_QUANT_BLOCK_SIZE) + BLOCK_SIZE_N2 = 1 + if x2 is not None: + N2 = x2.shape[1] + BLOCK_SIZE_N2 = triton.next_power_of_2(N2) + else: + N2 = 0 + # as we merge 2 fp4s to 1 uint8 + assert N1 % 2 == 0 + BLOCK_SIZE_M = 1 + # BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = max(BLOCK_SIZE_N, MXFP4_QUANT_BLOCK_SIZE) + out1_fp4 = torch.empty((M, N1 // 2), dtype=torch.uint8, device=x1.device) + SCALE_N_valid = triton.cdiv(N1, MXFP4_QUANT_BLOCK_SIZE) + use_scale_shuffle_padding = shuffle or scale_shuffle_padding + if use_scale_shuffle_padding: + SCALE_M = triton.cdiv(M, 256) * 256 + SCALE_N = triton.cdiv(SCALE_N_valid, 8) * 8 + # BLOCK_SIZE_M = triton.cdiv(BLOCK_SIZE_M, 32) * 32 + BLOCK_SIZE_N = triton.cdiv(BLOCK_SIZE_N, 32) * 32 + else: + SCALE_M = M + SCALE_N = SCALE_N_valid + out1_bs = torch.empty( + (SCALE_M, SCALE_N), + dtype=torch.uint8, + device=x1.device, + ) + + out_res1 = None + res1_stride_m = 0 + out_res1_stride_m = 0 + if res1 is not None: + out_res1 = torch.empty((M, N1), dtype=x1.dtype, device=x1.device) + res1_stride_m = res1.stride(0) + out_res1_stride_m = out_res1.stride(0) + + out2 = None + out2_stride_m = 0 + x2_stride_m = 0 + if x2 is not None: + out2 = torch.empty((M, N2), dtype=x1.dtype, device=x1.device) + x2_stride_m = x2.stride(0) + out2_stride_m = out2.stride(0) + + grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) + _fused_mxfp4_quant_kernel[grid]( + x1, + x1_weight, + x2, + x2_weight, + res1, + out1_fp4, + out1_bs, + out2, + out_res1, + x1_epsilon, + x2_epsilon, + M, + N1, + N2, + x1.stride(0), + x2_stride_m, + res1_stride_m, + out1_fp4.stride(0), + *out1_bs.stride(), + out2_stride_m, + out_res1_stride_m, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + HAS_SECOND_INPUT=(x2 is not None), + FIRST_INPUT_RES=(res1 is not None), + SCALE_N=SCALE_N_valid, + SCALE_M_PAD=(SCALE_M if use_scale_shuffle_padding else 1), + SCALE_N_PAD=SCALE_N, + SHUFFLE=shuffle, + SHUFFLE_PAD=use_scale_shuffle_padding, + ) + + return (out1_fp4, out1_bs), out2, out_res1 + +def fused_rms_hadamard_mxfp4_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: float = 0.0, + res1: Optional[torch.Tensor] = None, + shuffle: Optional[bool] = False, + scale_shuffle_padding: Optional[bool] = False, +): + """ + This op contains several steps: + 1. if res1 is not None, x1 = x1 + res1, and store x1 to out_res1 + 2. perform RMS norm along the last dimenion for x1 + 3. if x2 is not None, perform RMS norm along the last dimenion for x2 + 4. perform mxfp4 quantization for x1 only + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out1_fp4: The output matrix with shape (M, N1 // 2). + - out1_bs: The output matrix with shape (M, cdiv(N1, MXFP4_QUANT_BLOCK_SIZE)). + - out2: The output matrix with shape (M, N2). + - out_res1: The output matrix with shape (M, N1). + + always returns (out1_fp4, out1_bs), out2, out_res1 + """ + _LOGGER.info(f"FUSED_RMS_MXFP4_QUANT: inp1={tuple(x1.shape)}") + + MXFP4_QUANT_BLOCK_SIZE = 32 + M, N1 = x1.shape + BLOCK_SIZE_N = max(triton.next_power_of_2(N1), MXFP4_QUANT_BLOCK_SIZE) + BLOCK_SIZE_N2 = 1 + if x2 is not None: + N2 = x2.shape[1] + BLOCK_SIZE_N2 = triton.next_power_of_2(N2) + else: + N2 = 0 + # as we merge 2 fp4s to 1 uint8 + assert N1 % 2 == 0 + BLOCK_SIZE_M = 1 + # BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = max(BLOCK_SIZE_N, MXFP4_QUANT_BLOCK_SIZE) + out1_fp4 = torch.empty((M, N1 // 2), dtype=torch.uint8, device=x1.device) + SCALE_N_valid = triton.cdiv(N1, MXFP4_QUANT_BLOCK_SIZE) + use_scale_shuffle_padding = shuffle or scale_shuffle_padding + if use_scale_shuffle_padding: + SCALE_M = triton.cdiv(M, 256) * 256 + SCALE_N = triton.cdiv(SCALE_N_valid, 8) * 8 + # BLOCK_SIZE_M = triton.cdiv(BLOCK_SIZE_M, 32) * 32 + BLOCK_SIZE_N = triton.cdiv(BLOCK_SIZE_N, 32) * 32 + else: + SCALE_M = M + SCALE_N = SCALE_N_valid + out1_bs = torch.empty( + (SCALE_M, SCALE_N), + dtype=torch.uint8, + device=x1.device, + ) + + out_res1 = None + res1_stride_m = 0 + out_res1_stride_m = 0 + if res1 is not None: + out_res1 = torch.empty((M, N1), dtype=x1.dtype, device=x1.device) + res1_stride_m = res1.stride(0) + out_res1_stride_m = out_res1.stride(0) + + out2 = None + out2_stride_m = 0 + x2_stride_m = 0 + if x2 is not None: + out2 = torch.empty((M, N2), dtype=x1.dtype, device=x1.device) + x2_stride_m = x2.stride(0) + out2_stride_m = out2.stride(0) + + grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) + _fused_rms_hadamard_mxfp4_quant_kernel[grid]( + x1, + x1_weight, + x2, + x2_weight, + res1, + out1_fp4, + out1_bs, + out2, + out_res1, + x1_epsilon, + x2_epsilon, + M, + N1, + N2, + x1.stride(0), + x2_stride_m, + res1_stride_m, + out1_fp4.stride(0), + *out1_bs.stride(), + out2_stride_m, + out_res1_stride_m, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + HAS_SECOND_INPUT=(x2 is not None), + FIRST_INPUT_RES=(res1 is not None), + SCALE_N=SCALE_N_valid, + SCALE_M_PAD=(SCALE_M if use_scale_shuffle_padding else 1), + SCALE_N_PAD=SCALE_N, + SHUFFLE=shuffle, + SHUFFLE_PAD=use_scale_shuffle_padding, + ) + + return (out1_fp4, out1_bs), out2, out_res1 + +def fused_rms_mxfp4_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: float = 0.0, + res1: Optional[torch.Tensor] = None, + shuffle: Optional[bool] = False, + scale_shuffle_padding: Optional[bool] = False, +): + """ + This op contains several steps: + 1. if res1 is not None, x1 = x1 + res1, and store x1 to out_res1 + 2. perform RMS norm along the last dimenion for x1 + 3. if x2 is not None, perform RMS norm along the last dimenion for x2 + 4. perform mxfp4 quantization for x1 only + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out1_fp4: The output matrix with shape (M, N1 // 2). + - out1_bs: The output matrix with shape (M, cdiv(N1, MXFP4_QUANT_BLOCK_SIZE)). + - out2: The output matrix with shape (M, N2). + - out_res1: The output matrix with shape (M, N1). + + always returns (out1_fp4, out1_bs), out2, out_res1 + """ + _LOGGER.info(f"FUSED_RMS_MXFP4_QUANT: inp1={tuple(x1.shape)}") + + MXFP4_QUANT_BLOCK_SIZE = 32 + M, N1 = x1.shape + BLOCK_SIZE_N = max(triton.next_power_of_2(N1), MXFP4_QUANT_BLOCK_SIZE) + BLOCK_SIZE_N2 = 1 + if x2 is not None: + N2 = x2.shape[1] + BLOCK_SIZE_N2 = triton.next_power_of_2(N2) + else: + N2 = 0 + # as we merge 2 fp4s to 1 uint8 + assert N1 % 2 == 0 + BLOCK_SIZE_M = 1 + # BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = max(BLOCK_SIZE_N, MXFP4_QUANT_BLOCK_SIZE) + out1_fp4 = torch.empty((M, N1 // 2), dtype=torch.uint8, device=x1.device) + SCALE_N_valid = triton.cdiv(N1, MXFP4_QUANT_BLOCK_SIZE) + use_scale_shuffle_padding = shuffle or scale_shuffle_padding + if use_scale_shuffle_padding: + SCALE_M = triton.cdiv(M, 256) * 256 + SCALE_N = triton.cdiv(SCALE_N_valid, 8) * 8 + # BLOCK_SIZE_M = triton.cdiv(BLOCK_SIZE_M, 32) * 32 + BLOCK_SIZE_N = triton.cdiv(BLOCK_SIZE_N, 32) * 32 + else: + SCALE_M = M + SCALE_N = SCALE_N_valid + out1_bs = torch.empty( + (SCALE_M, SCALE_N), + dtype=torch.uint8, + device=x1.device, + ) + + out_res1 = None + res1_stride_m = 0 + out_res1_stride_m = 0 + if res1 is not None: + out_res1 = torch.empty((M, N1), dtype=x1.dtype, device=x1.device) + res1_stride_m = res1.stride(0) + out_res1_stride_m = out_res1.stride(0) + + out2 = None + out2_stride_m = 0 + x2_stride_m = 0 + if x2 is not None: + out2 = torch.empty((M, N2), dtype=x1.dtype, device=x1.device) + x2_stride_m = x2.stride(0) + out2_stride_m = out2.stride(0) + + grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) + _fused_rms_mxfp4_quant_kernel[grid]( + x1, + x1_weight, + x2, + x2_weight, + res1, + out1_fp4, + out1_bs, + out2, + out_res1, + x1_epsilon, + x2_epsilon, + M, + N1, + N2, + x1.stride(0), + x2_stride_m, + res1_stride_m, + out1_fp4.stride(0), + *out1_bs.stride(), + out2_stride_m, + out_res1_stride_m, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + HAS_SECOND_INPUT=(x2 is not None), + FIRST_INPUT_RES=(res1 is not None), + SCALE_N=SCALE_N_valid, + SCALE_M_PAD=(SCALE_M if use_scale_shuffle_padding else 1), + SCALE_N_PAD=SCALE_N, + SHUFFLE=shuffle, + SHUFFLE_PAD=use_scale_shuffle_padding, + ) + + return (out1_fp4, out1_bs), out2, out_res1 + + +def fused_flatten_mxfp4_quant( + x: torch.Tensor, +): + """ + Flatten the last two dimension of x and perform mxfp4 quantization along the last dimension + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out: The output matrix with shape (M, (N1 * N2) // 2). + - out_block_scales: The output matrix with shape (M, cdiv(N1 * N2, MXFP4_QUANT_BLOCK_SIZE)). + """ + _LOGGER.info(f"FUSED_FLATTEN_MXFP4_QUANT: x={tuple(x.shape)}") + M, N1, N2 = x.shape + + MXFP4_QUANT_BLOCK_SIZE = 32 + BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), MXFP4_QUANT_BLOCK_SIZE) + N = N1 * N2 + out = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) + out_block_scales = torch.empty( + (triton.cdiv(N, MXFP4_QUANT_BLOCK_SIZE), M), + dtype=torch.uint8, + device=x.device, + ).T + + grid = ( + M, + N1, + ) + _fused_flatten_mxfp4_quant[grid]( + x, + out, + out_block_scales, + *x.stride(), + *out.stride(), + *out_block_scales.stride(), + N2, + BLOCK_SIZE_N2, + MXFP4_QUANT_BLOCK_SIZE, + ) + + return out, out_block_scales diff --git a/include/tritonblas/hadamard.py b/include/tritonblas/hadamard.py new file mode 100644 index 0000000..83961b3 --- /dev/null +++ b/include/tritonblas/hadamard.py @@ -0,0 +1,152 @@ +import triton +import triton.language as tl +import torch + +# @triton.jit +# def build_H(SIZE: tl.constexpr, dtype: tl.constexpr): +# r""" +# Construct small Hadamard matrices, in such a way that Triton can optimize the code away. +# This uses the identity $H_{i,j} = (-1)^{i \cdot j}$, +# where the operation $\cdot$ is the BITWISE dot product of integers. +# """ +# tl.static_assert(0 < SIZE) +# tl.static_assert(SIZE <= 64) + +# i = tl.arange(0, SIZE) +# j = tl.arange(0, SIZE) +# matching_bits = (i[:, None] & j) + +# bit_sum = tl.zeros_like(matching_bits) +# for i in tl.static_range(5): +# bit_sum += matching_bits & 1 +# matching_bits >>= 1 + +# # map odd to -1, even to 1 +# H = 2 * ((bit_sum % 2) == 0) - 1 +# return H.cast(dtype) + +@triton.jit +def build_H(SIZE: tl.constexpr): + r""" + Construct Hadamard matrix using H_{i,j} = (-1)^{popcount(i & j)}. + + This computes the bitwise dot product of row index i and column index j: + - popcount(i & j) counts the number of matching 1-bits + - If count is even: H_{i,j} = 1 + - If count is odd: H_{i,j} = -1 + + Args: + SIZE: Matrix dimension (must be power of 2, max 64) + dtype: Output data type + + Returns: + SIZE x SIZE Hadamard matrix + """ + tl.static_assert(0 < SIZE) + tl.static_assert(SIZE <= 64) # extend to 128 ? + + # Create row and column indices + i = tl.arange(0, SIZE) + j = tl.arange(0, SIZE) + + # Compute bitwise AND for all (i, j) pairs + matching_bits = i[:, None] & j[None, :] + + # Count set bits (popcount) - simple iterative approach + bit_sum = tl.zeros_like(matching_bits) + temp = matching_bits + for _ in tl.static_range(6): # 6 iterations for up to 64 bits + bit_sum += temp & 1 + temp >>= 1 + + # Map: even popcount -> +1, odd popcount -> -1 + H = 1 - 2 * (bit_sum & 1) + + # normalize by sqrt(d) + H = H / tl.math.sqrt(float(SIZE)) + return H + + +@triton.jit +def hadamard_blocked_kernel( + A_ptr, # Pointer to input matrix A [M, K] + Out_ptr, # Pointer to output matrix [M, K] + M, # Number of rows in A + K, # Number of columns in A + stride_am, # Stride of A in M dimension + stride_ak, # Stride of A in K dimension + stride_om, # Stride of output in M dimension + stride_ok, # Stride of output in K dimension + BLOCK_SIZE: tl.constexpr, # Block size (32) +): + """ + Kernel that applies Hadamard transformation to each 32x32 block of A. + + Each program processes one 32x32 block independently: + Output[m_block, k_block] = A[m_block, k_block] @ H + """ + # Get program IDs for M and K dimensions + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + + # Compute starting indices for this block + m_start = pid_m * BLOCK_SIZE + k_start = pid_k * BLOCK_SIZE + + # Create offset ranges + m_offs = m_start + tl.arange(0, BLOCK_SIZE) + k_offs = k_start + tl.arange(0, BLOCK_SIZE) + + # Create masks for boundary conditions + m_mask = m_offs < M + k_mask = k_offs < K + + # Load A block [BLOCK_SIZE, BLOCK_SIZE] + a_ptrs = A_ptr + m_offs[:, None] * stride_am + k_offs[None, :] * stride_ak + a_block = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0) + + # Materialize Hadamard matrix [BLOCK_SIZE, BLOCK_SIZE] + h_block = build_H(BLOCK_SIZE) + + # Perform matrix multiplication: A_block @ H_block + # This is a single 32x32 @ 32x32 operation + result = tl.dot(a_block, h_block.to(a_block.dtype)) + + # Store result to output + out_ptrs = Out_ptr + m_offs[:, None] * stride_om + k_offs[None, :] * stride_ok + tl.store(out_ptrs, result, mask=m_mask[:, None] & k_mask[None, :]) + + +def hadamard_blocked_fast(A: torch.Tensor) -> torch.Tensor: + """ + Apply Hadamard transformation to each 32x32 block of matrix A. + + Args: + A: Input matrix of shape [M, K] + + Returns: + Output matrix of shape [M, K] with each 32x32 block transformed + """ + assert A.is_cuda, "Tensors must be on CUDA" + + M, K = A.shape + + # Allocate output with same shape as A + Out = torch.zeros_like(A) + + # Define block size + BLOCK_SIZE = 32 + + # Calculate grid dimensions - one program per 32x32 block + grid = (triton.cdiv(M, BLOCK_SIZE), triton.cdiv(K, BLOCK_SIZE)) + + # Launch kernel + hadamard_blocked_kernel[grid]( + A, Out, + M, K, + A.stride(0), A.stride(1), + Out.stride(0), Out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + return Out \ No newline at end of file diff --git a/include/tritonblas/kernels/__init__.py b/include/tritonblas/kernels/__init__.py index b7ea70c..004d41e 100644 --- a/include/tritonblas/kernels/__init__.py +++ b/include/tritonblas/kernels/__init__.py @@ -29,6 +29,9 @@ # FP4 kernel from .fp4_matmul import fp4_matmul +# fused mxfp4 with rms +from .fused_mxfp4_quant import _fused_rms_hadamard_mxfp4_quant_kernel, _fused_mxfp4_quant_kernel + # Export stages submodule from . import stages diff --git a/include/tritonblas/kernels/fused_mxfp4_quant.py b/include/tritonblas/kernels/fused_mxfp4_quant.py new file mode 100644 index 0000000..2f46377 --- /dev/null +++ b/include/tritonblas/kernels/fused_mxfp4_quant.py @@ -0,0 +1,844 @@ +import triton +import triton.language as tl + +@triton.jit +def build_H(SIZE: tl.constexpr): + r""" + Construct Hadamard matrix using H_{i,j} = (-1)^{popcount(i & j)}. + + This computes the bitwise dot product of row index i and column index j: + - popcount(i & j) counts the number of matching 1-bits + - If count is even: H_{i,j} = 1 + - If count is odd: H_{i,j} = -1 + + Args: + SIZE: Matrix dimension (must be power of 2, max 64) + dtype: Output data type + + Returns: + SIZE x SIZE Hadamard matrix + """ + tl.static_assert(0 < SIZE) + tl.static_assert(SIZE <= 64) # extend to 128 ? + + # Create row and column indices + i = tl.arange(0, SIZE) + j = tl.arange(0, SIZE) + + # Compute bitwise AND for all (i, j) pairs + matching_bits = i[:, None] & j[None, :] + + # Count set bits (popcount) - simple iterative approach + bit_sum = tl.zeros_like(matching_bits) + temp = matching_bits + for _ in tl.static_range(6): # 6 iterations for up to 64 bits + bit_sum += temp & 1 + temp >>= 1 + + # Map: even popcount -> +1, odd popcount -> -1 + H = 1 - 2 * (bit_sum & 1) + + # normalize by sqrt(d) + H = H / tl.math.sqrt(float(SIZE)) + + return H + +@triton.jit +def _hadamard_mxfp4_quant_op( + x, + BLOCK_SIZE_N, + BLOCK_SIZE_M, + MXFP4_QUANT_BLOCK_SIZE, +): + """ + Converts given x (in fp32) to mxfp4 format. + x: [BLOCK_SIZE_M, BLOCK_SIZE_N], fp32 + + """ + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x = x.reshape(BLOCK_SIZE_M * NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE) + + # to print: tl.device_print to print values @ runtime maybe.... + + # add hadamard here + h_block = build_H(MXFP4_QUANT_BLOCK_SIZE).to(x.dtype) + x = tl.dot(x, h_block) + x = x.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE) + # 1 x 32 block size + # Calculate scale + amax = tl.max(tl.abs(x), axis=-1, keep_dims=True) # 32 x 1 scales, 32 threads doing 1 x 32 parallel across threads + amax = amax.to(tl.int32, bitcast=True) + amax = (amax + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax = amax.to(tl.float32, bitcast=True) + scale_e8m0_unbiased = tl.log2(amax).floor() - 2 + scale_e8m0_unbiased = tl.clamp(scale_e8m0_unbiased, min=-127, max=127) + + # blockscale_e8m0 + bs_e8m0 = scale_e8m0_unbiased.to(tl.uint8) + 127 # in fp32, we have 2&(e - 127) + + quant_scale = tl.exp2(-scale_e8m0_unbiased) + + # Compute quantized x + qx = x * quant_scale + + # Convert quantized fp32 tensor to uint32 before converting to mxfp4 format + # Note: MXFP4 S:1-bit, E:2-bit, M:1-bit + # Zeros: S000 -> +/-0 + # Denormal Numbers: S001 -> +/- 0.5 + # Normal Numbers: + # S010 -> +/- 1.0 + # S011 -> +/- 1.5 + # S100 -> +/- 2.0 + # S101 -> +/- 3.0 + # S110 -> +/- 4.0 + # S111 -> +/- 6.0 + qx = qx.to(tl.uint32, bitcast=True) + + # Extract sign, exponents and mantissa fields from FP32 + s = qx & 0x80000000 + e = (qx >> 23) & 0xFF + m = qx & 0x7FFFFF + E8_BIAS: tl.constexpr = 127 + E2_BIAS: tl.constexpr = 1 + + # Denormal numbers + # If exponent is less than 127, then it's a denormal number + # See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa + adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False) + m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m) + # For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0. + # Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that. + e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine sign, exponent, and mantissa, while saturating + # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7) + e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8) + e2m1_value = tl.reshape( + e2m1_value, [BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE // 2, 2] + ) + evens, odds = tl.split(e2m1_value) + x_fp4 = evens | (odds << 4) + x_fp4 = x_fp4.reshape(BLOCK_SIZE_M, BLOCK_SIZE_N // 2) + + return x_fp4, bs_e8m0.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS) + +@triton.jit +def _mxfp4_quant_op( + x, + BLOCK_SIZE_N, + BLOCK_SIZE_M, + MXFP4_QUANT_BLOCK_SIZE, +): + """ + Converts given x (in fp32) to mxfp4 format. + x: [BLOCK_SIZE_M, BLOCK_SIZE_N], fp32 + + """ + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x = x.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE) + + # to print: tl.device_print to print values @ runtime maybe.... + + # add hadamard here + # tl.dot (x, hadamard) + + # 1 x 32 block size + # Calculate scale + amax = tl.max(tl.abs(x), axis=-1, keep_dims=True) # 32 x 1 scales, 32 threads doing 1 x 32 parallel across threads + amax = amax.to(tl.int32, bitcast=True) + amax = (amax + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax = amax.to(tl.float32, bitcast=True) + scale_e8m0_unbiased = tl.log2(amax).floor() - 2 + scale_e8m0_unbiased = tl.clamp(scale_e8m0_unbiased, min=-127, max=127) + + # blockscale_e8m0 + bs_e8m0 = scale_e8m0_unbiased.to(tl.uint8) + 127 # in fp32, we have 2&(e - 127) + + quant_scale = tl.exp2(-scale_e8m0_unbiased) + + # Compute quantized x + qx = x * quant_scale + + # Convert quantized fp32 tensor to uint32 before converting to mxfp4 format + # Note: MXFP4 S:1-bit, E:2-bit, M:1-bit + # Zeros: S000 -> +/-0 + # Denormal Numbers: S001 -> +/- 0.5 + # Normal Numbers: + # S010 -> +/- 1.0 + # S011 -> +/- 1.5 + # S100 -> +/- 2.0 + # S101 -> +/- 3.0 + # S110 -> +/- 4.0 + # S111 -> +/- 6.0 + qx = qx.to(tl.uint32, bitcast=True) + + # Extract sign, exponents and mantissa fields from FP32 + s = qx & 0x80000000 + e = (qx >> 23) & 0xFF + m = qx & 0x7FFFFF + E8_BIAS: tl.constexpr = 127 + E2_BIAS: tl.constexpr = 1 + + # Denormal numbers + # If exponent is less than 127, then it's a denormal number + # See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa + adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False) + m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m) + # For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0. + # Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that. + e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine sign, exponent, and mantissa, while saturating + # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7) + e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8) + e2m1_value = tl.reshape( + e2m1_value, [BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE // 2, 2] + ) + evens, odds = tl.split(e2m1_value) + x_fp4 = evens | (odds << 4) + x_fp4 = x_fp4.reshape(BLOCK_SIZE_M, BLOCK_SIZE_N // 2) + + return x_fp4, bs_e8m0.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS) + + +@triton.jit +def _rmsmorm_op(row, weight, n_cols, epsilon): + row_norm = row * row + row_norm = tl.sum(row_norm, axis=-1) + norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon) + + rms_norm = row * norm_factor[:, None] * weight + return rms_norm + + +@triton.heuristics( + { + "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, + "EVEN_M_N2": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N2"] % (args["BLOCK_SIZE_N2"]) == 0, + } +) +@triton.jit +def _fused_mxfp4_quant_kernel( + x1_ptr, + w1_ptr, + x2_ptr, + w2_ptr, + res1_ptr, + out1_fp4_ptr, + out1_bs_ptr, + out2_ptr, + out_res1_ptr, + eps1, + eps2, + M, + N1, + N2, + x1_stride_m, + x2_stride_m, + res1_stride_m, + out1_fp4_stride_m, + out1_bs_stride_m, + out1_bs_stride_n, + out2_stride_m, + out_res1_stride_m, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_N2: tl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: tl.constexpr, + HAS_SECOND_INPUT: tl.constexpr, + FIRST_INPUT_RES: tl.constexpr, + SCALE_N: tl.constexpr, + SCALE_M_PAD: tl.constexpr, + SCALE_N_PAD: tl.constexpr, + SHUFFLE: tl.constexpr, + SHUFFLE_PAD: tl.constexpr, + EVEN_M_N: tl.constexpr, + EVEN_M_N2: tl.constexpr, +): + # TODO: XCD remapping where every 32-token block should share the same XCD + # TODO: debug for large M + # TODO: investigate cache_modifier='.cg' on tl.store + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + + if pid >= num_pid_m: + if HAS_SECOND_INPUT: + pid -= num_pid_m + x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + x_offs_n2 = tl.arange(0, BLOCK_SIZE_N2) + mask2 = None + other2 = None + if not EVEN_M_N2: + mask2 = (x_offs_m < M)[:, None] & (x_offs_n2 < N2)[None, :] + other2 = 0.0 + + x2 = tl.load( + x2_ptr + x_offs_m[:, None] * x2_stride_m + x_offs_n2[None, :], + mask=mask2, + other=other2, + cache_modifier=".cg", + ).to(tl.float32) + + w_mask2 = None + w_other2 = None + if not EVEN_M_N2: + w_mask2 = x_offs_n2 < N2 + w_other2 = 0.0 + + w2 = tl.load(w2_ptr + x_offs_n2, mask=w_mask2, other=w_other2).to( + tl.float32 + ) + + norm2 = _rmsmorm_op(x2, w2, N2, eps2) + + tl.store( + out2_ptr + x_offs_m[:, None] * out2_stride_m + x_offs_n2[None, :], + norm2.to(out2_ptr.type.element_ty), + mask=mask2, + cache_modifier=".cg", + ) + return + + x_offs_n = tl.arange(0, BLOCK_SIZE_N) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + mask1 = None + other1 = None + if not EVEN_M_N: + mask1 = (x_offs_m < M)[:, None] & (x_offs_n < N1)[None, :] + other1 = 0.0 + + x1 = tl.load( + x1_ptr + x_offs_m[:, None] * x1_stride_m + x_offs_n[None, :], + mask=mask1, + other=other1, + cache_modifier=".cg", + ).to(tl.float32) + + if FIRST_INPUT_RES: + res1 = tl.load( + res1_ptr + x_offs_m[:, None] * res1_stride_m + x_offs_n[None, :], + mask=mask1, + other=other1, + cache_modifier=".cg", + ).to(tl.float32) + x1 = x1 + res1 + + w_mask1 = None + w_other1 = None + if not EVEN_M_N: + w_mask1 = x_offs_n < N1 + w_other1 = 0.0 + + w1 = tl.load(w1_ptr + x_offs_n, mask=w_mask1, other=w_other1).to(tl.float32) + + norm1 = _rmsmorm_op(x1, w1, N1, eps1) + out1_fp4, bs_e8m0 = _hadamard_mxfp4_quant_op( + x1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) + + # store the results + half_x_offs_n = tl.arange(0, BLOCK_SIZE_N // 2) + out_mask1 = None + if not EVEN_M_N: + out_mask1 = (x_offs_m < M)[:, None] & (half_x_offs_n < (N1 // 2))[None, :] + + tl.store( + out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], + out1_fp4, + mask=out_mask1, + cache_modifier=".cg", + ) + + bs_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + bs_offs_n = tl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] // 32 + bs_offs_1 = bs_offs_m[:, None] % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = bs_offs_n[None, :] // 8 + bs_offs_4 = bs_offs_n[None, :] % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] + bs_e8m0 = tl.where(bs_mask_127, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * out1_bs_stride_m + + bs_offs_n[None, :] * out1_bs_stride_n + ) + + bs_mask = None + if not EVEN_M_N: + if SHUFFLE_PAD: + bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ + None, : + ] + else: + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + + tl.store( + out1_bs_ptr + bs_offs, + bs_e8m0.to(out1_bs_ptr.type.element_ty), + mask=bs_mask, + cache_modifier=".cg", + ) + + if FIRST_INPUT_RES: + tl.store( + out_res1_ptr + x_offs_m[:, None] * out_res1_stride_m + x_offs_n[None, :], + x1.to(out_res1_ptr.dtype.element_ty), + mask=mask1, + cache_modifier=".cg", + ) + + +@triton.heuristics( + { + "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, + "EVEN_M_N2": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N2"] % (args["BLOCK_SIZE_N2"]) == 0, + } +) +@triton.jit +def _fused_rms_hadamard_mxfp4_quant_kernel( + x1_ptr, + w1_ptr, + x2_ptr, + w2_ptr, + res1_ptr, + out1_fp4_ptr, + out1_bs_ptr, + out2_ptr, + out_res1_ptr, + eps1, + eps2, + M, + N1, + N2, + x1_stride_m, + x2_stride_m, + res1_stride_m, + out1_fp4_stride_m, + out1_bs_stride_m, + out1_bs_stride_n, + out2_stride_m, + out_res1_stride_m, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_N2: tl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: tl.constexpr, + HAS_SECOND_INPUT: tl.constexpr, + FIRST_INPUT_RES: tl.constexpr, + SCALE_N: tl.constexpr, + SCALE_M_PAD: tl.constexpr, + SCALE_N_PAD: tl.constexpr, + SHUFFLE: tl.constexpr, + SHUFFLE_PAD: tl.constexpr, + EVEN_M_N: tl.constexpr, + EVEN_M_N2: tl.constexpr, +): + # TODO: XCD remapping where every 32-token block should share the same XCD + # TODO: debug for large M + # TODO: investigate cache_modifier='.cg' on tl.store + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + + if pid >= num_pid_m: + if HAS_SECOND_INPUT: + pid -= num_pid_m + x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + x_offs_n2 = tl.arange(0, BLOCK_SIZE_N2) + mask2 = None + other2 = None + if not EVEN_M_N2: + mask2 = (x_offs_m < M)[:, None] & (x_offs_n2 < N2)[None, :] + other2 = 0.0 + + x2 = tl.load( + x2_ptr + x_offs_m[:, None] * x2_stride_m + x_offs_n2[None, :], + mask=mask2, + other=other2, + cache_modifier=".cg", + ).to(tl.float32) + + w_mask2 = None + w_other2 = None + if not EVEN_M_N2: + w_mask2 = x_offs_n2 < N2 + w_other2 = 0.0 + + w2 = tl.load(w2_ptr + x_offs_n2, mask=w_mask2, other=w_other2).to( + tl.float32 + ) + + norm2 = _rmsmorm_op(x2, w2, N2, eps2) + + tl.store( + out2_ptr + x_offs_m[:, None] * out2_stride_m + x_offs_n2[None, :], + norm2.to(out2_ptr.type.element_ty), + mask=mask2, + cache_modifier=".cg", + ) + return + + x_offs_n = tl.arange(0, BLOCK_SIZE_N) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + mask1 = None + other1 = None + if not EVEN_M_N: + mask1 = (x_offs_m < M)[:, None] & (x_offs_n < N1)[None, :] + other1 = 0.0 + + x1 = tl.load( + x1_ptr + x_offs_m[:, None] * x1_stride_m + x_offs_n[None, :], + mask=mask1, + other=other1, + cache_modifier=".cg", + ).to(tl.float32) + + if FIRST_INPUT_RES: + res1 = tl.load( + res1_ptr + x_offs_m[:, None] * res1_stride_m + x_offs_n[None, :], + mask=mask1, + other=other1, + cache_modifier=".cg", + ).to(tl.float32) + x1 = x1 + res1 + + w_mask1 = None + w_other1 = None + if not EVEN_M_N: + w_mask1 = x_offs_n < N1 + w_other1 = 0.0 + + w1 = tl.load(w1_ptr + x_offs_n, mask=w_mask1, other=w_other1).to(tl.float32) + + norm1 = _rmsmorm_op(x1, w1, N1, eps1) + out1_fp4, bs_e8m0 = _hadamard_mxfp4_quant_op( + norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) + + # store the results + half_x_offs_n = tl.arange(0, BLOCK_SIZE_N // 2) + out_mask1 = None + if not EVEN_M_N: + out_mask1 = (x_offs_m < M)[:, None] & (half_x_offs_n < (N1 // 2))[None, :] + + tl.store( + out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], + out1_fp4, + mask=out_mask1, + cache_modifier=".cg", + ) + + bs_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + bs_offs_n = tl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] // 32 + bs_offs_1 = bs_offs_m[:, None] % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = bs_offs_n[None, :] // 8 + bs_offs_4 = bs_offs_n[None, :] % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] + bs_e8m0 = tl.where(bs_mask_127, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * out1_bs_stride_m + + bs_offs_n[None, :] * out1_bs_stride_n + ) + + bs_mask = None + if not EVEN_M_N: + if SHUFFLE_PAD: + bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ + None, : + ] + else: + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + + tl.store( + out1_bs_ptr + bs_offs, + bs_e8m0.to(out1_bs_ptr.type.element_ty), + mask=bs_mask, + cache_modifier=".cg", + ) + + if FIRST_INPUT_RES: + tl.store( + out_res1_ptr + x_offs_m[:, None] * out_res1_stride_m + x_offs_n[None, :], + x1.to(out_res1_ptr.dtype.element_ty), + mask=mask1, + cache_modifier=".cg", + ) + + +@triton.heuristics( + { + "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, + "EVEN_M_N2": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N2"] % (args["BLOCK_SIZE_N2"]) == 0, + } +) +@triton.jit +def _fused_rms_mxfp4_quant_kernel( + x1_ptr, + w1_ptr, + x2_ptr, + w2_ptr, + res1_ptr, + out1_fp4_ptr, + out1_bs_ptr, + out2_ptr, + out_res1_ptr, + eps1, + eps2, + M, + N1, + N2, + x1_stride_m, + x2_stride_m, + res1_stride_m, + out1_fp4_stride_m, + out1_bs_stride_m, + out1_bs_stride_n, + out2_stride_m, + out_res1_stride_m, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_N2: tl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: tl.constexpr, + HAS_SECOND_INPUT: tl.constexpr, + FIRST_INPUT_RES: tl.constexpr, + SCALE_N: tl.constexpr, + SCALE_M_PAD: tl.constexpr, + SCALE_N_PAD: tl.constexpr, + SHUFFLE: tl.constexpr, + SHUFFLE_PAD: tl.constexpr, + EVEN_M_N: tl.constexpr, + EVEN_M_N2: tl.constexpr, +): + # TODO: XCD remapping where every 32-token block should share the same XCD + # TODO: debug for large M + # TODO: investigate cache_modifier='.cg' on tl.store + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + + if pid >= num_pid_m: + if HAS_SECOND_INPUT: + pid -= num_pid_m + x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + x_offs_n2 = tl.arange(0, BLOCK_SIZE_N2) + mask2 = None + other2 = None + if not EVEN_M_N2: + mask2 = (x_offs_m < M)[:, None] & (x_offs_n2 < N2)[None, :] + other2 = 0.0 + + x2 = tl.load( + x2_ptr + x_offs_m[:, None] * x2_stride_m + x_offs_n2[None, :], + mask=mask2, + other=other2, + cache_modifier=".cg", + ).to(tl.float32) + + w_mask2 = None + w_other2 = None + if not EVEN_M_N2: + w_mask2 = x_offs_n2 < N2 + w_other2 = 0.0 + + w2 = tl.load(w2_ptr + x_offs_n2, mask=w_mask2, other=w_other2).to( + tl.float32 + ) + + norm2 = _rmsmorm_op(x2, w2, N2, eps2) + + tl.store( + out2_ptr + x_offs_m[:, None] * out2_stride_m + x_offs_n2[None, :], + norm2.to(out2_ptr.type.element_ty), + mask=mask2, + cache_modifier=".cg", + ) + return + + x_offs_n = tl.arange(0, BLOCK_SIZE_N) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + mask1 = None + other1 = None + if not EVEN_M_N: + mask1 = (x_offs_m < M)[:, None] & (x_offs_n < N1)[None, :] + other1 = 0.0 + + x1 = tl.load( + x1_ptr + x_offs_m[:, None] * x1_stride_m + x_offs_n[None, :], + mask=mask1, + other=other1, + cache_modifier=".cg", + ).to(tl.float32) + + if FIRST_INPUT_RES: + res1 = tl.load( + res1_ptr + x_offs_m[:, None] * res1_stride_m + x_offs_n[None, :], + mask=mask1, + other=other1, + cache_modifier=".cg", + ).to(tl.float32) + x1 = x1 + res1 + + w_mask1 = None + w_other1 = None + if not EVEN_M_N: + w_mask1 = x_offs_n < N1 + w_other1 = 0.0 + + w1 = tl.load(w1_ptr + x_offs_n, mask=w_mask1, other=w_other1).to(tl.float32) + + norm1 = _rmsmorm_op(x1, w1, N1, eps1) + out1_fp4, bs_e8m0 = _mxfp4_quant_op( + norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) + + # store the results + half_x_offs_n = tl.arange(0, BLOCK_SIZE_N // 2) + out_mask1 = None + if not EVEN_M_N: + out_mask1 = (x_offs_m < M)[:, None] & (half_x_offs_n < (N1 // 2))[None, :] + + tl.store( + out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], + out1_fp4, + mask=out_mask1, + cache_modifier=".cg", + ) + + bs_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + bs_offs_n = tl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] // 32 + bs_offs_1 = bs_offs_m[:, None] % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = bs_offs_n[None, :] // 8 + bs_offs_4 = bs_offs_n[None, :] % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] + bs_e8m0 = tl.where(bs_mask_127, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * out1_bs_stride_m + + bs_offs_n[None, :] * out1_bs_stride_n + ) + + bs_mask = None + if not EVEN_M_N: + if SHUFFLE_PAD: + bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ + None, : + ] + else: + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + + tl.store( + out1_bs_ptr + bs_offs, + bs_e8m0.to(out1_bs_ptr.type.element_ty), + mask=bs_mask, + cache_modifier=".cg", + ) + + if FIRST_INPUT_RES: + tl.store( + out_res1_ptr + x_offs_m[:, None] * out_res1_stride_m + x_offs_n[None, :], + x1.to(out_res1_ptr.dtype.element_ty), + mask=mask1, + cache_modifier=".cg", + ) + + +@triton.jit +def _fused_flatten_mxfp4_quant( + x_ptr, + out_ptr, + out_scales_ptr, + x_stride_m, + x_stride_n1, + x_stride_n2, + out_stride_m, + out_stride_n, + out_scales_stride_m, + out_scales_stride_n, + N2, + BLOCK_SIZE_N2: tl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: tl.constexpr, +): + m = tl.program_id(0) + n1 = tl.program_id(1) + + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N2 // MXFP4_QUANT_BLOCK_SIZE + n2_offs = tl.arange(0, BLOCK_SIZE_N2) + x_offs = m * x_stride_m + n1 * x_stride_n1 + n2_offs * x_stride_n2 + x = tl.load(x_ptr + x_offs, mask=n2_offs < N2) + + out, out_block_scales = _mxfp4_quant_op(x, BLOCK_SIZE_N2, 1, MXFP4_QUANT_BLOCK_SIZE) + out = tl.ravel(out) + out_block_scales = tl.ravel(out_block_scales) + + half_block_offs = tl.arange(0, BLOCK_SIZE_N2 // 2) + tl.store( + out_ptr + + m * out_stride_m + + (n1 * (BLOCK_SIZE_N2 // 2) + half_block_offs) * out_stride_n, + out, + mask=half_block_offs < (N2 // 2), + ) + block_scale_offs = tl.arange(0, NUM_QUANT_BLOCKS) + tl.store( + out_scales_ptr + + m * out_scales_stride_m + + (n1 * NUM_QUANT_BLOCKS + block_scale_offs) * out_scales_stride_n, + out_block_scales, + mask=block_scale_offs < tl.cdiv(N2, MXFP4_QUANT_BLOCK_SIZE), + ) diff --git a/include/tritonblas/rmsnorm.py b/include/tritonblas/rmsnorm.py new file mode 100644 index 0000000..8c8c02d --- /dev/null +++ b/include/tritonblas/rmsnorm.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +from torch import Tensor +from aiter.jit.core import compile_ops +from typing import Optional + +MD_NAME = "module_rmsnorm" + + +@compile_ops("module_rmsnorm") +def rms_norm_cu( + out: Tensor, + input: Tensor, + weight: Tensor, + epsilon: float, +) -> None: + """ + Cuda version of rmsnorm + """ + ... + + +@compile_ops("module_rmsnorm") +def fused_add_rms_norm_cu( + input: Tensor, # input/out + residual_in: Tensor, # residual_in/out + weight: Tensor, + epsilon: float, +) -> None: + """ + Cuda version of rmsnorm fused add + """ + ... + + +def gen_rms_norm_fake_tensor( + input: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> Tensor: + return torch.empty_like(input, dtype=input.dtype, device=input.device) + + +@compile_ops( + "module_rmsnorm", fc_name="rmsnorm2d_fwd", gen_fake=gen_rms_norm_fake_tensor +) +def rms_norm( + input: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> Tensor: + """ + CK version of rmsnorm + """ + ... + + +@compile_ops("module_rmsnorm", gen_fake=gen_rms_norm_fake_tensor) +def rmsnorm2d_fwd( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> Tensor: ... + + +@compile_ops("module_rmsnorm") +def rmsnorm2d_fwd_with_add( + out: Tensor, + input: Tensor, + residual_in: Tensor, + residual_out: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... + + +@compile_ops("module_rmsnorm") +def rmsnorm2d_fwd_with_smoothquant( + out: Tensor, + input: Tensor, + xscale: Tensor, + yscale: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... + + +@compile_ops("module_rmsnorm") +def rmsnorm2d_fwd_with_add_smoothquant( + out: Tensor, + input: Tensor, + residual_in: Tensor, + residual_out: Tensor, + xscale: Tensor, + yscale: Tensor, + weight: Tensor, + epsilon: float, + out_before_quant: Optional[Tensor] = None, + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... + + +@compile_ops("module_rmsnorm") +def rmsnorm2d_fwd_with_dynamicquant( + out: Tensor, + input: Tensor, + yscale: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... + + +@compile_ops("module_rmsnorm") +def rmsnorm2d_fwd_with_add_dynamicquant( + out: Tensor, + input: Tensor, + residual_in: Tensor, + residual_out: Tensor, + yscale: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... diff --git a/tests/fp4_utils.py b/tests/fp4_utils.py new file mode 100644 index 0000000..b7c9904 --- /dev/null +++ b/tests/fp4_utils.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: MIT +# FP4 utilities for tritonblas testing +# Based on aiter's fp4_utils implementation + +import torch +import triton +import triton.language as tl + + +def mxfp4_to_f32(x: torch.Tensor) -> torch.Tensor: + """ + Convert packed FP4 e2m1 data to FP32. + + FP4 e2m1 format (4 bits per value): + - 1 sign bit + - 2 exponent bits + - 1 mantissa bit + + Representable values: + - 0x0 (0000): +0.0 + - 0x1 (0001): +0.5 + - 0x2 (0010): +1.0 + - 0x3 (0011): +1.5 + - 0x4 (0100): +2.0 + - 0x5 (0101): +3.0 + - 0x6 (0110): +4.0 + - 0x7 (0111): +6.0 + - 0x8 (1000): -0.0 + - 0x9 (1001): -0.5 + - 0xA (1010): -1.0 + - 0xB (1011): -1.5 + - 0xC (1100): -2.0 + - 0xD (1101): -3.0 + - 0xE (1110): -4.0 + - 0xF (1111): -6.0 + + Args: + x: Packed FP4 tensor (2 values per uint8) + + Returns: + Unpacked FP32 tensor + """ + if x.dtype == torch.float4_e2m1fn_x2: + x = x.view(torch.uint8) + + # Unpack: 2 FP4 values per uint8 + # Shape: (..., N) -> (..., N*2) + x = x.repeat_interleave(2, dim=-1) + x[..., ::2] = x[..., ::2] & 0xF # Lower 4 bits + x[..., 1::2] = x[..., 1::2] >> 4 # Upper 4 bits + + # Lookup table for FP4 e2m1 values + mxfp4_list = [ + 0.0, # 0x0 + 0.5, # 0x1 + 1.0, # 0x2 + 1.5, # 0x3 + 2.0, # 0x4 + 3.0, # 0x5 + 4.0, # 0x6 + 6.0, # 0x7 + -0.0, # 0x8 + -0.5, # 0x9 + -1.0, # 0xA + -1.5, # 0xB + -2.0, # 0xC + -3.0, # 0xD + -4.0, # 0xE + -6.0, # 0xF + ] + mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device=x.device) + return mxfp4_in_f32[x.long()] + + +def e8m0_to_f32(scale_e8m0: torch.Tensor) -> torch.Tensor: + """ + Convert e8m0 scales to FP32. + + E8M0 format stores only the exponent (8 bits, biased by 127). + The value is 2^(exponent - 127). + + Special cases: + - 0x00: Represents 2^(-126) (minimum normal) + - 0xFF: Represents NaN/Inf + + Args: + scale_e8m0: E8M0 scale tensor + + Returns: + FP32 scale tensor + """ + scale_e8m0_biased = scale_e8m0.view(torch.uint8) + + # Special cases + zero_case = scale_e8m0_biased == 0 + nan_case = scale_e8m0_biased == 0xFF + + # Convert to FP32 by placing exponent in correct position + scale_f32 = scale_e8m0_biased.to(torch.int32) << 23 + + # Handle special cases + scale_f32[zero_case] = 0x00400000 # 2^(-126) + scale_f32[nan_case] = 0x7F800001 # NaN + + scale_f32 = scale_f32.view(torch.float32) + return scale_f32 + + +@triton.jit +def _dynamic_mxfp4_quant_kernel( + x_ptr, + x_fp4_ptr, + bs_ptr, + stride_x_m, + stride_x_n, + stride_x_fp4_m, + stride_x_fp4_n, + stride_bs_m, + stride_bs_n, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: tl.constexpr, +): + """ + Triton kernel for quantizing FP32/FP16/BF16 to FP4 e2m1 format. + + Each row is divided into blocks of MXFP4_QUANT_BLOCK_SIZE elements. + Each block gets one e8m0 scale computed from the max absolute value. + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Load input block + x_offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x_offs_n = pid_n * MXFP4_QUANT_BLOCK_SIZE + tl.arange(0, MXFP4_QUANT_BLOCK_SIZE) + x_offs = x_offs_m[:, None] * stride_x_m + x_offs_n[None, :] * stride_x_n + x_mask = (x_offs_m < M)[:, None] & (x_offs_n < N)[None, :] + x = tl.load(x_ptr + x_offs, mask=x_mask).to(tl.float32) + + # Calculate scale per row (max absolute value) + amax = tl.max(tl.abs(x), axis=-1, keep_dims=True) + + # Convert to e8m0 format + # Round up to nearest power of 2 for better numerical stability + amax = amax.to(tl.int32, bitcast=True) + amax = (amax + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax = amax.to(tl.float32, bitcast=True) + + # Compute unbiased exponent: log2(amax) - 2 (because max FP4 value is 6.0 = 2^2 * 1.5) + scale_e8m0_unbiased = tl.log2(amax).floor() - 2 + scale_e8m0_unbiased = tl.clamp(scale_e8m0_unbiased, min=-127, max=127) + + # Quantization scale + quant_scale = tl.exp2(-scale_e8m0_unbiased) + + # Quantize to FP4 range + qx = x * quant_scale + + # Store e8m0 scale (add bias of 127) + bs_e8m0 = scale_e8m0_unbiased.to(tl.uint8) + 127 + + # Convert quantized FP32 to FP4 e2m1 + qx = qx.to(tl.uint32, bitcast=True) + + # Extract sign, exponent, mantissa + s = qx & 0x80000000 + e = (qx >> 23) & 0xFF + m = qx & 0x7FFFFF + + E8_BIAS: tl.constexpr = 127 + E2_BIAS: tl.constexpr = 1 + + # Handle denormal numbers + adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False) + m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m) + + # Adjust exponent bias + e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine and saturate to 4 bits + # Round nearest with tie breaking up + e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7) + e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8) + + # Pack two FP4 values into one uint8 + e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE, MXFP4_QUANT_BLOCK_SIZE // 2, 2]) + evens, odds = tl.split(e2m1_value) + out_tensor = evens | (odds << 4) + + # Store packed FP4 output + out_offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + out_offs_n = pid_n * MXFP4_QUANT_BLOCK_SIZE // 2 + tl.arange(0, MXFP4_QUANT_BLOCK_SIZE // 2) + out_offs = out_offs_m[:, None] * stride_x_fp4_m + out_offs_n[None, :] * stride_x_fp4_n + out_mask = (out_offs_m < M)[:, None] & (out_offs_n < (N // 2))[None, :] + tl.store(x_fp4_ptr + out_offs, out_tensor, mask=out_mask) + + # Store e8m0 scales + bs_offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + bs_offs_n = pid_n + bs_offs = bs_offs_m[:, None] * stride_bs_m + bs_offs_n[None, :] * stride_bs_n + bs_mask = (bs_offs_m < M)[:, None] + tl.store(bs_ptr + bs_offs, bs_e8m0, mask=bs_mask) + + +def dynamic_mxfp4_quant(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize a tensor to MX FP4 e2m1 format with e8m0 scales. + + The input is divided into blocks of 32 elements along the last dimension. + Each block gets one e8m0 scale computed from the max absolute value. + + Args: + x: Input tensor (FP32, FP16, or BF16), shape (..., N) where N % 32 == 0 + + Returns: + Tuple of: + - x_fp4: Packed FP4 data, shape (..., N//2) + - blockscale_e8m0: E8M0 scales, shape (..., N//32) + """ + assert x.ndim == 2, "Input must be 2D tensor" + M, N = x.shape + assert N % 32 == 0, f"N ({N}) must be divisible by 32" + + # Fixed by MXFP4 spec + MXFP4_QUANT_BLOCK_SIZE = 32 + BLOCK_SIZE = 128 + + # Allocate output tensors + x_fp4 = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) + scaleN = triton.cdiv(N, MXFP4_QUANT_BLOCK_SIZE) + blockscale_e8m0 = torch.empty((M, scaleN), dtype=torch.uint8, device=x.device) + + # Launch kernel + grid = (triton.cdiv(M, BLOCK_SIZE), scaleN) + _dynamic_mxfp4_quant_kernel[grid]( + x, + x_fp4, + blockscale_e8m0, + x.stride(0), + x.stride(1), + x_fp4.stride(0), + x_fp4.stride(1), + blockscale_e8m0.stride(0), + blockscale_e8m0.stride(1), + M=M, + N=N, + BLOCK_SIZE=BLOCK_SIZE, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + ) + + return x_fp4, blockscale_e8m0 diff --git a/tests/hadamard.py b/tests/hadamard.py new file mode 100644 index 0000000..b207011 --- /dev/null +++ b/tests/hadamard.py @@ -0,0 +1,316 @@ +""" +Standalone Triton kernel for blocked Hadamard transformation. +This kernel applies a 32x32 Hadamard transformation to each 32x32 subblock of an MxK matrix A. +The output is also MxK, with each 32x32 block rotated by the Hadamard matrix. +For each 32x32 block A[i:i+32, j:j+32], we compute: + Output[i:i+32, j:j+32] = A[i:i+32, j:j+32] @ H +where H is the 32x32 Hadamard matrix. +""" + +import triton +import triton.language as tl +import torch +import time +import numpy as np + + +@triton.jit +def hadamard_blocked_kernel( + A_ptr, # Pointer to input matrix A [M, K] + H_ptr, # Pointer to Hadamard matrix [32, 32] + Out_ptr, # Pointer to output matrix [M, K] + M, # Number of rows in A + K, # Number of columns in A + stride_am, # Stride of A in M dimension + stride_ak, # Stride of A in K dimension + stride_hrow, # Stride of H in row dimension + stride_hcol, # Stride of H in column dimension + stride_om, # Stride of output in M dimension + stride_ok, # Stride of output in K dimension + BLOCK_SIZE: tl.constexpr, # Block size (32) +): + """ + Kernel that applies Hadamard transformation to each 32x32 block of A. + + Each program processes one 32x32 block independently: + Output[m_block, k_block] = A[m_block, k_block] @ H + """ + # Get program IDs for M and K dimensions + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + + # Compute starting indices for this block + m_start = pid_m * BLOCK_SIZE + k_start = pid_k * BLOCK_SIZE + + # Create offset ranges + m_offs = m_start + tl.arange(0, BLOCK_SIZE) + k_offs = k_start + tl.arange(0, BLOCK_SIZE) + + # Create masks for boundary conditions + m_mask = m_offs < M + k_mask = k_offs < K + + # Load A block [BLOCK_SIZE, BLOCK_SIZE] + a_ptrs = A_ptr + m_offs[:, None] * stride_am + k_offs[None, :] * stride_ak + a_block = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0) + + # Load Hadamard matrix [BLOCK_SIZE, BLOCK_SIZE] + h_row_offs = tl.arange(0, BLOCK_SIZE) + h_col_offs = tl.arange(0, BLOCK_SIZE) + h_ptrs = H_ptr + h_row_offs[:, None] * stride_hrow + h_col_offs[None, :] * stride_hcol + h_block = tl.load(h_ptrs) + + # Perform matrix multiplication: A_block @ H_block + # This is a single 32x32 @ 32x32 operation + result = tl.dot(a_block, h_block) + + # Store result to output + out_ptrs = Out_ptr + m_offs[:, None] * stride_om + k_offs[None, :] * stride_ok + tl.store(out_ptrs, result, mask=m_mask[:, None] & k_mask[None, :]) + + +def hadamard_blocked_transform(A: torch.Tensor, H: torch.Tensor) -> torch.Tensor: + """ + Apply Hadamard transformation to each 32x32 block of matrix A. + + Args: + A: Input matrix of shape [M, K] + H: Hadamard matrix of shape [32, 32] + + Returns: + Output matrix of shape [M, K] with each 32x32 block transformed + """ + assert A.is_cuda and H.is_cuda, "Tensors must be on CUDA" + assert H.shape == (32, 32), f"Hadamard must be 32x32, got {H.shape}" + + M, K = A.shape + + # Allocate output with same shape as A + Out = torch.zeros_like(A) + + # Define block size + BLOCK_SIZE = 32 + + # Calculate grid dimensions - one program per 32x32 block + grid = (triton.cdiv(M, BLOCK_SIZE), triton.cdiv(K, BLOCK_SIZE)) + + # Launch kernel + hadamard_blocked_kernel[grid]( + A, H, Out, + M, K, + A.stride(0), A.stride(1), + H.stride(0), H.stride(1), + Out.stride(0), Out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + return Out + +def generate_hadamard_matrix(n: int, dtype=torch.float32) -> torch.Tensor: + """ + Generate a Hadamard matrix of size n x n using Sylvester's construction. + n must be a power of 2. + """ + assert n > 0 and (n & (n - 1)) == 0, "n must be a power of 2" + + if n == 1: + return torch.tensor([[1.0]]) + + # Recursive construction + H_half = generate_hadamard_matrix(n // 2) + H = torch.zeros((n, n), dtype=dtype) + half = n // 2 + + H[:half, :half] = H_half + H[:half, half:] = H_half + H[half:, :half] = H_half + H[half:, half:] = -H_half + + # Note: cannot normalize when applying recursion! + return H + + +def hadamard_blocked_transform_torch(A: torch.Tensor, H: torch.Tensor) -> torch.Tensor: + """ + PyTorch reference implementation using batched matrix multiplication. + + Reshapes A into blocks and performs batched GEMM with the Hadamard matrix. + """ + M, K = A.shape + + # Pad to multiples of 32 if necessary + M_pad = ((M + 31) // 32) * 32 + K_pad = ((K + 31) // 32) * 32 + + if M_pad != M or K_pad != K: + A_padded = torch.zeros((M_pad, K_pad), device=A.device, dtype=A.dtype) + A_padded[:M, :K] = A + else: + A_padded = A + + # Reshape into blocks: [M_pad, K_pad] -> [M_blocks, 32, K_blocks, 32] + M_blocks = M_pad // 32 + K_blocks = K_pad // 32 + + # Reshape: [M_pad, K_pad] -> [M_blocks, 32, K_blocks, 32] -> [M_blocks, K_blocks, 32, 32] + A_blocks = A_padded.reshape(M_blocks, 32, K_blocks, 32).permute(0, 2, 1, 3) + + # Flatten batch dimensions: [M_blocks * K_blocks, 32, 32] + A_blocks_flat = A_blocks.reshape(-1, 32, 32) + + # Batched matrix multiplication: [batch, 32, 32] @ [32, 32] -> [batch, 32, 32] + Out_blocks_flat = torch.bmm(A_blocks_flat, H.unsqueeze(0).expand(A_blocks_flat.shape[0], -1, -1)) + + # Reshape back: [batch, 32, 32] -> [M_blocks, K_blocks, 32, 32] + Out_blocks = Out_blocks_flat.reshape(M_blocks, K_blocks, 32, 32) + + # Permute and reshape: [M_blocks, K_blocks, 32, 32] -> [M_blocks, 32, K_blocks, 32] -> [M_pad, K_pad] + Out_padded = Out_blocks.permute(0, 2, 1, 3).reshape(M_pad, K_pad) + + # Remove padding if necessary + if M_pad != M or K_pad != K: + return Out_padded[:M, :K] + else: + return Out_padded + + +def test_correctness(): + """Test correctness against PyTorch reference implementation.""" + print("=" * 80) + print("Testing Correctness") + print("=" * 80) + + # Set random seed for reproducibility + torch.manual_seed(42) + + # Test dimensions + M, K = 128, 128 + + # Generate test data + A = torch.randn((M, K), device='cuda', dtype=torch.float32) + H = generate_hadamard_matrix(32).cuda() + + # Triton implementation + Out_triton = hadamard_blocked_transform(A, H) + + # PyTorch reference using batched GEMM + Out_torch = hadamard_blocked_transform_torch(A, H) + + # Compare results + max_diff = torch.max(torch.abs(Out_triton - Out_torch)).item() + mean_diff = torch.mean(torch.abs(Out_triton - Out_torch)).item() + + print(f"Input shape: A={A.shape}, H={H.shape}") + print(f"Output shape: {Out_triton.shape}") + print(f"Max difference: {max_diff:.6e}") + print(f"Mean difference: {mean_diff:.6e}") + + if max_diff < 1e-4: + print("✓ Correctness test PASSED") + else: + print("✗ Correctness test FAILED") + print(f"\nSample values:") + print(f"Triton output [0:3, 0:3]:\n{Out_triton[0:3, 0:3]}") + print(f"PyTorch output [0:3, 0:3]:\n{Out_torch[0:3, 0:3]}") + + return max_diff < 1e-4 + + +def benchmark_performance(): + """Benchmark performance against PyTorch implementation.""" + print("\n" + "=" * 80) + print("Performance Benchmark") + print("=" * 80) + + # Test configurations + configs = [ + (512, 512), + (1024, 1024), + (2048, 2048), + (4096, 4096), + (8192, 8192), + ] + + num_warmup = 10 + num_iterations = 100 + + print(f"\nRunning {num_warmup} warmup iterations and {num_iterations} timed iterations") + print(f"{'M':>6} {'K':>6} {'Triton (ms)':>12} {'PyTorch (ms)':>13} {'Speedup':>8} {'TFLOPS':>8}") + print("-" * 80) + + for M, K in configs: + # Generate test data + A = torch.randn((M, K), device='cuda', dtype=torch.float32) + H = generate_hadamard_matrix(32).cuda() + + # Warmup - Triton + for _ in range(num_warmup): + _ = hadamard_blocked_transform(A, H) + torch.cuda.synchronize() + + # Benchmark - Triton + start = time.time() + for _ in range(num_iterations): + Out_triton = hadamard_blocked_transform(A, H) + torch.cuda.synchronize() + triton_time = (time.time() - start) / num_iterations * 1000 + + # Warmup - PyTorch (batched GEMM) + for _ in range(num_warmup): + _ = hadamard_blocked_transform_torch(A, H) + torch.cuda.synchronize() + + # Benchmark - PyTorch (batched GEMM) + start = time.time() + for _ in range(num_iterations): + Out_torch = hadamard_blocked_transform_torch(A, H) + torch.cuda.synchronize() + torch_time = (time.time() - start) / num_iterations * 1000 + + # Calculate metrics + speedup = torch_time / triton_time + + # FLOPS calculation: (M/32) * (K/32) blocks, each doing 32*32*32*2 FLOPs + num_blocks = ((M + 31) // 32) * ((K + 31) // 32) + flops = num_blocks * 32 * 32 * 32 * 2 + tflops = (flops / (triton_time * 1e-3)) / 1e12 + + print(f"{M:6d} {K:6d} {triton_time:12.4f} {torch_time:13.4f} {speedup:8.2f}x {tflops:8.2f}") + + print("=" * 80) + + +def main(): + """Main function to run tests and benchmarks.""" + print("\n" + "=" * 80) + print("Hadamard Blocked Transformation - Triton Implementation") + print("=" * 80) + print("\nThis kernel applies a 32x32 Hadamard transformation to each") + print("32x32 subblock of the input matrix A, producing an output of") + print("the same size as A.") + + # Check CUDA availability + if not torch.cuda.is_available(): + print("\nERROR: CUDA is not available. This kernel requires a CUDA GPU.") + return + + print(f"\nDevice: {torch.cuda.get_device_name()}") + print(f"CUDA Version: {torch.version.cuda}") + + # Run correctness test + passed = test_correctness() + + if passed: + # Run performance benchmark + benchmark_performance() + else: + print("\nSkipping performance benchmark due to correctness test failure.") + + print("\n" + "=" * 80) + print("Done!") + print("=" * 80 + "\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_fused_had.py b/tests/test_fused_had.py new file mode 100644 index 0000000..6bdd3ff --- /dev/null +++ b/tests/test_fused_had.py @@ -0,0 +1,266 @@ +import math +import torch +import torch.nn.functional as F +import numpy as np +from hadamard import hadamard_blocked_transform, generate_hadamard_matrix +import tritonblas +from fp4_utils import dynamic_mxfp4_quant, mxfp4_to_f32, e8m0_to_f32 +import aiter +def is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + + +def is_triton_available() -> bool: + """Check if Triton is available.""" + try: + import triton + return True + except ImportError: + return False + + +# Reference implementation +def hadamard_matrix(n: int, dtype=torch.float32) -> np.ndarray: + """ + Create a Hadamard matrix of size n x n using Sylvester's construction. + n must be a power of 2. + + Args: + n: Size of the Hadamard matrix (must be power of 2) + + Returns: + Hadamard matrix as numpy array + """ + if n < 1: + lg2 = 0 + else: + lg2 = int(math.log(n, 2)) + if 2 ** lg2 != n: + raise ValueError("n must be int and power of 2.") + + if n == 1: + return np.array([[1]], dtype=np.float32) + + H = np.array([[1, 1], [1, -1]], dtype=np.float32) + + # Hadamard stacking via Sylvester's construction + # H H + # H -H + for i in range(0, lg2 - 1): + H = np.vstack((np.hstack((H, H)), np.hstack((H, -H)))) + + return H.to(dtype) + + +def fwht_matmul(x: torch.Tensor) -> torch.Tensor: + """ + Matrix multiplication version of Hadamard transform (un-normalized). + Power-of-two length required along last dimension. + + Args: + x: Input tensor with power-of-2 size in last dimension + + Returns: + Hadamard transformed tensor + """ + *leading, N = x.shape + + if not is_power_of_2(N): + raise ValueError(f"N must be power-of-two, got {N}") + + # Create Hadamard matrix + H = generate_hadamard_matrix(N, dtype=x.dtype).cuda() / math.sqrt(N) + + # H_tensor = torch.from_numpy(H).to(x.device).to(x.dtype) + # Perform matrix multiplication: x @ H + return x @ H + + +def fwht_torch_reference(x: torch.Tensor, N: int = 0) -> torch.Tensor: + """ + Reference FWHT implementation matching Sylvester's Hadamard construction (un-normalized). + Power-of-two length required along last dimension. + """ + y = x.clone() + if not N: + N = y.shape[-1] + + # Perform butterfly operations matching Sylvester's construction + # Start with stride N/2 and work down to stride 1 + stride = N // 2 + while stride >= 1: + for start in range(0, N, stride * 2): + for i in range(stride): + idx1 = start + i + idx2 = start + i + stride + a = y[..., idx1].clone() + b = y[..., idx2].clone() + y[..., idx1] = a + b + y[..., idx2] = a - b + stride //= 2 + + y = y / math.sqrt(N) + return y + + +# Triton kernels +if is_triton_available(): + import triton + import triton.language as tl + + @triton.jit + def _fwht_tile_kernel(X_ptr, stride_row, N, TILE: tl.constexpr, LOG_TILE: tl.constexpr): + """Blockwise FWHT (per-tile) with masked final (partial) tile.""" + pid = tl.program_id(0) + tiles_per_row = tl.cdiv(N, TILE) + row = pid // tiles_per_row + tile_id = pid % tiles_per_row + row_base = row * stride_row + start = tile_id * TILE + + offs = tl.arange(0, TILE) + idx_i = start + offs + valid_i = idx_i < N + + # Iterate butterfly stages + for s in range(LOG_TILE): + dist = 1 << s + idx_j = idx_i ^ dist + valid_j = idx_j < N + process = (offs & dist) == 0 + + # Load both ends + vi = tl.load(X_ptr + row_base + idx_i, mask=valid_i, other=0.0) + vj = tl.load(X_ptr + row_base + idx_j, mask=valid_j, other=0.0) + + vsum = vi + vj + vdiff = vi - vj + # Store results + tl.store(X_ptr + row_base + idx_i, vsum, mask=valid_i & process) + tl.store(X_ptr + row_base + idx_j, vdiff, mask=valid_j & process) + + @triton.jit + def _fwht_merge_kernel(X_ptr, stride_row, N, STAGE_DIST: tl.constexpr): + """Merge stage across tiles.""" + pid = tl.program_id(0) + groups_per_row = N // (2 * STAGE_DIST) + row_id = pid // groups_per_row + group_id = pid % groups_per_row + base0 = row_id * stride_row + group_id * 2 * STAGE_DIST + base1 = base0 + STAGE_DIST + offs = tl.arange(0, STAGE_DIST) + a = tl.load(X_ptr + base0 + offs) + b = tl.load(X_ptr + base1 + offs) + tl.store(X_ptr + base0 + offs, a + b) + tl.store(X_ptr + base1 + offs, a - b) + + +def _select_block(N: int, candidate_blocks=(32, 64, 128), max_mono_kernel_n=2048, block_size=None): + """Choose tile size for tiled FWHT.""" + divs = [b for b in candidate_blocks if b <= N and N % b == 0] + if not divs: + return N + if block_size and block_size in divs: + return block_size + if (N in divs) and (N <= max_mono_kernel_n) and not block_size: + return N + small = [b for b in divs if b <= 128] + if small: + return sorted(small, reverse=True)[0] + return max(divs) + +def torch_rmsnorm(x, g, out_dtype=torch.float16, epsilon=1e-6): + M, N = x.shape + # cast to float32 as the triton kernel + x_f32 = x.float() + g_f32 = g.float() + rms = torch.sqrt(torch.sum(x_f32 * x_f32, dim=-1) * 1 / N) + rsigma = 1.0 / rms + rms_norm_f32 = x_f32 * rsigma.unsqueeze(1) * g_f32 + rms_norm = rms_norm_f32.to(out_dtype) + return rms_norm + +def full_hadamard_triton(x: torch.Tensor, block_size=None): + """ + Full Hadamard transform using Triton kernels. + Un-normalized. Requires power-of-two length. + """ + if not is_triton_available(): + raise RuntimeError("Triton not available") + + assert x.is_contiguous() + *leading, N = x.shape + + if not is_power_of_2(N): + raise ValueError(f"N must be power-of-two, got {N}") + + rows = int(torch.tensor(leading).prod()) if leading else 1 + BLOCK = _select_block(N, block_size=block_size) + LOG_BLOCK = int(math.log2(BLOCK)) + tiles_per_row = N // BLOCK + total_intra = rows * tiles_per_row + stride_row = N + + # Intra-tile FWHT + _fwht_tile_kernel[(total_intra,)](x.view(-1, N), stride_row, N, TILE=BLOCK, LOG_TILE=LOG_BLOCK) + + # If single tile (BLOCK == N), we're done + if BLOCK == N: + return x + + # Inter-tile merges + dist = BLOCK + while dist < N: + groups_per_row = N // (2 * dist) + total_groups = rows * groups_per_row + _fwht_merge_kernel[(total_groups,)](x.view(-1, N), stride_row, N, STAGE_DIST=dist) + dist *= 2 + + return x + + +def test_full_hadamard(): + """Test full Hadamard: matrix multiplication, PyTorch reference, and Triton kernel.""" + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + + test_sizes = [32] #, 64, 128, 256, 512, 1024] + batch_sizes = [1, 4, 32, 64] + + print(f"\nRunning tests on device: {device}") + print("=" * 60) + + + for N in test_sizes: + for batch in batch_sizes: + print(f"\nTesting N={N}, batch={batch}") + print("-" * 40) + + # Create random input + x = torch.randn(batch, N, device=device, dtype=dtype) + + # rmsnorm parameters + weight = torch.randn(N, dtype=dtype, device=device) + eps = 1e-5 + use_model_sensitive_rmsnorm = 0 + + # 5. fast blocked GEMM + x_norm_1 = torch_rmsnorm(x.clone(), weight, epsilon=eps, out_dtype=dtype) + x_had = tritonblas.hadamard_blocked_fast(x_norm_1) + x_fp4_v1, x_scales_v1 = dynamic_mxfp4_quant(x_had) + + (x_fp4, x_scales), _, _ = tritonblas.fused_rms_hadamard_mxfp4_quant(x.clone(), weight, eps) + # (x_fp4, x_scales), _, _ = tritonblas.fused_mxfp4_quant(x.clone(), weight, eps) + + print(f"unfused vs. fused xfp4: {F.mse_loss(x_fp4_v1.to(torch.float), x_fp4.to(torch.float))}") + print(f"unfused vs. fused fp4 scales: {F.mse_loss(x_scales_v1.to(torch.float), x_scales.to(torch.float))}") + + + print("\n" + "=" * 60) + print("Test completed") + + +if __name__ == "__main__": + test_full_hadamard() \ No newline at end of file diff --git a/tests/test_hadamard.py b/tests/test_hadamard.py new file mode 100644 index 0000000..96f34f0 --- /dev/null +++ b/tests/test_hadamard.py @@ -0,0 +1,272 @@ +import math +import torch +import torch.nn.functional as F +import numpy as np +from hadamard import hadamard_blocked_transform, generate_hadamard_matrix +import tritonblas + +def is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + + +def is_triton_available() -> bool: + """Check if Triton is available.""" + try: + import triton + return True + except ImportError: + return False + + +# Reference implementation +def hadamard_matrix(n: int, dtype=torch.float32) -> np.ndarray: + """ + Create a Hadamard matrix of size n x n using Sylvester's construction. + n must be a power of 2. + + Args: + n: Size of the Hadamard matrix (must be power of 2) + + Returns: + Hadamard matrix as numpy array + """ + if n < 1: + lg2 = 0 + else: + lg2 = int(math.log(n, 2)) + if 2 ** lg2 != n: + raise ValueError("n must be int and power of 2.") + + if n == 1: + return np.array([[1]], dtype=np.float32) + + H = np.array([[1, 1], [1, -1]], dtype=np.float32) + + # Hadamard stacking via Sylvester's construction + # H H + # H -H + for i in range(0, lg2 - 1): + H = np.vstack((np.hstack((H, H)), np.hstack((H, -H)))) + + return H.to(dtype) + + +def fwht_matmul(x: torch.Tensor) -> torch.Tensor: + """ + Matrix multiplication version of Hadamard transform (un-normalized). + Power-of-two length required along last dimension. + + Args: + x: Input tensor with power-of-2 size in last dimension + + Returns: + Hadamard transformed tensor + """ + *leading, N = x.shape + + if not is_power_of_2(N): + raise ValueError(f"N must be power-of-two, got {N}") + + # Create Hadamard matrix + H = generate_hadamard_matrix(N, dtype=x.dtype).cuda() / math.sqrt(N) + + # H_tensor = torch.from_numpy(H).to(x.device).to(x.dtype) + # Perform matrix multiplication: x @ H + return x @ H + + +def fwht_torch_reference(x: torch.Tensor, N: int = 0) -> torch.Tensor: + """ + Reference FWHT implementation matching Sylvester's Hadamard construction (un-normalized). + Power-of-two length required along last dimension. + """ + y = x.clone() + if not N: + N = y.shape[-1] + + # Perform butterfly operations matching Sylvester's construction + # Start with stride N/2 and work down to stride 1 + stride = N // 2 + while stride >= 1: + for start in range(0, N, stride * 2): + for i in range(stride): + idx1 = start + i + idx2 = start + i + stride + a = y[..., idx1].clone() + b = y[..., idx2].clone() + y[..., idx1] = a + b + y[..., idx2] = a - b + stride //= 2 + + y = y / math.sqrt(N) + return y + + +# Triton kernels +if is_triton_available(): + import triton + import triton.language as tl + + @triton.jit + def _fwht_tile_kernel(X_ptr, stride_row, N, TILE: tl.constexpr, LOG_TILE: tl.constexpr): + """Blockwise FWHT (per-tile) with masked final (partial) tile.""" + pid = tl.program_id(0) + tiles_per_row = tl.cdiv(N, TILE) + row = pid // tiles_per_row + tile_id = pid % tiles_per_row + row_base = row * stride_row + start = tile_id * TILE + + offs = tl.arange(0, TILE) + idx_i = start + offs + valid_i = idx_i < N + + # Iterate butterfly stages + for s in range(LOG_TILE): + dist = 1 << s + idx_j = idx_i ^ dist + valid_j = idx_j < N + process = (offs & dist) == 0 + + # Load both ends + vi = tl.load(X_ptr + row_base + idx_i, mask=valid_i, other=0.0) + vj = tl.load(X_ptr + row_base + idx_j, mask=valid_j, other=0.0) + + vsum = vi + vj + vdiff = vi - vj + # Store results + tl.store(X_ptr + row_base + idx_i, vsum, mask=valid_i & process) + tl.store(X_ptr + row_base + idx_j, vdiff, mask=valid_j & process) + + @triton.jit + def _fwht_merge_kernel(X_ptr, stride_row, N, STAGE_DIST: tl.constexpr): + """Merge stage across tiles.""" + pid = tl.program_id(0) + groups_per_row = N // (2 * STAGE_DIST) + row_id = pid // groups_per_row + group_id = pid % groups_per_row + base0 = row_id * stride_row + group_id * 2 * STAGE_DIST + base1 = base0 + STAGE_DIST + offs = tl.arange(0, STAGE_DIST) + a = tl.load(X_ptr + base0 + offs) + b = tl.load(X_ptr + base1 + offs) + tl.store(X_ptr + base0 + offs, a + b) + tl.store(X_ptr + base1 + offs, a - b) + + +def _select_block(N: int, candidate_blocks=(32, 64, 128), max_mono_kernel_n=2048, block_size=None): + """Choose tile size for tiled FWHT.""" + divs = [b for b in candidate_blocks if b <= N and N % b == 0] + if not divs: + return N + if block_size and block_size in divs: + return block_size + if (N in divs) and (N <= max_mono_kernel_n) and not block_size: + return N + small = [b for b in divs if b <= 128] + if small: + return sorted(small, reverse=True)[0] + return max(divs) + + +def full_hadamard_triton(x: torch.Tensor, block_size=None): + """ + Full Hadamard transform using Triton kernels. + Un-normalized. Requires power-of-two length. + """ + if not is_triton_available(): + raise RuntimeError("Triton not available") + + assert x.is_contiguous() + *leading, N = x.shape + + if not is_power_of_2(N): + raise ValueError(f"N must be power-of-two, got {N}") + + rows = int(torch.tensor(leading).prod()) if leading else 1 + BLOCK = _select_block(N, block_size=block_size) + LOG_BLOCK = int(math.log2(BLOCK)) + tiles_per_row = N // BLOCK + total_intra = rows * tiles_per_row + stride_row = N + + # Intra-tile FWHT + _fwht_tile_kernel[(total_intra,)](x.view(-1, N), stride_row, N, TILE=BLOCK, LOG_TILE=LOG_BLOCK) + + # If single tile (BLOCK == N), we're done + if BLOCK == N: + return x + + # Inter-tile merges + dist = BLOCK + while dist < N: + groups_per_row = N // (2 * dist) + total_groups = rows * groups_per_row + _fwht_merge_kernel[(total_groups,)](x.view(-1, N), stride_row, N, STAGE_DIST=dist) + dist *= 2 + + return x + + +def test_full_hadamard(): + """Test full Hadamard: matrix multiplication, PyTorch reference, and Triton kernel.""" + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + + test_sizes = [32] #, 64, 128, 256, 512, 1024] + batch_sizes = [1, 4, 32, 64, 128, 512] + + print(f"\nRunning tests on device: {device}") + print("=" * 60) + + for N in test_sizes: + for batch in batch_sizes: + print(f"\nTesting N={N}, batch={batch}") + print("-" * 40) + + # Create random input + x = torch.randn(batch, N, device=device, dtype=dtype) + + # 1. Matrix multiplication version + matmul_result = fwht_matmul(x.clone()) + + # 2. Python FWHT reference + ref_result = fwht_torch_reference(x.clone(), N) + + # 3. triton FWHT reference + triton_result = full_hadamard_triton(x.clone()) / math.sqrt(N) + + # 4. blocked GEMM + H = generate_hadamard_matrix(N, dtype=dtype).cuda() / math.sqrt(N) + blocked_result = hadamard_blocked_transform(x.clone(), H) + + # 5. fast blocked GEMM + fast_blocked_result = tritonblas.hadamard_blocked_fast(x.clone()) + + # matmul vs ref + print(f"matmul vs ref: {F.mse_loss(matmul_result, ref_result).item():.2e}") + + # matmul vs triton + print(f"matmul vs triton: {F.mse_loss(matmul_result, triton_result).item():.2e}") + + # ref vs triton + print(f"ref vs triton: {F.mse_loss(ref_result, triton_result).item():.2e}") + + # blocked GEMM vs triton + print(f"blocked gemm vs. triton: {F.mse_loss(blocked_result, triton_result)}") + + # blocked GEMM vs ref + print(f"blocked gemm vs. matmul: {F.mse_loss(blocked_result, matmul_result)}") + + # fast blocked GEMM vs blocked GEMM + print(f"fast blocked gemm vs. blocked gemm: {F.mse_loss(fast_blocked_result, blocked_result)}") + + + print("\n" + "=" * 60) + print("Test completed") + + +if __name__ == "__main__": + test_full_hadamard() \ No newline at end of file diff --git a/tests/test_hadamard_latency.py b/tests/test_hadamard_latency.py new file mode 100644 index 0000000..1187703 --- /dev/null +++ b/tests/test_hadamard_latency.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: MIT +# Comprehensive FP4 matmul test suite for tritonblas +# Based on aiter's test_gemm_a4w4.py + +import torch +import tritonblas +import time +import argparse +import aiter +from aiter.ops.triton.fused_mxfp4_quant import ( + fused_rms_mxfp4_quant, +) +from fp4_utils import dynamic_mxfp4_quant, mxfp4_to_f32, e8m0_to_f32 +from hadamard import hadamard_blocked_transform, generate_hadamard_matrix +torch.set_default_device("cuda") +torch.set_printoptions(sci_mode=False) +import math +from test_hadamard import full_hadamard_triton, fwht_matmul, fwht_torch_reference + +def run_torch_reference(x_fp4, w_fp4, x_scales, w_scales, dtype): + """ + Compute reference result using PyTorch with dequantized FP4 inputs. + + This provides the ground truth for correctness validation. + """ + m, k_packed = x_fp4.shape + n, k_packed = w_fp4.shape + k = k_packed * 2 + + # Dequantize FP4 to FP32 + x_f32 = mxfp4_to_f32(x_fp4) + w_f32 = mxfp4_to_f32(w_fp4) + + # Convert e8m0 scales to FP32 and expand to match data shape + x_scales_f32 = e8m0_to_f32(x_scales) + x_scales_f32 = x_scales_f32.repeat_interleave(32, dim=1) + + w_scales_f32 = e8m0_to_f32(w_scales) + w_scales_f32 = w_scales_f32.repeat_interleave(32, dim=1) + + # Apply scales + x_f32 = x_f32 * x_scales_f32 + w_f32 = w_f32 * w_scales_f32 + + # Compute matmul + return torch.mm(x_f32, w_f32.T).to(dtype)[:m, :n] + + +def benchmark_kernel(func, *args, num_iters=10, warmup=3): + """Benchmark a kernel with warmup iterations.""" + # Warmup + for _ in range(warmup): + func(*args) + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(num_iters): + func(*args) + torch.cuda.synchronize() + end_time = time.time() + + avg_time_us = (end_time - start_time) / num_iters * 1e6 + return avg_time_us + + +def test_gemm_fp4(dtype, M, N, K, verbose=True): + """ + Test FP4 GEMM with given dimensions and dtype. + + Returns dictionary with performance metrics and error statistics. + """ + ret = {} + + # Generate random input data + x = torch.randn((M, K), dtype=dtype) + w = torch.randn((N, K), dtype=dtype) + + # Quantize to FP4 + x_fp4, x_scales = dynamic_mxfp4_quant(x) + w_fp4, w_scales = dynamic_mxfp4_quant(w) + + # Allocate output + out = torch.empty((M, N), dtype=dtype) + + # Compute reference + ref = run_torch_reference(x_fp4, w_fp4, x_scales, w_scales, dtype) + + # Run tritonblas FP4 matmul + def run_tritonblas(): + tritonblas.matmul_fp4(x_fp4, w_fp4, out, x_scales, w_scales) + + us = benchmark_kernel(run_tritonblas, num_iters=10, warmup=3) + + # Compute performance metrics + total_ops = 2 * M * N * K + ret["M"] = M + ret["N"] = N + ret["K"] = K + ret["dtype"] = str(dtype) + ret["us"] = us + ret["TFLOPS"] = total_ops / us / 1e6 + ret["TB/s"] = (x_fp4.nbytes + w_fp4.nbytes) / us / 1e6 + + # Compute error metrics + nan_mask = torch.isnan(out) + inf_mask = torch.isinf(out) + valid_mask = ~nan_mask & ~inf_mask + + num_valid = valid_mask.sum().item() + num_nan = nan_mask.sum().item() + num_inf = inf_mask.sum().item() + total = M * N + + ret["valid_%"] = 100 * num_valid / total + ret["nan_%"] = 100 * num_nan / total + ret["inf_%"] = 100 * num_inf / total + + # Compute error against reference + ref_valid_mask = ~torch.isnan(ref) & ~torch.isinf(ref) + both_valid = valid_mask & ref_valid_mask + + if both_valid.sum() > 0: + out_valid = out[both_valid] + ref_valid = ref[both_valid] + + abs_error = torch.abs(out_valid - ref_valid) + ret["mean_abs_err"] = abs_error.mean().item() + ret["max_abs_err"] = abs_error.max().item() + + rel_error = abs_error / (torch.abs(ref_valid) + 1e-8) + ret["mean_rel_err"] = rel_error.mean().item() + else: + ret["mean_abs_err"] = float('nan') + ret["max_abs_err"] = float('nan') + ret["mean_rel_err"] = float('nan') + + if verbose: + print(f"\n{'='*80}") + print(f"FP4 GEMM Test: M={M}, N={N}, K={K}, dtype={dtype}") + print(f"{'='*80}") + print(f"Performance:") + print(f" Time: {us:.2f} us") + print(f" Throughput: {ret['TFLOPS']:.2f} TFLOPS") + print(f" Bandwidth: {ret['TB/s']:.2f} TB/s") + print(f"Correctness:") + print(f" Valid values: {num_valid}/{total} ({ret['valid_%']:.1f}%)") + print(f" NaN values: {num_nan}/{total} ({ret['nan_%']:.1f}%)") + print(f" Inf values: {num_inf}/{total} ({ret['inf_%']:.1f}%)") + if both_valid.sum() > 0: + print(f"Error vs Reference:") + print(f" Mean absolute error: {ret['mean_abs_err']:.6f}") + print(f" Max absolute error: {ret['max_abs_err']:.6f}") + print(f" Mean relative error: {ret['mean_rel_err']:.6f}") + print(f"{'='*80}\n") + + return ret + +def benchmark_block_sizes(): + """Sweep block sizes to find optimal configuration.""" + print("\n" + "="*80) + print("FP4 GEMM Block Size Sweep (8192x8192x8192)") + print("="*80) + + # Use smaller size to avoid OOM in Docker + M, N, K = 16384, 16384, 16384 + dtype = torch.bfloat16 + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Generate test data once + x = torch.randn((M, K), dtype=dtype) + w = torch.randn((N, K), dtype=dtype) + w_fp4, w_scales = dynamic_mxfp4_quant(w) + out = torch.empty((M, N), dtype=dtype) + + # for RMSNorm + weight = torch.randn(K, dtype=dtype) + eps = 1e-5 + use_model_sensitive_rmsnorm = 0 + quant_dtype = torch.float8_e4m3fnuz + xq_fused = torch.empty(x.shape, dtype=quant_dtype, device=device) + xscale_fused = torch.empty(x.shape[0], 1, dtype=torch.float32, device="cuda") + # x_norm = tritonblas.rms_norm(x, weight, eps, use_model_sensitive_rmsnorm) + + quant_dtype = torch.float8_e4m3fn + w_fp8 = torch.empty(w.shape, dtype=quant_dtype, device=device) + w_fp8_scales = torch.empty(w.shape[0], 1, dtype=torch.float32, device=device) + # w_fp8, w_fp8_scales = quantize_tensor_per_channel(w.clone(), quant_dtype, axis=1 + x_fp8 = torch.empty(x.shape, dtype=quant_dtype, device=device) + x_fp8_scales = torch.empty(x.shape[0], 1, dtype=torch.float32, device=device) + aiter.ops.triton.quant.dynamic_per_tensor_quant_fp8_i8(w_fp8, w, w_fp8_scales) + # for Hadamard + had_size = 32 + + # Block size configurations to test + block_m_sizes = [64, 128, 256] + block_n_sizes = [64, 128, 256] + block_k_sizes = [128, 256, 512] + + results = {} + best_tflops = 0 + best_config = None + print(x.shape) + for block_k in block_k_sizes: + results[block_k] = {} + for block_m in block_m_sizes: + for block_n in block_n_sizes: + try: + def run_kernel(): + x_norm = tritonblas.rms_norm(x.clone(), weight, eps, use_model_sensitive_rmsnorm) + # x_fp82, x_fp8_scales2 = quantize_tensor_per_channel(x_norm.clone(), quant_dtype, axis=1) + aiter.ops.triton.quant.dynamic_per_tensor_quant_fp8_i8(x_fp8, x_norm, x_fp8_scales) + selector = tritonblas.MatmulHeuristicResult( + M, N, K, x_fp8.dtype, w_fp8.dtype, out.dtype + ) + tritonblas.matmul_a8w8_lt(x_fp8, w_fp8, x_fp8_scales, w_fp8_scales, out, selector) + + # x_norm = tritonblas.rms_norm(x, weight, eps, use_model_sensitive_rmsnorm) + # (x_fp4, x_scales), _, _ = tritonblas.fused_rms_hadamard_mxfp4_quant(x, weight, eps) + # x_norm = tritonblas.rmsnorm2d_fwd_with_dynamicquant(xq_fused, x, xscale_fused, weight, eps) + # x_fp4, x_scales = dynamic_mxfp4_quant(x) + + # tritonblas.matmul_fp4( + # x_fp4, w_fp4, out, x_scales, w_scales, + # block_m=block_m, block_n=block_n, block_k=block_k + # ) + + # triton_result = full_hadamard_triton(x) / math.sqrt(had_size) + + # triton_result = full_hadamard_triton(x.reshape(-1, 32)) / math.sqrt(had_size) + + # ref_result = fwht_torch_reference(x.clone(), had_size) + + # H = generate_hadamard_matrix(had_size).cuda() / math.sqrt(had_size) + # blocked_result = hadamard_blocked_transform(x, H) + # fast_blocked_result = tritonblas.hadamard_blocked_fast(x.reshape(-1,had_size)) + # fast_blocked_result = tritonblas.hadamard_blocked_fast(x) + + us = benchmark_kernel(run_kernel, num_iters=5, warmup=2) + total_ops = 2 * M * N * K + tflops = total_ops / us / 1e6 + + key = f"M{block_m}_N{block_n}" + results[block_k][key] = us #tflops + + if tflops > best_tflops: + best_tflops = tflops + best_config = (block_m, block_n, block_k) + + except Exception as e: + key = f"M{block_m}_N{block_n}" + results[block_k][key] = 0.0 + print(f"BLK_M={block_m}, BLK_N={block_n}, BLK_K={block_k}: FAILED - {str(e)}") + + # Print results table + print("\nThroughput Table (TFLOPS):") + print("-" * 80) + + # Header + header = "BLK_K |" + for block_m in block_m_sizes: + for block_n in block_n_sizes: + header += f" M{block_m:3d}xN{block_n:3d} |" + print(header) + print("-" * len(header)) + + # Rows + for block_k in block_k_sizes: + row = f" {block_k:3d} |" + for block_m in block_m_sizes: + for block_n in block_n_sizes: + key = f"M{block_m}_N{block_n}" + tflops = results[block_k].get(key, 0.0) + + # Highlight best configuration + if best_config and (block_m, block_n, block_k) == best_config: + row += f" *{tflops:6.2f}* |" + else: + row += f" {tflops:7.2f} |" + print(row) + + print("-" * 80) + + if best_config: + print(f"\nBest Configuration:") + print(f" BLK_M={best_config[0]}, BLK_N={best_config[1]}, BLK_K={best_config[2]}") + print(f" Performance: {best_tflops:.2f} TFLOPS") + + print("="*80 + "\n") + return results + + +def main(): + """Main test runner.""" + parser = argparse.ArgumentParser( + description="TritonBLAS FP4 GEMM Test Suite", + formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "-d", "--dtype", + type=str, + choices=["bf16", "fp16"], + default="bf16", + help="Data type for output (default: bf16)" + ) + parser.add_argument( + "-m", "--mode", + type=str, + choices=["all", "correctness", "production", "blocksweep", "single"], + default="all", + help="Test mode (default: all)" + ) + parser.add_argument( + "--mnk", + type=str, + default=None, + help="Single test size as M,N,K (e.g., --mnk 1024,1024,1024)" + ) + + args = parser.parse_args() + + # Map dtype string to torch dtype + dtype_map = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + } + dtype = dtype_map[args.dtype] + + print("\n" + "="*80) + print("TritonBLAS FP4 GEMM Test Suite") + print("="*80) + print(f"Output dtype: {dtype}") + print(f"Test mode: {args.mode}") + print("="*80) + + if args.mode == "blocksweep" or args.mode == "all": + # Block size sweep + benchmark_block_sizes() + + print("\n" + "="*80) + print("Test suite completed!") + print("="*80 + "\n") + + +if __name__ == "__main__": + main() diff --git a/tests/test_rmse_gemm.py b/tests/test_rmse_gemm.py new file mode 100644 index 0000000..82422a9 --- /dev/null +++ b/tests/test_rmse_gemm.py @@ -0,0 +1,357 @@ +import math +import torch +import torch.nn.functional as F +import numpy as np +from hadamard import hadamard_blocked_transform, generate_hadamard_matrix +import tritonblas +from tritonblas.utils import quantize_tensor_per_channel +from fp4_utils import dynamic_mxfp4_quant, mxfp4_to_f32, e8m0_to_f32 +import aiter +def is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + + +def is_triton_available() -> bool: + """Check if Triton is available.""" + try: + import triton + return True + except ImportError: + return False + + +# Reference implementation +def hadamard_matrix(n: int, dtype=torch.float32) -> np.ndarray: + """ + Create a Hadamard matrix of size n x n using Sylvester's construction. + n must be a power of 2. + + Args: + n: Size of the Hadamard matrix (must be power of 2) + + Returns: + Hadamard matrix as numpy array + """ + if n < 1: + lg2 = 0 + else: + lg2 = int(math.log(n, 2)) + if 2 ** lg2 != n: + raise ValueError("n must be int and power of 2.") + + if n == 1: + return np.array([[1]], dtype=np.float32) + + H = np.array([[1, 1], [1, -1]], dtype=np.float32) + + # Hadamard stacking via Sylvester's construction + # H H + # H -H + for i in range(0, lg2 - 1): + H = np.vstack((np.hstack((H, H)), np.hstack((H, -H)))) + + return H.to(dtype) + + +def fwht_matmul(x: torch.Tensor) -> torch.Tensor: + """ + Matrix multiplication version of Hadamard transform (un-normalized). + Power-of-two length required along last dimension. + + Args: + x: Input tensor with power-of-2 size in last dimension + + Returns: + Hadamard transformed tensor + """ + *leading, N = x.shape + + if not is_power_of_2(N): + raise ValueError(f"N must be power-of-two, got {N}") + + # Create Hadamard matrix + H = generate_hadamard_matrix(N, dtype=x.dtype).cuda() / math.sqrt(N) + + # H_tensor = torch.from_numpy(H).to(x.device).to(x.dtype) + # Perform matrix multiplication: x @ H + return x @ H + + +def fwht_torch_reference(x: torch.Tensor, N: int = 0) -> torch.Tensor: + """ + Reference FWHT implementation matching Sylvester's Hadamard construction (un-normalized). + Power-of-two length required along last dimension. + """ + y = x.clone() + if not N: + N = y.shape[-1] + + # Perform butterfly operations matching Sylvester's construction + # Start with stride N/2 and work down to stride 1 + stride = N // 2 + while stride >= 1: + for start in range(0, N, stride * 2): + for i in range(stride): + idx1 = start + i + idx2 = start + i + stride + a = y[..., idx1].clone() + b = y[..., idx2].clone() + y[..., idx1] = a + b + y[..., idx2] = a - b + stride //= 2 + + y = y / math.sqrt(N) + return y + + +# Triton kernels +if is_triton_available(): + import triton + import triton.language as tl + + @triton.jit + def _fwht_tile_kernel(X_ptr, stride_row, N, TILE: tl.constexpr, LOG_TILE: tl.constexpr): + """Blockwise FWHT (per-tile) with masked final (partial) tile.""" + pid = tl.program_id(0) + tiles_per_row = tl.cdiv(N, TILE) + row = pid // tiles_per_row + tile_id = pid % tiles_per_row + row_base = row * stride_row + start = tile_id * TILE + + offs = tl.arange(0, TILE) + idx_i = start + offs + valid_i = idx_i < N + + # Iterate butterfly stages + for s in range(LOG_TILE): + dist = 1 << s + idx_j = idx_i ^ dist + valid_j = idx_j < N + process = (offs & dist) == 0 + + # Load both ends + vi = tl.load(X_ptr + row_base + idx_i, mask=valid_i, other=0.0) + vj = tl.load(X_ptr + row_base + idx_j, mask=valid_j, other=0.0) + + vsum = vi + vj + vdiff = vi - vj + # Store results + tl.store(X_ptr + row_base + idx_i, vsum, mask=valid_i & process) + tl.store(X_ptr + row_base + idx_j, vdiff, mask=valid_j & process) + + @triton.jit + def _fwht_merge_kernel(X_ptr, stride_row, N, STAGE_DIST: tl.constexpr): + """Merge stage across tiles.""" + pid = tl.program_id(0) + groups_per_row = N // (2 * STAGE_DIST) + row_id = pid // groups_per_row + group_id = pid % groups_per_row + base0 = row_id * stride_row + group_id * 2 * STAGE_DIST + base1 = base0 + STAGE_DIST + offs = tl.arange(0, STAGE_DIST) + a = tl.load(X_ptr + base0 + offs) + b = tl.load(X_ptr + base1 + offs) + tl.store(X_ptr + base0 + offs, a + b) + tl.store(X_ptr + base1 + offs, a - b) + + +def _select_block(N: int, candidate_blocks=(32, 64, 128), max_mono_kernel_n=2048, block_size=None): + """Choose tile size for tiled FWHT.""" + divs = [b for b in candidate_blocks if b <= N and N % b == 0] + if not divs: + return N + if block_size and block_size in divs: + return block_size + if (N in divs) and (N <= max_mono_kernel_n) and not block_size: + return N + small = [b for b in divs if b <= 128] + if small: + return sorted(small, reverse=True)[0] + return max(divs) + +def torch_rmsnorm(x, g, out_dtype=torch.float16, epsilon=1e-6): + M, N = x.shape + # cast to float32 as the triton kernel + x_f32 = x.float() + g_f32 = g.float() + rms = torch.sqrt(torch.sum(x_f32 * x_f32, dim=-1) * 1 / N) + rsigma = 1.0 / rms + rms_norm_f32 = x_f32 * rsigma.unsqueeze(1) * g_f32 + rms_norm = rms_norm_f32.to(out_dtype) + return rms_norm + +def make_outlier_tensor(shape, seed=0, outlier_ratio=0.01): + """Create a tensor with outliers in random rows for testing (row-wise outlier).""" + g = torch.Generator(device="cuda").manual_seed(seed) + base = torch.randn(shape, generator=g, device="cuda", dtype=torch.bfloat16) + + # Create mask based on random rows (axis 0) + num_rows = shape[0] + num_outlier_rows = max(1, int(num_rows * outlier_ratio)) + + # Randomly select rows to be outliers + outlier_indices = torch.randperm(num_rows, generator=g, device="cuda")[:num_outlier_rows] + + # Create mask: 1 for outlier rows, 0 for others + mask = torch.zeros(shape, device="cuda", dtype=torch.bfloat16) + mask[outlier_indices, :] = 1.0 + + spikes = torch.randn(shape, generator=g, device="cuda", dtype=torch.bfloat16) * 25.0 + return base + mask * spikes + +def full_hadamard_triton(x: torch.Tensor, block_size=None): + """ + Full Hadamard transform using Triton kernels. + Un-normalized. Requires power-of-two length. + """ + if not is_triton_available(): + raise RuntimeError("Triton not available") + + assert x.is_contiguous() + *leading, N = x.shape + + if not is_power_of_2(N): + raise ValueError(f"N must be power-of-two, got {N}") + + rows = int(torch.tensor(leading).prod()) if leading else 1 + BLOCK = _select_block(N, block_size=block_size) + LOG_BLOCK = int(math.log2(BLOCK)) + tiles_per_row = N // BLOCK + total_intra = rows * tiles_per_row + stride_row = N + + # Intra-tile FWHT + _fwht_tile_kernel[(total_intra,)](x.view(-1, N), stride_row, N, TILE=BLOCK, LOG_TILE=LOG_BLOCK) + + # If single tile (BLOCK == N), we're done + if BLOCK == N: + return x + + # Inter-tile merges + dist = BLOCK + while dist < N: + groups_per_row = N // (2 * dist) + total_groups = rows * groups_per_row + _fwht_merge_kernel[(total_groups,)](x.view(-1, N), stride_row, N, STAGE_DIST=dist) + dist *= 2 + + return x + + +def test_full_hadamard(): + """Test full Hadamard: matrix multiplication, PyTorch reference, and Triton kernel.""" + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + + test_sizes = [16384] #[32] #, 64, 128, 256, 512, 1024] + batch_sizes = [1] #, 4, 32, 64] + + print(f"\nRunning tests on device: {device}") + print("=" * 60) + + # block_m_sizes = [64, 128, 256] + # block_n_sizes = [64, 128, 256] + # block_k_sizes = [128, 256, 512] + block_m = 256 + block_n = 256 + block_k = 256 + for N in test_sizes: + for batch in batch_sizes: + print(f"\nTesting N={N}, batch={batch}") + print("-" * 40) + + K = N + M = N + # Create outlier tensor + x = make_outlier_tensor((batch,N)) + x = x.to(dtype) + + # x = torch.randn(batch, N, device=device, dtype=dtype) + # quant_dtype = torch.float8_e4m3fnuz + quant_dtype = torch.float8_e4m3fn + x_quant = torch.empty(x.shape, dtype=quant_dtype, device=device) + x_fp8 = torch.empty(x.shape, dtype=quant_dtype, device=device) + # x_quant = torch.empty(x.shape, dtype=dtype, device=device) + x_fp8_scales = torch.empty(x.shape[0], 1, dtype=torch.float32, device=device) + x_quant_scales = torch.empty(x.shape[0], 1, dtype=torch.float32, device=device) + w = torch.randn((N, K), device=device, dtype=dtype) + w_fp8 = torch.empty(w.shape, dtype=quant_dtype, device=device) + w_fp8_scales = torch.empty(w.shape[0], 1, dtype=torch.float32, device=device) + # w_fp8, w_fp8_scales = quantize_tensor_per_channel(w.clone(), quant_dtype, axis=1) + aiter.ops.triton.quant.dynamic_per_tensor_quant_fp8_i8(w_fp8, w, w_fp8_scales) + w_fp4, w_scales = dynamic_mxfp4_quant(w) + out_fp16 = torch.empty((M, N), device=device, dtype=dtype) + out_fp8 = torch.empty((M, N), device=device, dtype=dtype) + out_fp4 = torch.empty((M, N), device=device, dtype=dtype) + out_fp4_fused = torch.empty((M, N), device=device, dtype=dtype) + out_had_unfused = torch.empty((M, N), device=device, dtype=dtype) + out_had_fused = torch.empty((M, N), device=device, dtype=dtype) + + # rmsnorm parameters + weight = torch.randn(N, dtype=dtype, device=device) + eps = 1e-5 + use_model_sensitive_rmsnorm = 0 + + # 1. FP16 RMSNorm + FP16 GEMM + x_norm = tritonblas.rms_norm(x.clone(), weight, eps, use_model_sensitive_rmsnorm) + tritonblas.matmul(x_norm, w, out_fp16) + + # 1. FP16 RMSNorm + FP8 GEMM + x_norm = tritonblas.rms_norm(x.clone(), weight, eps, use_model_sensitive_rmsnorm) + # x_fp82, x_fp8_scales2 = quantize_tensor_per_channel(x_norm.clone(), quant_dtype, axis=1) + aiter.ops.triton.quant.dynamic_per_tensor_quant_fp8_i8(x_fp8, x_norm, x_fp8_scales) + selector = tritonblas.MatmulHeuristicResult( + M, N, K, x_fp8.dtype, w_fp8.dtype, out_fp8.dtype + ) + tritonblas.matmul_a8w8_lt(x_fp8, w_fp8, x_fp8_scales, w_fp8_scales, out_fp8, selector) + + # 2. FP16 RMSNorm + quant + FP4 GEMM + x_norm = tritonblas.rms_norm(x, weight, eps, use_model_sensitive_rmsnorm) + x_fp4_v1, x_scales_v1 = dynamic_mxfp4_quant(x_norm) + tritonblas.matmul_fp4( + x_fp4_v1, w_fp4, out_fp4, x_scales_v1, w_scales, + block_m=block_m, block_n=block_n, block_k=block_k + ) + + # 3. fused (FP16 RMSNorm + quant) + FP4 GEMM + x_norm = tritonblas.rmsnorm2d_fwd_with_dynamicquant(x_quant, x.clone(), x_quant_scales, weight, eps) + tritonblas.matmul_fp4( + x_fp4_v1, w_fp4, out_fp4_fused, x_scales_v1, w_scales, + block_m=block_m, block_n=block_n, block_k=block_k + ) + + # 4. unfused rmsnorm + hadamard + quant + FP4 GEMM + x_norm_1 = tritonblas.rms_norm(x, weight, eps, use_model_sensitive_rmsnorm) + x_had = tritonblas.hadamard_blocked_fast(x_norm_1) + x_fp4_v1, x_scales_v1 = dynamic_mxfp4_quant(x_had) + tritonblas.matmul_fp4( + x_fp4_v1, w_fp4, out_had_unfused, x_scales_v1, w_scales, + block_m=block_m, block_n=block_n, block_k=block_k + ) + + # 5. fused (rmsnorm + hadamard + quant) + FP4 GEMM + (x_fp4, x_scales), _, _ = tritonblas.fused_rms_hadamard_mxfp4_quant(x.clone(), weight, eps) + tritonblas.matmul_fp4( + x_fp4, w_fp4, out_had_fused, x_scales, w_scales, + block_m=block_m, block_n=block_n, block_k=block_k + ) + # (x_fp4, x_scales), _, _ = tritonblas.fused_mxfp4_quant(x.clone(), weight, eps) + + print(f"||out||: {torch.norm(out_fp16)}") + # print(f"FP16 vs. FP8: {F.mse_loss(x_norm, x_fp8.to(torch.float32))}") + print(f"FP16 vs. FP8: {F.mse_loss(out_fp16, out_fp8)}") + print(f"FP16 vs. FP4: {F.mse_loss(out_fp16, out_fp4)}") + print(f"FP16 vs. FP4 fused: {F.mse_loss(out_fp16, out_fp4_fused)}") + print(f"FP16 vs. FP4 hadamard: {F.mse_loss(out_fp16, out_had_unfused)}") + print(f"FP16 vs. FP4 hadamard fused: {F.mse_loss(out_fp16, out_had_fused)}") + + + print("\n" + "=" * 60) + print("Test completed") + + +if __name__ == "__main__": + test_full_hadamard() \ No newline at end of file