From f9bf9a930cb7b04f12ee756bc1fc871ba3473162 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Wed, 21 Jan 2026 15:36:01 -0500 Subject: [PATCH 1/4] Add simple user epilogues to persistent gemm kernel --- examples/example_matmul_epilogue.py | 132 +++++++++++ include/tritonblas/kernels/persistent_gemm.py | 5 + .../kernels/stages/algorithms/__init__.py | 21 ++ .../kernels/stages/algorithms/epilogue.py | 158 ++++++++++++ tests/test_epilogues.py | 224 ++++++++++++++++++ 5 files changed, 540 insertions(+) create mode 100644 examples/example_matmul_epilogue.py create mode 100644 include/tritonblas/kernels/stages/algorithms/epilogue.py create mode 100644 tests/test_epilogues.py diff --git a/examples/example_matmul_epilogue.py b/examples/example_matmul_epilogue.py new file mode 100644 index 0000000..e146298 --- /dev/null +++ b/examples/example_matmul_epilogue.py @@ -0,0 +1,132 @@ +""" +Minimal example demonstrating custom epilogue functions in tritonBLAS. + +This example shows how to: +1. Define your own custom epilogue function +2. Pass it to the persistent GEMM kernel +3. Verify the results +""" +import torch +import triton +import triton.language as tl +from tritonblas.kernels.persistent_gemm import persistent_matmul + + +# ============================================================================ +# Define Custom Epilogue Function +# ============================================================================ + +@triton.jit +def my_custom_clamp(acc): + """ + Custom epilogue: clamp values between -1 and 1. + + This is a simple example showing how to create your own epilogue function. + You can perform any element-wise operation on the accumulator. + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Clamped accumulator + """ + return tl.minimum(tl.maximum(acc, -1.0), 1.0) + + +# ============================================================================ +# Helper Function to Run GEMM with Custom Epilogue +# ============================================================================ + +def matmul_with_custom_epilogue(A, B, epilogue_fn=None): + """ + Perform matrix multiplication with a custom epilogue function. + + Args: + A: Input matrix A [M, K] + B: Input matrix B [K, N] (transposed) + epilogue_fn: Custom epilogue function to apply + + Returns: + Output matrix C [M, N] + """ + M, K = A.shape + _, N = B.shape + C = torch.zeros((M, N), device="cuda", dtype=A.dtype) + + # Get device properties + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + # Fixed block sizes + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + GROUP_SIZE_M = 8 + + # Define grid + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + # Launch kernel with custom epilogue + persistent_matmul[grid]( + A, B, C, + None, None, # No quantization scales + A, # Dummy bias pointer (not used) + M, N, K, + A.stride(0), B.stride(1), + C.stride(0), C.stride(1), + 0, # stride_bias (not used) + A.stride(1), B.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=num_sms, + NUM_XCDS=1, + CHUNK_SIZE=1, + BIAS=False, + EVEN_K=(K % BLOCK_SIZE_K == 0), + CACHE_MODIFIER_A=".cg", + CACHE_MODIFIER_B=".cg", + epilogue_fn=epilogue_fn, # Pass your custom epilogue here! + QUANTIZED=False, + ) + + return C + + +# ============================================================================ +# Main Example +# ============================================================================ + +def main(): + print("\n" + "="*70) + print("Custom Epilogue Function Example") + print("="*70 + "\n") + + # Problem size + M, N, K = 512, 512, 512 + + # Allocate input matrices + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(N, K, device="cuda", dtype=torch.float16).T + + # ======================================================================== + # Example: Custom Clamp Epilogue + # ======================================================================== + print("Custom Clamp Epilogue (values between -1 and 1)") + print("-" * 70) + + C_custom = matmul_with_custom_epilogue(A, B, epilogue_fn=my_custom_clamp) + + # Verify against PyTorch + C_torch = torch.clamp(torch.matmul(A, B), -1.0, 1.0) + max_diff = torch.max(torch.abs(C_custom - C_torch)).item() + + print(f"Max difference from PyTorch: {max_diff:.6f}") + print(f"Min value: {C_custom.min().item():.6f}") + print(f"Max value: {C_custom.max().item():.6f}") + print(f"Sample output (first 3x3):\n{C_custom[:3, :3]}\n") + + + +if __name__ == "__main__": + main() diff --git a/include/tritonblas/kernels/persistent_gemm.py b/include/tritonblas/kernels/persistent_gemm.py index d00d104..ef179f3 100644 --- a/include/tritonblas/kernels/persistent_gemm.py +++ b/include/tritonblas/kernels/persistent_gemm.py @@ -37,6 +37,7 @@ def persistent_matmul( EVEN_K: tl.constexpr, CACHE_MODIFIER_A: tl.constexpr, CACHE_MODIFIER_B: tl.constexpr, + epilogue_fn=None, # Epilogue function to apply (default: None) QUANTIZED: tl.constexpr = False, # True for int8/fp8, False for fp16/bf16 ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32, ): @@ -106,6 +107,10 @@ def persistent_matmul( # Check if we're using quantized mode based on whether scales were applied acc = add_vector(acc, bias_vector, QUANTIZED=(A_scale_ptr is not None)) #Add bias vector to output accumulator + # Apply epilogue function to accumulator if provided + if epilogue_fn is not None: + acc = epilogue_fn(acc) + # Convert to output dtype result = convert_dtype(acc, C.type.element_ty) #Quantize output accumulator to output datatype diff --git a/include/tritonblas/kernels/stages/algorithms/__init__.py b/include/tritonblas/kernels/stages/algorithms/__init__.py index 827ae14..12260cf 100644 --- a/include/tritonblas/kernels/stages/algorithms/__init__.py +++ b/include/tritonblas/kernels/stages/algorithms/__init__.py @@ -33,6 +33,17 @@ convert_dtype, ) from .gemm_loop import gemm_loop +from .epilogue import ( + relu, + gelu, + gelu_tanh, + sigmoid, + silu, + tanh, + leaky_relu, + identity, + apply_epilogue, +) __all__ = [ # Binary operations @@ -43,4 +54,14 @@ 'convert_dtype', # Composition 'gemm_loop', + # Epilogue operations + 'relu', + 'gelu', + 'gelu_tanh', + 'sigmoid', + 'silu', + 'tanh', + 'leaky_relu', + 'identity', + 'apply_epilogue', ] diff --git a/include/tritonblas/kernels/stages/algorithms/epilogue.py b/include/tritonblas/kernels/stages/algorithms/epilogue.py new file mode 100644 index 0000000..a5b592a --- /dev/null +++ b/include/tritonblas/kernels/stages/algorithms/epilogue.py @@ -0,0 +1,158 @@ +""" +Epilogue functions for composable Triton GEMM kernels. +Operations applied to the output accumulator after GEMM computation. +""" +import triton +import triton.language as tl + + +@triton.jit +def relu(acc): + """ + Apply ReLU activation: max(0, x) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with ReLU applied + """ + return tl.maximum(acc, 0.0) + + +@triton.jit +def gelu(acc): + """ + Apply GELU activation (approximate version using tanh). + GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with GELU applied + """ + # Constants for GELU approximation + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2/π) + coeff = 0.044715 + + # GELU approximation using numerically stable tanh + x_cubed = acc * acc * acc + inner = sqrt_2_over_pi * (acc + coeff * x_cubed) + + # Numerically stable tanh: use different formulas for positive/negative values + # For x > 0: tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x)) + # For x < 0: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + exp_neg_2x = tl.exp(-2.0 * tl.abs(inner)) + tanh_inner = tl.where( + inner >= 0, + (1.0 - exp_neg_2x) / (1.0 + exp_neg_2x), + -(1.0 - exp_neg_2x) / (1.0 + exp_neg_2x) + ) + return 0.5 * acc * (1.0 + tanh_inner) + + +@triton.jit +def gelu_tanh(acc): + """ + Apply GELU activation using tanh approximation (alias for gelu). + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with GELU applied + """ + return gelu(acc) + + +@triton.jit +def sigmoid(acc): + """ + Apply Sigmoid activation: 1 / (1 + exp(-x)) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with Sigmoid applied + """ + return tl.sigmoid(acc) + + +@triton.jit +def silu(acc): + """ + Apply SiLU (Swish) activation: x * sigmoid(x) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with SiLU applied + """ + return acc * tl.sigmoid(acc) + + +@triton.jit +def tanh(acc): + """ + Apply Tanh activation using numerically stable formula. + For x > 0: tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x)) + For x < 0: tanh(x) = -(1 - exp(2x)) / (1 + exp(2x)) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with Tanh applied + """ + # Use numerically stable formula to avoid overflow + exp_neg_2x = tl.exp(-2.0 * tl.abs(acc)) + result = (1.0 - exp_neg_2x) / (1.0 + exp_neg_2x) + # Apply sign + return tl.where(acc >= 0, result, -result) + + +@triton.jit +def leaky_relu(acc, negative_slope: tl.constexpr = 0.01): + """ + Apply Leaky ReLU activation: max(0, x) + negative_slope * min(0, x) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + negative_slope: Slope for negative values (default: 0.01) + + Returns: + Accumulator with Leaky ReLU applied + """ + return tl.where(acc > 0, acc, acc * negative_slope) + + +@triton.jit +def identity(acc): + """ + Identity function (no activation). + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Unchanged accumulator + """ + return acc + + +@triton.jit +def apply_epilogue(acc, epilogue_fn): + """ + Apply an epilogue function to the accumulator. + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + epilogue_fn: Epilogue function to apply (e.g., relu, gelu, etc.) + + Returns: + Accumulator with epilogue applied + """ + return epilogue_fn(acc) diff --git a/tests/test_epilogues.py b/tests/test_epilogues.py new file mode 100644 index 0000000..7ddfb91 --- /dev/null +++ b/tests/test_epilogues.py @@ -0,0 +1,224 @@ +""" +Tests for epilogue functions in tritonBLAS. +""" +import pytest +import torch +import triton +from tritonblas.kernels.persistent_gemm import persistent_matmul +from tritonblas.kernels.stages.algorithms.epilogue import ( + relu, gelu, silu, tanh, sigmoid, leaky_relu, identity +) + + +def run_matmul_with_epilogue(M, N, K, epilogue_fn=None, bias=None, dtype=torch.float16): + """ + Helper function to run matrix multiplication with epilogue. + + Args: + M, N, K: Matrix dimensions + epilogue_fn: Epilogue function to apply + bias: Optional bias vector + dtype: Data type for tensors + + Returns: + Output tensor from Triton kernel + """ + # Allocate tensors + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(N, K, device="cuda", dtype=dtype).T + C = torch.zeros((M, N), device="cuda", dtype=dtype) + + # Get device properties + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + # Setup bias + has_bias = bias is not None + if has_bias: + bias_ptr = bias + stride_bias = bias.stride(0) + else: + bias_ptr = A # Dummy pointer + stride_bias = 0 + + # Fixed block sizes + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + GROUP_SIZE_M = 8 + + # Define grid + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + # Launch kernel + persistent_matmul[grid]( + A, B, C, + None, None, # No quantization scales + bias_ptr, + M, N, K, + A.stride(0), B.stride(1), + C.stride(0), C.stride(1), + stride_bias, + A.stride(1), B.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=num_sms, + NUM_XCDS=1, + CHUNK_SIZE=1, + BIAS=has_bias, + EVEN_K=(K % BLOCK_SIZE_K == 0), + CACHE_MODIFIER_A=".cg", + CACHE_MODIFIER_B=".cg", + epilogue_fn=epilogue_fn, + QUANTIZED=False, + ) + + return C, A, B + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), + (128, 256, 512), +]) +def test_identity_epilogue(M, N, K): + """Test identity epilogue (no activation).""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=identity) + C_torch = torch.matmul(A, B) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Identity epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_relu_epilogue(M, N, K): + """Test ReLU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=relu) + C_torch = torch.relu(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"ReLU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_gelu_epilogue(M, N, K): + """Test GELU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=gelu) + C_torch = torch.nn.functional.gelu(torch.matmul(A, B), approximate='tanh') + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"GELU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_silu_epilogue(M, N, K): + """Test SiLU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=silu) + C_torch = torch.nn.functional.silu(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"SiLU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_tanh_epilogue(M, N, K): + """Test Tanh epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=tanh) + C_torch = torch.tanh(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Tanh epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_sigmoid_epilogue(M, N, K): + """Test Sigmoid epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=sigmoid) + C_torch = torch.sigmoid(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Sigmoid epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_leaky_relu_epilogue(M, N, K): + """Test Leaky ReLU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=leaky_relu) + C_torch = torch.nn.functional.leaky_relu(torch.matmul(A, B), negative_slope=0.01) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Leaky ReLU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), +]) +def test_epilogue_with_bias(M, N, K): + """Test epilogue with bias addition.""" + bias = torch.randn(M, device="cuda", dtype=torch.float16) + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=relu, bias=bias) + + C_torch = torch.matmul(A, B) + bias.unsqueeze(1) + C_torch = torch.relu(C_torch) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"ReLU epilogue with bias failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), +]) +def test_no_epilogue(M, N, K): + """Test that None epilogue works correctly.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=None) + C_torch = torch.matmul(A, B) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"No epilogue (None) failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +if __name__ == "__main__": + # Run tests manually + print("Running epilogue tests...") + + test_functions = [ + test_identity_epilogue, + test_relu_epilogue, + test_gelu_epilogue, + test_silu_epilogue, + test_tanh_epilogue, + test_sigmoid_epilogue, + test_leaky_relu_epilogue, + test_epilogue_with_bias, + test_no_epilogue, + ] + + for test_func in test_functions: + try: + test_func(256, 256, 256) + print(f"✓ {test_func.__name__} PASSED") + except AssertionError as e: + print(f"✗ {test_func.__name__} FAILED: {e}") + except Exception as e: + print(f"✗ {test_func.__name__} ERROR: {e}") + + print("\nAll tests completed!") From 15a021998f3f4894064f106b520b077352d8412f Mon Sep 17 00:00:00 2001 From: Ryan Swann <109695074+ryanswann-amd@users.noreply.github.com> Date: Wed, 21 Jan 2026 14:44:16 -0600 Subject: [PATCH 2/4] Update include/tritonblas/kernels/stages/algorithms/epilogue.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/tritonblas/kernels/stages/algorithms/epilogue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tritonblas/kernels/stages/algorithms/epilogue.py b/include/tritonblas/kernels/stages/algorithms/epilogue.py index a5b592a..0a466ab 100644 --- a/include/tritonblas/kernels/stages/algorithms/epilogue.py +++ b/include/tritonblas/kernels/stages/algorithms/epilogue.py @@ -99,7 +99,7 @@ def tanh(acc): """ Apply Tanh activation using numerically stable formula. For x > 0: tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x)) - For x < 0: tanh(x) = -(1 - exp(2x)) / (1 + exp(2x)) + For x < 0: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) Args: acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] From 44787cd5808878566353294c29b27fcac160060d Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Wed, 21 Jan 2026 16:09:29 -0500 Subject: [PATCH 3/4] Add example using epilogues to fuse hadamard to gemm --- examples/example_hadamard_epilogue.py | 290 ++++++++++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 examples/example_hadamard_epilogue.py diff --git a/examples/example_hadamard_epilogue.py b/examples/example_hadamard_epilogue.py new file mode 100644 index 0000000..6312277 --- /dev/null +++ b/examples/example_hadamard_epilogue.py @@ -0,0 +1,290 @@ +""" +Example demonstrating Hadamard rotation epilogue in tritonBLAS. + +This example shows how to: +1. Define a custom Hadamard rotation epilogue function +2. Apply it to the GEMM output accumulator +3. Use it for randomized numerical linear algebra + +This is useful for privacy-preserving computations and randomized algorithms. +""" +import torch +import triton +import triton.language as tl +from include.tritonblas.kernels.persistent_gemm import persistent_matmul + + +# ============================================================================ +# Define Hadamard Rotation Epilogue +# ============================================================================ + +@triton.jit +def build_hadamard(SIZE: tl.constexpr): + """ + Construct Hadamard matrix using H_{i,j} = (-1)^{popcount(i & j)}. + + This computes the bitwise dot product of row index i and column index j: + - popcount(i & j) counts the number of matching 1-bits + - If count is even: H_{i,j} = 1 + - If count is odd: H_{i,j} = -1 + + Args: + SIZE: Matrix dimension (must be power of 2, 16 <= SIZE <= 64) + + Returns: + SIZE x SIZE normalized Hadamard matrix + """ + tl.static_assert(16 <= SIZE) + tl.static_assert(SIZE <= 64) + + # Create row and column indices + i = tl.arange(0, SIZE) + j = tl.arange(0, SIZE) + + # Compute bitwise AND for all (i, j) pairs + matching_bits = i[:, None] & j[None, :] + + # Count set bits (popcount) - iterative approach + bit_sum = tl.zeros_like(matching_bits) + temp = matching_bits + for _ in tl.static_range(6): # 6 iterations for up to 64 bits + bit_sum += temp & 1 + temp >>= 1 + + # Map: even popcount -> +1, odd popcount -> -1 + H = 1 - 2 * (bit_sum & 1) + + # Normalize by sqrt(SIZE) + H = H / tl.math.sqrt(float(SIZE)) + return H + + +@triton.jit +def hadamard_rotation(acc, BLOCK_SIZE: tl.constexpr = 16): + """ + Apply Hadamard rotation to the accumulator in blocks. + + This epilogue applies a Hadamard transformation to blocks of the accumulator: + For each BLOCK_SIZE x BLOCK_SIZE block: result = block @ H + + Constraints: + - BLOCK_SIZE must be a power of 2 + - 16 <= BLOCK_SIZE <= 64 + - BLOCK_SIZE must evenly divide both accumulator dimensions + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + BLOCK_SIZE: Size of Hadamard blocks (default: 32) + + Returns: + Accumulator with Hadamard rotation applied to each block + """ + # Get accumulator dimensions + M = acc.shape[0] + N = acc.shape[1] + + # Static assertions for valid block size + tl.static_assert(16 <= BLOCK_SIZE) + tl.static_assert(BLOCK_SIZE <= 64) + tl.static_assert(BLOCK_SIZE <= M) + tl.static_assert(BLOCK_SIZE <= N) + + # Build Hadamard matrix once + H = build_hadamard(BLOCK_SIZE) + + # Process each block + result = tl.zeros_like(acc) + + # Iterate over blocks in M dimension + for m_block in tl.static_range(M // BLOCK_SIZE): + m_start = m_block * BLOCK_SIZE + m_end = m_start + BLOCK_SIZE + + # Iterate over blocks in N dimension + for n_block in tl.static_range(N // BLOCK_SIZE): + n_start = n_block * BLOCK_SIZE + n_end = n_start + BLOCK_SIZE + + # Extract block + block = acc[m_start:m_end, n_start:n_end] + + # Apply Hadamard: block @ H + rotated = tl.dot(block, H.to(block.dtype)) + + # Store result + result[m_start:m_end, n_start:n_end] = rotated + + return result + + +def matmul_with_hadamard(A, B, block_size=32): + """ + Perform matrix multiplication with Hadamard rotation epilogue. + + Args: + A: Input matrix A [M, K] + B: Input matrix B [K, N] (transposed) + block_size: Size of Hadamard blocks (must be 16, 32, or 64) + + Returns: + Output matrix C [M, N] with Hadamard rotation applied + """ + M, K = A.shape + _, N = B.shape + C = torch.zeros((M, N), device="cuda", dtype=A.dtype) + + # Get device properties + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + # Block sizes must be compatible with Hadamard block size + # For this example, we use 128x128 tiles which are divisible by 32 + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + GROUP_SIZE_M = 8 + + # Verify dimensions are compatible + assert BLOCK_SIZE_M % block_size == 0, f"BLOCK_SIZE_M ({BLOCK_SIZE_M}) must be divisible by block_size ({block_size})" + assert BLOCK_SIZE_N % block_size == 0, f"BLOCK_SIZE_N ({BLOCK_SIZE_N}) must be divisible by block_size ({block_size})" + + # Define grid + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + # Create epilogue function with specific block size + @triton.jit + def hadamard_epilogue(acc): + return hadamard_rotation(acc, BLOCK_SIZE=block_size) + + # Launch kernel with Hadamard epilogue + persistent_matmul[grid]( + A, B, C, + None, None, # No quantization scales + A, # Dummy bias pointer (not used) + M, N, K, + A.stride(0), B.stride(1), + C.stride(0), C.stride(1), + 0, # stride_bias (not used) + A.stride(1), B.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=num_sms, + NUM_XCDS=1, + CHUNK_SIZE=1, + BIAS=False, + EVEN_K=(K % BLOCK_SIZE_K == 0), + CACHE_MODIFIER_A=".cg", + CACHE_MODIFIER_B=".cg", + epilogue_fn=hadamard_epilogue, + QUANTIZED=False, + ) + + return C + + +def build_hadamard_torch(size): + """Build Hadamard matrix in PyTorch for verification.""" + i = torch.arange(size, dtype=torch.int32) + j = torch.arange(size, dtype=torch.int32) + + # Compute bitwise AND for all (i, j) pairs + matching_bits = i[:, None] & j[None, :] + + # Count set bits (popcount) + bit_sum = torch.zeros_like(matching_bits) + temp = matching_bits.clone() + for _ in range(6): # 6 iterations for up to 64 bits + bit_sum += temp & 1 + temp >>= 1 + + # Map: even popcount -> +1, odd popcount -> -1 + H = 1 - 2 * (bit_sum & 1) + + # Normalize + H = H.float() / (size ** 0.5) + return H + + +def apply_hadamard_torch(matrix, block_size=32): + """Apply Hadamard rotation in PyTorch for verification.""" + M, N = matrix.shape + result = torch.zeros_like(matrix) + + # Build Hadamard matrix + H = build_hadamard_torch(block_size).to(matrix.device, matrix.dtype) + + # Apply to each block + for m_block in range(M // block_size): + m_start = m_block * block_size + m_end = m_start + block_size + + for n_block in range(N // block_size): + n_start = n_block * block_size + n_end = n_start + block_size + + # Extract block and apply Hadamard + block = matrix[m_start:m_end, n_start:n_end] + result[m_start:m_end, n_start:n_end] = block @ H + + return result + + +def main(): + print("\n" + "="*70) + print("Hadamard Rotation Epilogue Example") + print("="*70 + "\n") + + # Problem size (must be divisible by block size) + M, N, K = 512, 512, 512 + block_size = 32 + + print(f"Matrix dimensions: M={M}, N={N}, K={K}") + print(f"Hadamard block size: {block_size}x{block_size}\n") + + # Allocate input matrices + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(N, K, device="cuda", dtype=torch.float16).T + + # Run GEMM with Hadamard epilogue + print("Running GEMM with Hadamard rotation epilogue...") + C_triton = matmul_with_hadamard(A, B, block_size=block_size) + + # Compute reference: GEMM then Hadamard + print("Computing PyTorch reference...") + C_gemm = torch.matmul(A, B) + C_torch = apply_hadamard_torch(C_gemm, block_size=block_size) + + # Compare results + max_diff = torch.max(torch.abs(C_triton - C_torch)).item() + mean_diff = torch.mean(torch.abs(C_triton - C_torch)).item() + + print(f"\nResults:") + print("-" * 70) + print(f"Max difference from PyTorch: {max_diff:.6f}") + print(f"Mean difference from PyTorch: {mean_diff:.6f}") + print(f"Output shape: {C_triton.shape}") + print(f"\nSample output (first 3x3):\n{C_triton[:3, :3]}") + + # Verify Hadamard properties + print(f"\nHadamard Matrix Properties:") + print("-" * 70) + H = build_hadamard_torch(block_size).to("cuda", torch.float16) + H_squared = H @ H.T + identity_diff = torch.max(torch.abs(H_squared - torch.eye(block_size, device="cuda", dtype=torch.float16))).item() + print(f"H @ H^T ≈ I (max diff from identity): {identity_diff:.6f}") + print(f"Hadamard matrix is orthogonal: {identity_diff < 0.01}") + + print("\n" + "="*70) + print("Key Points:") + print("="*70) + print("1. Hadamard rotation is applied to blocks of the accumulator") + print("2. Block size must be power of 2 between 16 and 64") + print("3. Accumulator dimensions must be divisible by block size") + print("4. Hadamard matrices are orthogonal: H @ H^T = I") + print("5. Useful for randomized algorithms and privacy-preserving ML") + print("="*70 + "\n") + + +if __name__ == "__main__": + main() From b0c16fa261b0c60bf92af3479931dbdb3bb42bd1 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Thu, 22 Jan 2026 12:38:46 -0500 Subject: [PATCH 4/4] Update hadamard example --- examples/example_hadamard_epilogue.py | 235 ++++++++++++++------------ 1 file changed, 128 insertions(+), 107 deletions(-) diff --git a/examples/example_hadamard_epilogue.py b/examples/example_hadamard_epilogue.py index 6312277..800999d 100644 --- a/examples/example_hadamard_epilogue.py +++ b/examples/example_hadamard_epilogue.py @@ -1,17 +1,14 @@ """ Example demonstrating Hadamard rotation epilogue in tritonBLAS. -This example shows how to: -1. Define a custom Hadamard rotation epilogue function -2. Apply it to the GEMM output accumulator -3. Use it for randomized numerical linear algebra - -This is useful for privacy-preserving computations and randomized algorithms. +This example shows how to apply a Hadamard transformation to the entire +output accumulator tile. This is useful for randomized numerical linear algebra +and privacy-preserving computations. """ import torch import triton import triton.language as tl -from include.tritonblas.kernels.persistent_gemm import persistent_matmul +from tritonblas.kernels.persistent_gemm import persistent_matmul # ============================================================================ @@ -23,11 +20,6 @@ def build_hadamard(SIZE: tl.constexpr): """ Construct Hadamard matrix using H_{i,j} = (-1)^{popcount(i & j)}. - This computes the bitwise dot product of row index i and column index j: - - popcount(i & j) counts the number of matching 1-bits - - If count is even: H_{i,j} = 1 - - If count is odd: H_{i,j} = -1 - Args: SIZE: Matrix dimension (must be power of 2, 16 <= SIZE <= 64) @@ -35,7 +27,6 @@ def build_hadamard(SIZE: tl.constexpr): SIZE x SIZE normalized Hadamard matrix """ tl.static_assert(16 <= SIZE) - tl.static_assert(SIZE <= 64) # Create row and column indices i = tl.arange(0, SIZE) @@ -44,7 +35,7 @@ def build_hadamard(SIZE: tl.constexpr): # Compute bitwise AND for all (i, j) pairs matching_bits = i[:, None] & j[None, :] - # Count set bits (popcount) - iterative approach + # Count set bits (popcount) bit_sum = tl.zeros_like(matching_bits) temp = matching_bits for _ in tl.static_range(6): # 6 iterations for up to 64 bits @@ -60,74 +51,52 @@ def build_hadamard(SIZE: tl.constexpr): @triton.jit -def hadamard_rotation(acc, BLOCK_SIZE: tl.constexpr = 16): +def is_power_of_two(n: tl.constexpr) -> tl.constexpr: + """Check if n is a power of 2.""" + return (n & (n - 1)) == 0 and n > 0 + + +@triton.jit +def hadamard_rotation_square(acc): """ - Apply Hadamard rotation to the accumulator in blocks. - - This epilogue applies a Hadamard transformation to blocks of the accumulator: - For each BLOCK_SIZE x BLOCK_SIZE block: result = block @ H + Apply Hadamard rotation to the entire square accumulator. - Constraints: - - BLOCK_SIZE must be a power of 2 - - 16 <= BLOCK_SIZE <= 64 - - BLOCK_SIZE must evenly divide both accumulator dimensions + This works for any square accumulator with power-of-2 dimensions + between 16 and 64. Args: - acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] - BLOCK_SIZE: Size of Hadamard blocks (default: 32) + acc: Accumulator tensor [SIZE, SIZE] where SIZE is power of 2, 16 <= SIZE <= 64 Returns: - Accumulator with Hadamard rotation applied to each block + Accumulator with Hadamard transformation applied: acc @ H """ - # Get accumulator dimensions - M = acc.shape[0] - N = acc.shape[1] - - # Static assertions for valid block size - tl.static_assert(16 <= BLOCK_SIZE) - tl.static_assert(BLOCK_SIZE <= 64) - tl.static_assert(BLOCK_SIZE <= M) - tl.static_assert(BLOCK_SIZE <= N) - - # Build Hadamard matrix once - H = build_hadamard(BLOCK_SIZE) - - # Process each block - result = tl.zeros_like(acc) - - # Iterate over blocks in M dimension - for m_block in tl.static_range(M // BLOCK_SIZE): - m_start = m_block * BLOCK_SIZE - m_end = m_start + BLOCK_SIZE - - # Iterate over blocks in N dimension - for n_block in tl.static_range(N // BLOCK_SIZE): - n_start = n_block * BLOCK_SIZE - n_end = n_start + BLOCK_SIZE - - # Extract block - block = acc[m_start:m_end, n_start:n_end] - - # Apply Hadamard: block @ H - rotated = tl.dot(block, H.to(block.dtype)) - - # Store result - result[m_start:m_end, n_start:n_end] = rotated + SIZE:tl.constexpr = acc.shape[0] - return result + # Static assertions to enforce layout constraints + tl.static_assert(acc.shape[0] == acc.shape[1], "Accumulator must be square") + tl.static_assert((SIZE & (SIZE - 1)) == 0, "Accumulator size must be power of 2") + tl.static_assert(SIZE >= 16, "Accumulator size must be >= 16") + + # Build Hadamard matrix and apply + H = build_hadamard(SIZE) + return tl.dot(acc, H.to(acc.dtype)) -def matmul_with_hadamard(A, B, block_size=32): +# ============================================================================ +# Helper Function +# ============================================================================ + +def matmul_with_hadamard(A, B, tile_size=256): """ Perform matrix multiplication with Hadamard rotation epilogue. Args: A: Input matrix A [M, K] B: Input matrix B [K, N] (transposed) - block_size: Size of Hadamard blocks (must be 16, 32, or 64) + tile_size: Square tile size for GEMM (must be 32 or 64) Returns: - Output matrix C [M, N] with Hadamard rotation applied + Output matrix C [M, N] with Hadamard transformation applied to each tile """ M, K = A.shape _, N = B.shape @@ -136,25 +105,15 @@ def matmul_with_hadamard(A, B, block_size=32): # Get device properties num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count - # Block sizes must be compatible with Hadamard block size - # For this example, we use 128x128 tiles which are divisible by 32 - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 128 + # Use square tiles + BLOCK_SIZE_M = tile_size + BLOCK_SIZE_N = tile_size BLOCK_SIZE_K = 32 GROUP_SIZE_M = 8 - # Verify dimensions are compatible - assert BLOCK_SIZE_M % block_size == 0, f"BLOCK_SIZE_M ({BLOCK_SIZE_M}) must be divisible by block_size ({block_size})" - assert BLOCK_SIZE_N % block_size == 0, f"BLOCK_SIZE_N ({BLOCK_SIZE_N}) must be divisible by block_size ({block_size})" - # Define grid grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) - # Create epilogue function with specific block size - @triton.jit - def hadamard_epilogue(acc): - return hadamard_rotation(acc, BLOCK_SIZE=block_size) - # Launch kernel with Hadamard epilogue persistent_matmul[grid]( A, B, C, @@ -176,7 +135,7 @@ def hadamard_epilogue(acc): EVEN_K=(K % BLOCK_SIZE_K == 0), CACHE_MODIFIER_A=".cg", CACHE_MODIFIER_B=".cg", - epilogue_fn=hadamard_epilogue, + epilogue_fn=hadamard_rotation_square, QUANTIZED=False, ) @@ -194,7 +153,7 @@ def build_hadamard_torch(size): # Count set bits (popcount) bit_sum = torch.zeros_like(matching_bits) temp = matching_bits.clone() - for _ in range(6): # 6 iterations for up to 64 bits + for _ in range(6): bit_sum += temp & 1 temp >>= 1 @@ -206,26 +165,27 @@ def build_hadamard_torch(size): return H -def apply_hadamard_torch(matrix, block_size=32): - """Apply Hadamard rotation in PyTorch for verification.""" +def apply_hadamard_torch(matrix, tile_size): + """Apply Hadamard rotation to each tile in PyTorch using bmm.""" M, N = matrix.shape - result = torch.zeros_like(matrix) + num_tiles_m = M // tile_size + num_tiles_n = N // tile_size # Build Hadamard matrix - H = build_hadamard_torch(block_size).to(matrix.device, matrix.dtype) + H = build_hadamard_torch(tile_size).to(matrix.device, matrix.dtype) - # Apply to each block - for m_block in range(M // block_size): - m_start = m_block * block_size - m_end = m_start + block_size - - for n_block in range(N // block_size): - n_start = n_block * block_size - n_end = n_start + block_size - - # Extract block and apply Hadamard - block = matrix[m_start:m_end, n_start:n_end] - result[m_start:m_end, n_start:n_end] = block @ H + # Reshape matrix into tiles: (num_tiles_m, num_tiles_n, tile_size, tile_size) + matrix_tiled = matrix.reshape(num_tiles_m, tile_size, num_tiles_n, tile_size) + matrix_tiled = matrix_tiled.permute(0, 2, 1, 3) # (num_tiles_m, num_tiles_n, tile_size, tile_size) + matrix_tiled = matrix_tiled.reshape(-1, tile_size, tile_size) # (num_tiles_m * num_tiles_n, tile_size, tile_size) + + # Apply Hadamard to all tiles at once using bmm + result_tiled = torch.bmm(matrix_tiled, H.unsqueeze(0).expand(matrix_tiled.shape[0], -1, -1)) + + # Reshape back to original shape + result_tiled = result_tiled.reshape(num_tiles_m, num_tiles_n, tile_size, tile_size) + result_tiled = result_tiled.permute(0, 2, 1, 3) # (num_tiles_m, tile_size, num_tiles_n, tile_size) + result = result_tiled.reshape(M, N) return result @@ -235,12 +195,13 @@ def main(): print("Hadamard Rotation Epilogue Example") print("="*70 + "\n") - # Problem size (must be divisible by block size) - M, N, K = 512, 512, 512 - block_size = 32 + # Problem size (must be divisible by tile size) + M, N, K = 8192, 8192, 8192 + tile_size = 128 print(f"Matrix dimensions: M={M}, N={N}, K={K}") - print(f"Hadamard block size: {block_size}x{block_size}\n") + print(f"Square tile size: {tile_size}x{tile_size}") + print(f"Hadamard applied to entire accumulator tile\n") # Allocate input matrices A = torch.randn(M, K, device="cuda", dtype=torch.float16) @@ -248,17 +209,18 @@ def main(): # Run GEMM with Hadamard epilogue print("Running GEMM with Hadamard rotation epilogue...") - C_triton = matmul_with_hadamard(A, B, block_size=block_size) + C_triton = matmul_with_hadamard(A, B, tile_size=tile_size) # Compute reference: GEMM then Hadamard print("Computing PyTorch reference...") C_gemm = torch.matmul(A, B) - C_torch = apply_hadamard_torch(C_gemm, block_size=block_size) + C_torch = apply_hadamard_torch(C_gemm, tile_size=tile_size) # Compare results max_diff = torch.max(torch.abs(C_triton - C_torch)).item() mean_diff = torch.mean(torch.abs(C_triton - C_torch)).item() - + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + print(f"\nResults:") print("-" * 70) print(f"Max difference from PyTorch: {max_diff:.6f}") @@ -269,21 +231,80 @@ def main(): # Verify Hadamard properties print(f"\nHadamard Matrix Properties:") print("-" * 70) - H = build_hadamard_torch(block_size).to("cuda", torch.float16) + H = build_hadamard_torch(tile_size).to("cuda", torch.float16) H_squared = H @ H.T - identity_diff = torch.max(torch.abs(H_squared - torch.eye(block_size, device="cuda", dtype=torch.float16))).item() + identity_diff = torch.max(torch.abs(H_squared - torch.eye(tile_size, device="cuda", dtype=torch.float16))).item() print(f"H @ H^T ≈ I (max diff from identity): {identity_diff:.6f}") print(f"Hadamard matrix is orthogonal: {identity_diff < 0.01}") print("\n" + "="*70) print("Key Points:") print("="*70) - print("1. Hadamard rotation is applied to blocks of the accumulator") - print("2. Block size must be power of 2 between 16 and 64") - print("3. Accumulator dimensions must be divisible by block size") + print("1. Hadamard rotation applied to entire square accumulator tile") + print("2. Tile must be square and power-of-2 (16, 32, or 64)") + print("3. Generic implementation works for any valid square tile size") print("4. Hadamard matrices are orthogonal: H @ H^T = I") print("5. Useful for randomized algorithms and privacy-preserving ML") print("="*70 + "\n") + + # ======================================================================== + # Performance Benchmark + # ======================================================================== + print("="*70) + print("Performance Benchmark") + print("="*70 + "\n") + + from triton.testing import do_bench + + # Benchmark GEMM without epilogue + print("Benchmarking GEMM without epilogue...") + + def gemm_no_epilogue(): + C_no_epi = torch.zeros((M, N), device="cuda", dtype=torch.float16) + grid = (triton.cdiv(M, tile_size) * triton.cdiv(N, tile_size),) + persistent_matmul[grid]( + A, B, C_no_epi, + None, None, A, + M, N, K, + A.stride(0), B.stride(1), + C_no_epi.stride(0), C_no_epi.stride(1), + 0, A.stride(1), B.stride(0), + BLOCK_SIZE_M=tile_size, BLOCK_SIZE_N=tile_size, BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, NUM_SMS=num_sms, NUM_XCDS=8, CHUNK_SIZE=8, + BIAS=False, EVEN_K=(K % 32 == 0), + CACHE_MODIFIER_A=".cg", CACHE_MODIFIER_B=".cg", + epilogue_fn=None, + QUANTIZED=False, + ) + return C_no_epi + + time_no_epilogue = do_bench(gemm_no_epilogue) + + # Benchmark GEMM with Hadamard epilogue + print("Benchmarking GEMM with Hadamard epilogue...") + + def gemm_with_epilogue(): + return matmul_with_hadamard(A, B, tile_size=tile_size) + + time_with_epilogue = do_bench(gemm_with_epilogue) + + # Benchmark separate GEMM + Hadamard + print("Benchmarking GEMM + separate Hadamard kernel...") + + def gemm_then_hadamard(): + C_temp = torch.matmul(A, B) + return apply_hadamard_torch(C_temp, tile_size=tile_size) + + time_separate = do_bench(gemm_then_hadamard) + + print(f"\nPerformance Results:") + print("-" * 70) + print(f"GEMM without epilogue: {time_no_epilogue:.3f} ms") + print(f"GEMM with Hadamard epilogue: {time_with_epilogue:.3f} ms") + print(f"GEMM + separate Hadamard: {time_separate:.3f} ms") + print(f"\nOverhead of epilogue: {time_with_epilogue - time_no_epilogue:.3f} ms ({((time_with_epilogue/time_no_epilogue - 1) * 100):.1f}%)") + print(f"Speedup vs separate kernels: {time_separate / time_with_epilogue:.2f}x") + print("="*70 + "\n") if __name__ == "__main__":