diff --git a/examples/example_hadamard_epilogue.py b/examples/example_hadamard_epilogue.py new file mode 100644 index 0000000..800999d --- /dev/null +++ b/examples/example_hadamard_epilogue.py @@ -0,0 +1,311 @@ +""" +Example demonstrating Hadamard rotation epilogue in tritonBLAS. + +This example shows how to apply a Hadamard transformation to the entire +output accumulator tile. This is useful for randomized numerical linear algebra +and privacy-preserving computations. +""" +import torch +import triton +import triton.language as tl +from tritonblas.kernels.persistent_gemm import persistent_matmul + + +# ============================================================================ +# Define Hadamard Rotation Epilogue +# ============================================================================ + +@triton.jit +def build_hadamard(SIZE: tl.constexpr): + """ + Construct Hadamard matrix using H_{i,j} = (-1)^{popcount(i & j)}. + + Args: + SIZE: Matrix dimension (must be power of 2, 16 <= SIZE <= 64) + + Returns: + SIZE x SIZE normalized Hadamard matrix + """ + tl.static_assert(16 <= SIZE) + + # Create row and column indices + i = tl.arange(0, SIZE) + j = tl.arange(0, SIZE) + + # Compute bitwise AND for all (i, j) pairs + matching_bits = i[:, None] & j[None, :] + + # Count set bits (popcount) + bit_sum = tl.zeros_like(matching_bits) + temp = matching_bits + for _ in tl.static_range(6): # 6 iterations for up to 64 bits + bit_sum += temp & 1 + temp >>= 1 + + # Map: even popcount -> +1, odd popcount -> -1 + H = 1 - 2 * (bit_sum & 1) + + # Normalize by sqrt(SIZE) + H = H / tl.math.sqrt(float(SIZE)) + return H + + +@triton.jit +def is_power_of_two(n: tl.constexpr) -> tl.constexpr: + """Check if n is a power of 2.""" + return (n & (n - 1)) == 0 and n > 0 + + +@triton.jit +def hadamard_rotation_square(acc): + """ + Apply Hadamard rotation to the entire square accumulator. + + This works for any square accumulator with power-of-2 dimensions + between 16 and 64. + + Args: + acc: Accumulator tensor [SIZE, SIZE] where SIZE is power of 2, 16 <= SIZE <= 64 + + Returns: + Accumulator with Hadamard transformation applied: acc @ H + """ + SIZE:tl.constexpr = acc.shape[0] + + # Static assertions to enforce layout constraints + tl.static_assert(acc.shape[0] == acc.shape[1], "Accumulator must be square") + tl.static_assert((SIZE & (SIZE - 1)) == 0, "Accumulator size must be power of 2") + tl.static_assert(SIZE >= 16, "Accumulator size must be >= 16") + + # Build Hadamard matrix and apply + H = build_hadamard(SIZE) + return tl.dot(acc, H.to(acc.dtype)) + + +# ============================================================================ +# Helper Function +# ============================================================================ + +def matmul_with_hadamard(A, B, tile_size=256): + """ + Perform matrix multiplication with Hadamard rotation epilogue. + + Args: + A: Input matrix A [M, K] + B: Input matrix B [K, N] (transposed) + tile_size: Square tile size for GEMM (must be 32 or 64) + + Returns: + Output matrix C [M, N] with Hadamard transformation applied to each tile + """ + M, K = A.shape + _, N = B.shape + C = torch.zeros((M, N), device="cuda", dtype=A.dtype) + + # Get device properties + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + # Use square tiles + BLOCK_SIZE_M = tile_size + BLOCK_SIZE_N = tile_size + BLOCK_SIZE_K = 32 + GROUP_SIZE_M = 8 + + # Define grid + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + # Launch kernel with Hadamard epilogue + persistent_matmul[grid]( + A, B, C, + None, None, # No quantization scales + A, # Dummy bias pointer (not used) + M, N, K, + A.stride(0), B.stride(1), + C.stride(0), C.stride(1), + 0, # stride_bias (not used) + A.stride(1), B.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=num_sms, + NUM_XCDS=1, + CHUNK_SIZE=1, + BIAS=False, + EVEN_K=(K % BLOCK_SIZE_K == 0), + CACHE_MODIFIER_A=".cg", + CACHE_MODIFIER_B=".cg", + epilogue_fn=hadamard_rotation_square, + QUANTIZED=False, + ) + + return C + + +def build_hadamard_torch(size): + """Build Hadamard matrix in PyTorch for verification.""" + i = torch.arange(size, dtype=torch.int32) + j = torch.arange(size, dtype=torch.int32) + + # Compute bitwise AND for all (i, j) pairs + matching_bits = i[:, None] & j[None, :] + + # Count set bits (popcount) + bit_sum = torch.zeros_like(matching_bits) + temp = matching_bits.clone() + for _ in range(6): + bit_sum += temp & 1 + temp >>= 1 + + # Map: even popcount -> +1, odd popcount -> -1 + H = 1 - 2 * (bit_sum & 1) + + # Normalize + H = H.float() / (size ** 0.5) + return H + + +def apply_hadamard_torch(matrix, tile_size): + """Apply Hadamard rotation to each tile in PyTorch using bmm.""" + M, N = matrix.shape + num_tiles_m = M // tile_size + num_tiles_n = N // tile_size + + # Build Hadamard matrix + H = build_hadamard_torch(tile_size).to(matrix.device, matrix.dtype) + + # Reshape matrix into tiles: (num_tiles_m, num_tiles_n, tile_size, tile_size) + matrix_tiled = matrix.reshape(num_tiles_m, tile_size, num_tiles_n, tile_size) + matrix_tiled = matrix_tiled.permute(0, 2, 1, 3) # (num_tiles_m, num_tiles_n, tile_size, tile_size) + matrix_tiled = matrix_tiled.reshape(-1, tile_size, tile_size) # (num_tiles_m * num_tiles_n, tile_size, tile_size) + + # Apply Hadamard to all tiles at once using bmm + result_tiled = torch.bmm(matrix_tiled, H.unsqueeze(0).expand(matrix_tiled.shape[0], -1, -1)) + + # Reshape back to original shape + result_tiled = result_tiled.reshape(num_tiles_m, num_tiles_n, tile_size, tile_size) + result_tiled = result_tiled.permute(0, 2, 1, 3) # (num_tiles_m, tile_size, num_tiles_n, tile_size) + result = result_tiled.reshape(M, N) + + return result + + +def main(): + print("\n" + "="*70) + print("Hadamard Rotation Epilogue Example") + print("="*70 + "\n") + + # Problem size (must be divisible by tile size) + M, N, K = 8192, 8192, 8192 + tile_size = 128 + + print(f"Matrix dimensions: M={M}, N={N}, K={K}") + print(f"Square tile size: {tile_size}x{tile_size}") + print(f"Hadamard applied to entire accumulator tile\n") + + # Allocate input matrices + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(N, K, device="cuda", dtype=torch.float16).T + + # Run GEMM with Hadamard epilogue + print("Running GEMM with Hadamard rotation epilogue...") + C_triton = matmul_with_hadamard(A, B, tile_size=tile_size) + + # Compute reference: GEMM then Hadamard + print("Computing PyTorch reference...") + C_gemm = torch.matmul(A, B) + C_torch = apply_hadamard_torch(C_gemm, tile_size=tile_size) + + # Compare results + max_diff = torch.max(torch.abs(C_triton - C_torch)).item() + mean_diff = torch.mean(torch.abs(C_triton - C_torch)).item() + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + print(f"\nResults:") + print("-" * 70) + print(f"Max difference from PyTorch: {max_diff:.6f}") + print(f"Mean difference from PyTorch: {mean_diff:.6f}") + print(f"Output shape: {C_triton.shape}") + print(f"\nSample output (first 3x3):\n{C_triton[:3, :3]}") + + # Verify Hadamard properties + print(f"\nHadamard Matrix Properties:") + print("-" * 70) + H = build_hadamard_torch(tile_size).to("cuda", torch.float16) + H_squared = H @ H.T + identity_diff = torch.max(torch.abs(H_squared - torch.eye(tile_size, device="cuda", dtype=torch.float16))).item() + print(f"H @ H^T ≈ I (max diff from identity): {identity_diff:.6f}") + print(f"Hadamard matrix is orthogonal: {identity_diff < 0.01}") + + print("\n" + "="*70) + print("Key Points:") + print("="*70) + print("1. Hadamard rotation applied to entire square accumulator tile") + print("2. Tile must be square and power-of-2 (16, 32, or 64)") + print("3. Generic implementation works for any valid square tile size") + print("4. Hadamard matrices are orthogonal: H @ H^T = I") + print("5. Useful for randomized algorithms and privacy-preserving ML") + print("="*70 + "\n") + + # ======================================================================== + # Performance Benchmark + # ======================================================================== + print("="*70) + print("Performance Benchmark") + print("="*70 + "\n") + + from triton.testing import do_bench + + # Benchmark GEMM without epilogue + print("Benchmarking GEMM without epilogue...") + + def gemm_no_epilogue(): + C_no_epi = torch.zeros((M, N), device="cuda", dtype=torch.float16) + grid = (triton.cdiv(M, tile_size) * triton.cdiv(N, tile_size),) + persistent_matmul[grid]( + A, B, C_no_epi, + None, None, A, + M, N, K, + A.stride(0), B.stride(1), + C_no_epi.stride(0), C_no_epi.stride(1), + 0, A.stride(1), B.stride(0), + BLOCK_SIZE_M=tile_size, BLOCK_SIZE_N=tile_size, BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, NUM_SMS=num_sms, NUM_XCDS=8, CHUNK_SIZE=8, + BIAS=False, EVEN_K=(K % 32 == 0), + CACHE_MODIFIER_A=".cg", CACHE_MODIFIER_B=".cg", + epilogue_fn=None, + QUANTIZED=False, + ) + return C_no_epi + + time_no_epilogue = do_bench(gemm_no_epilogue) + + # Benchmark GEMM with Hadamard epilogue + print("Benchmarking GEMM with Hadamard epilogue...") + + def gemm_with_epilogue(): + return matmul_with_hadamard(A, B, tile_size=tile_size) + + time_with_epilogue = do_bench(gemm_with_epilogue) + + # Benchmark separate GEMM + Hadamard + print("Benchmarking GEMM + separate Hadamard kernel...") + + def gemm_then_hadamard(): + C_temp = torch.matmul(A, B) + return apply_hadamard_torch(C_temp, tile_size=tile_size) + + time_separate = do_bench(gemm_then_hadamard) + + print(f"\nPerformance Results:") + print("-" * 70) + print(f"GEMM without epilogue: {time_no_epilogue:.3f} ms") + print(f"GEMM with Hadamard epilogue: {time_with_epilogue:.3f} ms") + print(f"GEMM + separate Hadamard: {time_separate:.3f} ms") + print(f"\nOverhead of epilogue: {time_with_epilogue - time_no_epilogue:.3f} ms ({((time_with_epilogue/time_no_epilogue - 1) * 100):.1f}%)") + print(f"Speedup vs separate kernels: {time_separate / time_with_epilogue:.2f}x") + print("="*70 + "\n") + + +if __name__ == "__main__": + main() diff --git a/examples/example_matmul_epilogue.py b/examples/example_matmul_epilogue.py new file mode 100644 index 0000000..e146298 --- /dev/null +++ b/examples/example_matmul_epilogue.py @@ -0,0 +1,132 @@ +""" +Minimal example demonstrating custom epilogue functions in tritonBLAS. + +This example shows how to: +1. Define your own custom epilogue function +2. Pass it to the persistent GEMM kernel +3. Verify the results +""" +import torch +import triton +import triton.language as tl +from tritonblas.kernels.persistent_gemm import persistent_matmul + + +# ============================================================================ +# Define Custom Epilogue Function +# ============================================================================ + +@triton.jit +def my_custom_clamp(acc): + """ + Custom epilogue: clamp values between -1 and 1. + + This is a simple example showing how to create your own epilogue function. + You can perform any element-wise operation on the accumulator. + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Clamped accumulator + """ + return tl.minimum(tl.maximum(acc, -1.0), 1.0) + + +# ============================================================================ +# Helper Function to Run GEMM with Custom Epilogue +# ============================================================================ + +def matmul_with_custom_epilogue(A, B, epilogue_fn=None): + """ + Perform matrix multiplication with a custom epilogue function. + + Args: + A: Input matrix A [M, K] + B: Input matrix B [K, N] (transposed) + epilogue_fn: Custom epilogue function to apply + + Returns: + Output matrix C [M, N] + """ + M, K = A.shape + _, N = B.shape + C = torch.zeros((M, N), device="cuda", dtype=A.dtype) + + # Get device properties + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + # Fixed block sizes + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + GROUP_SIZE_M = 8 + + # Define grid + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + # Launch kernel with custom epilogue + persistent_matmul[grid]( + A, B, C, + None, None, # No quantization scales + A, # Dummy bias pointer (not used) + M, N, K, + A.stride(0), B.stride(1), + C.stride(0), C.stride(1), + 0, # stride_bias (not used) + A.stride(1), B.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=num_sms, + NUM_XCDS=1, + CHUNK_SIZE=1, + BIAS=False, + EVEN_K=(K % BLOCK_SIZE_K == 0), + CACHE_MODIFIER_A=".cg", + CACHE_MODIFIER_B=".cg", + epilogue_fn=epilogue_fn, # Pass your custom epilogue here! + QUANTIZED=False, + ) + + return C + + +# ============================================================================ +# Main Example +# ============================================================================ + +def main(): + print("\n" + "="*70) + print("Custom Epilogue Function Example") + print("="*70 + "\n") + + # Problem size + M, N, K = 512, 512, 512 + + # Allocate input matrices + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(N, K, device="cuda", dtype=torch.float16).T + + # ======================================================================== + # Example: Custom Clamp Epilogue + # ======================================================================== + print("Custom Clamp Epilogue (values between -1 and 1)") + print("-" * 70) + + C_custom = matmul_with_custom_epilogue(A, B, epilogue_fn=my_custom_clamp) + + # Verify against PyTorch + C_torch = torch.clamp(torch.matmul(A, B), -1.0, 1.0) + max_diff = torch.max(torch.abs(C_custom - C_torch)).item() + + print(f"Max difference from PyTorch: {max_diff:.6f}") + print(f"Min value: {C_custom.min().item():.6f}") + print(f"Max value: {C_custom.max().item():.6f}") + print(f"Sample output (first 3x3):\n{C_custom[:3, :3]}\n") + + + +if __name__ == "__main__": + main() diff --git a/include/tritonblas/kernels/persistent_gemm.py b/include/tritonblas/kernels/persistent_gemm.py index d00d104..ef179f3 100644 --- a/include/tritonblas/kernels/persistent_gemm.py +++ b/include/tritonblas/kernels/persistent_gemm.py @@ -37,6 +37,7 @@ def persistent_matmul( EVEN_K: tl.constexpr, CACHE_MODIFIER_A: tl.constexpr, CACHE_MODIFIER_B: tl.constexpr, + epilogue_fn=None, # Epilogue function to apply (default: None) QUANTIZED: tl.constexpr = False, # True for int8/fp8, False for fp16/bf16 ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32, ): @@ -106,6 +107,10 @@ def persistent_matmul( # Check if we're using quantized mode based on whether scales were applied acc = add_vector(acc, bias_vector, QUANTIZED=(A_scale_ptr is not None)) #Add bias vector to output accumulator + # Apply epilogue function to accumulator if provided + if epilogue_fn is not None: + acc = epilogue_fn(acc) + # Convert to output dtype result = convert_dtype(acc, C.type.element_ty) #Quantize output accumulator to output datatype diff --git a/include/tritonblas/kernels/stages/algorithms/__init__.py b/include/tritonblas/kernels/stages/algorithms/__init__.py index 827ae14..12260cf 100644 --- a/include/tritonblas/kernels/stages/algorithms/__init__.py +++ b/include/tritonblas/kernels/stages/algorithms/__init__.py @@ -33,6 +33,17 @@ convert_dtype, ) from .gemm_loop import gemm_loop +from .epilogue import ( + relu, + gelu, + gelu_tanh, + sigmoid, + silu, + tanh, + leaky_relu, + identity, + apply_epilogue, +) __all__ = [ # Binary operations @@ -43,4 +54,14 @@ 'convert_dtype', # Composition 'gemm_loop', + # Epilogue operations + 'relu', + 'gelu', + 'gelu_tanh', + 'sigmoid', + 'silu', + 'tanh', + 'leaky_relu', + 'identity', + 'apply_epilogue', ] diff --git a/include/tritonblas/kernels/stages/algorithms/epilogue.py b/include/tritonblas/kernels/stages/algorithms/epilogue.py new file mode 100644 index 0000000..0a466ab --- /dev/null +++ b/include/tritonblas/kernels/stages/algorithms/epilogue.py @@ -0,0 +1,158 @@ +""" +Epilogue functions for composable Triton GEMM kernels. +Operations applied to the output accumulator after GEMM computation. +""" +import triton +import triton.language as tl + + +@triton.jit +def relu(acc): + """ + Apply ReLU activation: max(0, x) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with ReLU applied + """ + return tl.maximum(acc, 0.0) + + +@triton.jit +def gelu(acc): + """ + Apply GELU activation (approximate version using tanh). + GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with GELU applied + """ + # Constants for GELU approximation + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2/π) + coeff = 0.044715 + + # GELU approximation using numerically stable tanh + x_cubed = acc * acc * acc + inner = sqrt_2_over_pi * (acc + coeff * x_cubed) + + # Numerically stable tanh: use different formulas for positive/negative values + # For x > 0: tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x)) + # For x < 0: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + exp_neg_2x = tl.exp(-2.0 * tl.abs(inner)) + tanh_inner = tl.where( + inner >= 0, + (1.0 - exp_neg_2x) / (1.0 + exp_neg_2x), + -(1.0 - exp_neg_2x) / (1.0 + exp_neg_2x) + ) + return 0.5 * acc * (1.0 + tanh_inner) + + +@triton.jit +def gelu_tanh(acc): + """ + Apply GELU activation using tanh approximation (alias for gelu). + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with GELU applied + """ + return gelu(acc) + + +@triton.jit +def sigmoid(acc): + """ + Apply Sigmoid activation: 1 / (1 + exp(-x)) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with Sigmoid applied + """ + return tl.sigmoid(acc) + + +@triton.jit +def silu(acc): + """ + Apply SiLU (Swish) activation: x * sigmoid(x) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with SiLU applied + """ + return acc * tl.sigmoid(acc) + + +@triton.jit +def tanh(acc): + """ + Apply Tanh activation using numerically stable formula. + For x > 0: tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x)) + For x < 0: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Accumulator with Tanh applied + """ + # Use numerically stable formula to avoid overflow + exp_neg_2x = tl.exp(-2.0 * tl.abs(acc)) + result = (1.0 - exp_neg_2x) / (1.0 + exp_neg_2x) + # Apply sign + return tl.where(acc >= 0, result, -result) + + +@triton.jit +def leaky_relu(acc, negative_slope: tl.constexpr = 0.01): + """ + Apply Leaky ReLU activation: max(0, x) + negative_slope * min(0, x) + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + negative_slope: Slope for negative values (default: 0.01) + + Returns: + Accumulator with Leaky ReLU applied + """ + return tl.where(acc > 0, acc, acc * negative_slope) + + +@triton.jit +def identity(acc): + """ + Identity function (no activation). + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + + Returns: + Unchanged accumulator + """ + return acc + + +@triton.jit +def apply_epilogue(acc, epilogue_fn): + """ + Apply an epilogue function to the accumulator. + + Args: + acc: Accumulator tensor [BLOCK_SIZE_M, BLOCK_SIZE_N] + epilogue_fn: Epilogue function to apply (e.g., relu, gelu, etc.) + + Returns: + Accumulator with epilogue applied + """ + return epilogue_fn(acc) diff --git a/tests/test_epilogues.py b/tests/test_epilogues.py new file mode 100644 index 0000000..7ddfb91 --- /dev/null +++ b/tests/test_epilogues.py @@ -0,0 +1,224 @@ +""" +Tests for epilogue functions in tritonBLAS. +""" +import pytest +import torch +import triton +from tritonblas.kernels.persistent_gemm import persistent_matmul +from tritonblas.kernels.stages.algorithms.epilogue import ( + relu, gelu, silu, tanh, sigmoid, leaky_relu, identity +) + + +def run_matmul_with_epilogue(M, N, K, epilogue_fn=None, bias=None, dtype=torch.float16): + """ + Helper function to run matrix multiplication with epilogue. + + Args: + M, N, K: Matrix dimensions + epilogue_fn: Epilogue function to apply + bias: Optional bias vector + dtype: Data type for tensors + + Returns: + Output tensor from Triton kernel + """ + # Allocate tensors + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(N, K, device="cuda", dtype=dtype).T + C = torch.zeros((M, N), device="cuda", dtype=dtype) + + # Get device properties + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + # Setup bias + has_bias = bias is not None + if has_bias: + bias_ptr = bias + stride_bias = bias.stride(0) + else: + bias_ptr = A # Dummy pointer + stride_bias = 0 + + # Fixed block sizes + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + GROUP_SIZE_M = 8 + + # Define grid + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + # Launch kernel + persistent_matmul[grid]( + A, B, C, + None, None, # No quantization scales + bias_ptr, + M, N, K, + A.stride(0), B.stride(1), + C.stride(0), C.stride(1), + stride_bias, + A.stride(1), B.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=num_sms, + NUM_XCDS=1, + CHUNK_SIZE=1, + BIAS=has_bias, + EVEN_K=(K % BLOCK_SIZE_K == 0), + CACHE_MODIFIER_A=".cg", + CACHE_MODIFIER_B=".cg", + epilogue_fn=epilogue_fn, + QUANTIZED=False, + ) + + return C, A, B + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), + (128, 256, 512), +]) +def test_identity_epilogue(M, N, K): + """Test identity epilogue (no activation).""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=identity) + C_torch = torch.matmul(A, B) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Identity epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_relu_epilogue(M, N, K): + """Test ReLU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=relu) + C_torch = torch.relu(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"ReLU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_gelu_epilogue(M, N, K): + """Test GELU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=gelu) + C_torch = torch.nn.functional.gelu(torch.matmul(A, B), approximate='tanh') + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"GELU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_silu_epilogue(M, N, K): + """Test SiLU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=silu) + C_torch = torch.nn.functional.silu(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"SiLU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_tanh_epilogue(M, N, K): + """Test Tanh epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=tanh) + C_torch = torch.tanh(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Tanh epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_sigmoid_epilogue(M, N, K): + """Test Sigmoid epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=sigmoid) + C_torch = torch.sigmoid(torch.matmul(A, B)) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Sigmoid epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), + (512, 512, 512), +]) +def test_leaky_relu_epilogue(M, N, K): + """Test Leaky ReLU epilogue.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=leaky_relu) + C_torch = torch.nn.functional.leaky_relu(torch.matmul(A, B), negative_slope=0.01) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"Leaky ReLU epilogue failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), +]) +def test_epilogue_with_bias(M, N, K): + """Test epilogue with bias addition.""" + bias = torch.randn(M, device="cuda", dtype=torch.float16) + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=relu, bias=bias) + + C_torch = torch.matmul(A, B) + bias.unsqueeze(1) + C_torch = torch.relu(C_torch) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"ReLU epilogue with bias failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +@pytest.mark.parametrize("M,N,K", [ + (256, 256, 256), +]) +def test_no_epilogue(M, N, K): + """Test that None epilogue works correctly.""" + C_triton, A, B = run_matmul_with_epilogue(M, N, K, epilogue_fn=None) + C_torch = torch.matmul(A, B) + + assert torch.allclose(C_triton, C_torch, rtol=1e-2, atol=1e-2), \ + f"No epilogue (None) failed: max_diff={torch.max(torch.abs(C_triton - C_torch))}" + + +if __name__ == "__main__": + # Run tests manually + print("Running epilogue tests...") + + test_functions = [ + test_identity_epilogue, + test_relu_epilogue, + test_gelu_epilogue, + test_silu_epilogue, + test_tanh_epilogue, + test_sigmoid_epilogue, + test_leaky_relu_epilogue, + test_epilogue_with_bias, + test_no_epilogue, + ] + + for test_func in test_functions: + try: + test_func(256, 256, 256) + print(f"✓ {test_func.__name__} PASSED") + except AssertionError as e: + print(f"✗ {test_func.__name__} FAILED: {e}") + except Exception as e: + print(f"✗ {test_func.__name__} ERROR: {e}") + + print("\nAll tests completed!")