From e9e1217e21726a8871b0ffdccefb7a199471df97 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Fri, 16 Jan 2026 17:57:27 -0600 Subject: [PATCH 01/13] Add addmm() support and adjust APIs to mimic torch Most of torch's APIs allocate new tensor objects when calling high level functions and provide the option to perform in-place operations. To keep things consistent with argument names and better match what torch does, this commit also moves the project in that direction by changing the output matrix to be called `out` inside the matmul() and addmm() functions and, unless that argument is set, does output tensor allocations before calling the triton kernels within. --- include/tritonblas/__init__.py | 1 + include/tritonblas/matmul.py | 49 +++++++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/include/tritonblas/__init__.py b/include/tritonblas/__init__.py index 6f9365f..2a2c8b3 100644 --- a/include/tritonblas/__init__.py +++ b/include/tritonblas/__init__.py @@ -1,4 +1,5 @@ from .matmul import matmul, matmul_a8w8 from .matmul import matmul_lt, matmul_a8w8_lt from .matmul import matmul_fp4 +from .matmul import addmm from .origami import OrigamiMatmulSelector diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 9b23986..650ea4a 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -52,6 +52,7 @@ def persistent_matmul_lt( b: torch.Tensor, c: torch.Tensor, selector, + bias: Optional[torch.Tensor] = None, a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, @@ -98,7 +99,7 @@ def persistent_matmul_lt( c, a_scale if quantized else None, # A_scale_ptr b_scale if quantized else None, # B_scale_ptr - None, # TODO: Enable bias. + bias if bias is not None else None, M, N, K, @@ -106,7 +107,7 @@ def persistent_matmul_lt( b.stride(1), c.stride(0), c.stride(1), - 0, # TODO: Enable bias stride. + bias.stride(0) if bias is not None else 0, stride_ak=a.stride(1), stride_bk=b.stride(0), BLOCK_SIZE_M=BLK_M, @@ -116,7 +117,7 @@ def persistent_matmul_lt( NUM_SMS=total_programs, NUM_XCDS=num_xcds, CHUNK_SIZE=chunk_size, - BIAS=False, + BIAS=bias is not None, EVEN_K=even_k, CACHE_MODIFIER_A=CACHE_MODIFIER_A, CACHE_MODIFIER_B=CACHE_MODIFIER_B, @@ -135,6 +136,7 @@ def streamk_matmul_lt( b: torch.Tensor, c: torch.Tensor, selector, + bias: Optional[torch.Tensor] = None, sk_grid: Optional[int] = None, a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, @@ -198,7 +200,7 @@ def streamk_matmul_lt( c, a_scale if quantized else None, # A_scale_ptr b_scale if quantized else None, # B_scale_ptr - None, # TODO: Enable bias. + bias if bias is not None else None, P, locks, M, @@ -208,7 +210,7 @@ def streamk_matmul_lt( b.stride(1), c.stride(0), c.stride(1), - 0, # TODO: Enable bias stride. + bias.stride(0) if bias is not None else None, stride_ak=a.stride(1), stride_bk=b.stride(0), BLOCK_SIZE_M=BLK_M, @@ -219,7 +221,7 @@ def streamk_matmul_lt( NUM_XCDS=num_xcds, CHUNK_SIZE=chunk_size, STREAMK_TILES=total_tiles_streamk, - BIAS=False, + BIAS=bias is not None, EVEN_K=even_k, CACHE_MODIFIER_A=CACHE_MODIFIER_A, CACHE_MODIFIER_B=CACHE_MODIFIER_B, @@ -256,7 +258,7 @@ def matmul_a8w8_lt( def matmul( a: torch.Tensor, b: torch.Tensor, - c: torch.Tensor, + out: Optional[torch.Tensor] = None, enable_streamk=False, sk_grid=None, ): @@ -264,11 +266,14 @@ def matmul( M, K = a.shape _, N = b.shape - selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk) + if out is None: + out = a.new_empty(M, N) + + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid) + return streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) else: - return persistent_matmul_lt(a, b, c, selector) + return persistent_matmul_lt(a, b, out, selector) def matmul_a8w8( a: torch.Tensor, @@ -396,3 +401,27 @@ def matmul_fp4( ) return c + + +def addmm( + bias: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" + M, K = a.shape + _, N = b.shape + + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) + + if out is None: + out = a.new_empty(M, N) + + if enable_streamk: + return streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) + else: + return persistent_matmul_lt(a, b, out, selector, bias=bias) + From 37baec16fa61a3fc31b7ecdcba7cc7cc1ad30568 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Mon, 19 Jan 2026 14:23:05 -0600 Subject: [PATCH 02/13] Add addmm() triton_op and backwards implementation Despite this addition, backwards pass doesn't currently work because autograd is incompatible with mutated input args (`out=`). --- include/tritonblas/matmul.py | 62 ++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 650ea4a..1d046a5 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -1,12 +1,15 @@ -import torch -import triton -import random import functools +import random import time +from typing import Any, Dict, Optional, Tuple + +import torch +from torch.library import triton_op, wrap_triton +import triton + from .kernels import persistent_matmul, streamk_matmul from .kernels.fp4_matmul import fp4_matmul from .origami import OrigamiMatmulSelector -from typing import Dict, Tuple, Optional _tensor_cache = {} current_device_index = torch.cuda.current_device() @@ -93,7 +96,8 @@ def persistent_matmul_lt( chunk_size = min(chunk_size, total_programs // num_xcds) # TODO: Support other matmul algs. - kk = persistent_matmul[(grids,)]( + #kk = persistent_matmul[(grids,)]( + kk = wrap_triton(persistent_matmul)[(grids,)]( a, b, c, @@ -194,7 +198,8 @@ def streamk_matmul_lt( chunk_size = gsize_m * gsize_m chunk_size = min(chunk_size, grids // num_xcds) - kk = streamk_matmul[(grids,)]( + #kk = streamk_matmul[(grids,)]( + kk = wrap_triton(streamk_matmul)[(grids,)]( a, b, c, @@ -266,6 +271,7 @@ def matmul( M, K = a.shape _, N = b.shape + # Allocate an output tensor iff one is not provided from inputs if out is None: out = a.new_empty(M, N) @@ -403,6 +409,7 @@ def matmul_fp4( return c +@triton_op("tritonblas::addmm", mutates_args={"out"}) def addmm( bias: torch.Tensor, a: torch.Tensor, @@ -417,6 +424,7 @@ def addmm( selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) + # Allocate an output tensor iff one is not provided from inputs if out is None: out = a.new_empty(M, N) @@ -425,3 +433,45 @@ def addmm( else: return persistent_matmul_lt(a, b, out, selector, bias=bias) + +def _setup_context_addmm_backwards( + ctx: Any, + inputs: tuple[Any, ...], + output: Any +): + bias, a, b, out, enable_streamk, sk_grid = inputs + ctx.save_for_backwards(a, b) + ctx.enable_streamk = enable_streamk + ctx.sk_grid = sk_grid + + +def _addmm_backwards( + ctx: Any, + grad_output: torch.Tensor +): + a, b = ctx.saved_tensors + enable_streamk = ctx.enable_streamk + sk_grid = ctx.sk_grid + + # Need to make grad_output contiguous? + + # grad_a = grad_output @ b^T + b_t = b.T.contiguous() + grad_a = matmul(grad_output, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # grad_b = a^T @ grad_output + a_t = a.T.contiguous() + grad_b = matmul(a_t, grad_output, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # grad_bias = sum(grad_output) + grad_bias = grad_output.sum(dim=0) + + # tuple[bias, a, b, out, enable_streamk, sk_grid] + # First 3 must be in the order that matches addmm()'s forward args + # Last 3 are not part of the gradient and so are None + return grad_bias, grad_a, grad_b, None, None, None + + +addmm.register_autograd(_addmm_backwards, + setup_context=_setup_context_addmm_backwards) + From d59f0c9b0c5e4488fc0f619fc76d985414e12e5a Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Wed, 28 Jan 2026 19:42:24 -0500 Subject: [PATCH 03/13] Split in-place and out-of-place addm paths --- include/tritonblas/matmul.py | 85 +++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 16 deletions(-) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 1d046a5..8994218 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -409,12 +409,11 @@ def matmul_fp4( return c -@triton_op("tritonblas::addmm", mutates_args={"out"}) -def addmm( +@triton_op("tritonblas::_addmm", mutates_args={}) +def _addmm( bias: torch.Tensor, a: torch.Tensor, b: torch.Tensor, - out: Optional[torch.Tensor] = None, enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None ) -> torch.Tensor: @@ -422,11 +421,11 @@ def addmm( M, K = a.shape _, N = b.shape + # Query Origami for solution selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) - # Allocate an output tensor iff one is not provided from inputs - if out is None: - out = a.new_empty(M, N) + # Allocate an output tensor + out = a.new_empty(M, N) if enable_streamk: return streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) @@ -439,8 +438,8 @@ def _setup_context_addmm_backwards( inputs: tuple[Any, ...], output: Any ): - bias, a, b, out, enable_streamk, sk_grid = inputs - ctx.save_for_backwards(a, b) + bias, a, b, enable_streamk, sk_grid = inputs + ctx.save_for_backward(a, b) ctx.enable_streamk = enable_streamk ctx.sk_grid = sk_grid @@ -453,25 +452,79 @@ def _addmm_backwards( enable_streamk = ctx.enable_streamk sk_grid = ctx.sk_grid - # Need to make grad_output contiguous? + # Make grad_output contiguous + grad_output_cont = grad_output.contiguous() # grad_a = grad_output @ b^T b_t = b.T.contiguous() - grad_a = matmul(grad_output, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) + grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) # grad_b = a^T @ grad_output a_t = a.T.contiguous() - grad_b = matmul(a_t, grad_output, enable_streamk=enable_streamk, sk_grid=sk_grid) + grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid) # grad_bias = sum(grad_output) grad_bias = grad_output.sum(dim=0) - # tuple[bias, a, b, out, enable_streamk, sk_grid] + # tuple[bias, a, b, enable_streamk, sk_grid] # First 3 must be in the order that matches addmm()'s forward args - # Last 3 are not part of the gradient and so are None - return grad_bias, grad_a, grad_b, None, None, None + # Last 2 are not part of the gradient and so are None + return grad_bias, grad_a, grad_b, None, None + + +_addmm.register_autograd(_addmm_backwards, + setup_context=_setup_context_addmm_backwards) + +@triton_op("tritonblas::_addmm_out", mutates_args={'out'}) +def _addmm_out( + bias: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> None: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" + M, K = a.shape + _, N = b.shape -addmm.register_autograd(_addmm_backwards, - setup_context=_setup_context_addmm_backwards) + # Query Origami for solution + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) + + if enable_streamk: + streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) + else: + persistent_matmul_lt(a, b, out, selector, bias=bias) + + # Custom torch ops cannot return a value which is an alias of an input. So + # even though torch returns a pointer to the out arg when used, we can't. + return None + + +def addmm( + bias: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> Optional[torch.Tensor]: + # If no out tensor provided - we do the allocation - we support autograd + if out is None: + return _addmm(bias, a, b, enable_streamk, sk_grid) + + # If out tensor provided - in-place - we do NOT support autograd + # Check for autograd conditions (global and per-tensor) + if torch.is_grad_enabled() and ( + bias.requires_grad + or a.requires_grad + or b.requires_grad + or out.requires_grad + ): + raise RuntimeError( + "tritonblas.addmm(): functions with out=... arguments don't support " + "automatic differentiation, but one of the arguments requires grad." + ) + return _addmm_out(bias, a, b, out, enable_streamk, sk_grid) From d7590341c2794c70d8fad552e3efeb306aaca9e3 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Wed, 28 Jan 2026 19:43:39 -0500 Subject: [PATCH 04/13] Hack fixes for bias into persistent GEMM kernels --- include/tritonblas/kernels/persistent_gemm.py | 4 ++-- include/tritonblas/kernels/stages/algorithms/binary.py | 4 ++-- include/tritonblas/kernels/stages/algorithms/gemm_loop.py | 2 +- include/tritonblas/matmul.py | 1 + 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/tritonblas/kernels/persistent_gemm.py b/include/tritonblas/kernels/persistent_gemm.py index d00d104..7d8012c 100644 --- a/include/tritonblas/kernels/persistent_gemm.py +++ b/include/tritonblas/kernels/persistent_gemm.py @@ -38,7 +38,7 @@ def persistent_matmul( CACHE_MODIFIER_A: tl.constexpr, CACHE_MODIFIER_B: tl.constexpr, QUANTIZED: tl.constexpr = False, # True for int8/fp8, False for fp16/bf16 - ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32, + ALLOW_TF32: tl.constexpr = True, ): # Stride guards tl.assume(stride_am > 0) @@ -102,7 +102,7 @@ def persistent_matmul( # Add bias if provided if BIAS: - bias_vector = tl.load(bias_ptr + row_indices * stride_bias, mask=row_indices < M, other=0.0) #Load Bias vector + bias_vector = tl.load(bias_ptr + col_indices * stride_bias, mask=col_indices < N, other=0.0) #Load Bias vector # Check if we're using quantized mode based on whether scales were applied acc = add_vector(acc, bias_vector, QUANTIZED=(A_scale_ptr is not None)) #Add bias vector to output accumulator diff --git a/include/tritonblas/kernels/stages/algorithms/binary.py b/include/tritonblas/kernels/stages/algorithms/binary.py index b66a0db..10bc7b8 100644 --- a/include/tritonblas/kernels/stages/algorithms/binary.py +++ b/include/tritonblas/kernels/stages/algorithms/binary.py @@ -72,6 +72,6 @@ def add_vector(acc, bias_vector, QUANTIZED: tl.constexpr): Accumulator with bias added """ if QUANTIZED: - return acc + bias_vector[:, None].to(tl.float32) + return acc + bias_vector[None, :].to(tl.float32) else: - return acc + bias_vector[:, None] + return acc + bias_vector[None, :] diff --git a/include/tritonblas/kernels/stages/algorithms/gemm_loop.py b/include/tritonblas/kernels/stages/algorithms/gemm_loop.py index 001a266..db3c118 100644 --- a/include/tritonblas/kernels/stages/algorithms/gemm_loop.py +++ b/include/tritonblas/kernels/stages/algorithms/gemm_loop.py @@ -82,7 +82,7 @@ def gemm_loop( loop_k = tl.cdiv(K, BLOCK_SIZE_K) if not EVEN_K: loop_k -= 1 - tl.assume(loop_k > 0) + tl.assume(loop_k >= 0) # Main loop over K dimension for k_iter in range(loop_k): diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 8994218..e913cef 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -131,6 +131,7 @@ def persistent_matmul_lt( waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, ) return c From 223bcec3de225005340fe3e9dbf99ed2f8637838 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Wed, 28 Jan 2026 19:47:41 -0500 Subject: [PATCH 05/13] Add pytests for addmm functionality --- tests/test_addmm.py | 305 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 tests/test_addmm.py diff --git a/tests/test_addmm.py b/tests/test_addmm.py new file mode 100644 index 0000000..acd200d --- /dev/null +++ b/tests/test_addmm.py @@ -0,0 +1,305 @@ +""" +Tests for tritonblas.addmm with torch.autograd and torch.compile support + +Tests cover: +1. Forward pass correctness against torch.addmm +2. Backward pass gradient correctness against torch.addmm +3. In-place (out=...) functionality and autograd restrictions +4. Edge cases including small dimensions +5. torch.compile compatibility +6. Persistent vs. StreamK compatibility +""" + +import pytest +import torch +import tritonblas + + +# If we don't increase this, torch will complain about too many recompilations. +torch._dynamo.config.cache_size_limit = 10000 + +# Standard test dimensions +STANDARD_DIMS = [ + (128, 256, 512), # Medium sizes + (256, 256, 256), # Square + (512, 1024, 768), # Larger + (2048, 1024, 512), # Wide output + (1024, 2048, 512), # Tall output +] + +# Edge case dimensions (small dimensions, N < 16 cases) +EDGE_CASE_DIMS = [ + (32, 32, 32), # Small square + (64, 16, 128), # Small N + (16, 64, 128), # Small M + (128, 8, 256), # N < 16 + (8, 128, 256), # M < 16 + (12, 12, 512), # Small M and N + (128, 64, 12), # Small K +] + +# Skinny matrix dimensions (stress tests) +SKINNY_DIMS = [ + (16, 16, 4096), # Very large K + (32, 32, 8192), # Large K +] + +# Data types to test +DTYPES = [torch.bfloat16, torch.float16] + +# Whether to test with torch.compile +USE_COMPILE = [False, True] + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_addmm_forward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.addmm forward pass matches torch.addmm.""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + bias = torch.randn(n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # tritonblas result + result = addmm_fn(bias, a, b) + + # torch reference + expected = torch.addmm(bias, a, b) + + # Check forward correctness with relaxed tolerance for low precision + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_addmm_backward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.addmm backward pass produces correct gradients.""" + torch.manual_seed(42) + + # Create inputs with requires_grad for tritonblas + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + + # Clone for torch reference + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + bias_ref = bias.detach().clone().requires_grad_(True) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Forward pass + result = addmm_fn(bias, a, b) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + # Backward pass with same upstream gradient + grad_output = torch.randn_like(result) + result.backward(grad_output) + result_ref.backward(grad_output) + + # Check gradients match + torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1, + msg="bias gradient mismatch") + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1, + msg="a gradient mismatch") + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1, + msg="b gradient mismatch") + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", SKINNY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): + """Test addmm with skinny matrices (large K dimension).""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + bias_ref = bias.detach().clone().requires_grad_(True) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Forward + result = addmm_fn(bias, a, b) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_inplace_with_grad_raises(use_compile): + """Test that addmm with out=... raises RuntimeError when autograd is enabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + with pytest.raises(RuntimeError, match="don't support automatic differentiation"): + addmm_fn(bias, a, b, out=out) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_inplace_without_grad_works(use_compile): + """Test that addmm with out=... works when autograd is disabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Should work with torch.no_grad() + with torch.no_grad(): + result = addmm_fn(bias, a, b, out=out) + + # In-place path returns None (custom ops don't support aliasing) + assert result is None, "in-place addmm should return None" + + # Verify correctness against torch + expected = torch.addmm(bias, a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_inplace_output_correctness(use_compile): + """Test that addmm in-place mode produces correct results.""" + torch.manual_seed(42) + m, n, k = 128, 256, 512 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + bias = torch.randn(n, device='cuda', dtype=dtype) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + with torch.no_grad(): + addmm_fn(bias, a, b, out=out) + + expected = torch.addmm(bias, a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_no_grad_tensors(use_compile): + """Test addmm works when input tensors don't require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=False) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=False) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + result = addmm_fn(bias, a, b) + expected = torch.addmm(bias, a, b) + + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_partial_grad(use_compile): + """Test addmm when only some inputs require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + # Only a requires grad + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=False) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone() + bias_ref = bias.detach().clone() + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + result = addmm_fn(bias, a, b) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", [False, True]) +def test_addmm_streamk_modes(enable_streamk, use_compile): + """Test addmm with different streamk settings.""" + torch.manual_seed(42) + m, n, k = 256, 256, 256 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + bias_ref = bias.detach().clone().requires_grad_(True) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Forward + result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1) From 06cedccf861ece2b895dada3984e29b8b57a18a3 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Wed, 28 Jan 2026 19:48:26 -0500 Subject: [PATCH 06/13] Temporarily disable selector cache due to hash bug --- include/tritonblas/matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index e913cef..f03a3f9 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -25,7 +25,7 @@ # Function will behave like an LRU-Cache of heuristic results # Saves several microseconds for previously seen problems by not rerunning the heuristic unnecessarily -@functools.lru_cache(maxsize=1024) +#@functools.lru_cache(maxsize=1024) def _make_matmul_selector( M: int, N: int, From 19730978378e5796e089af79a9592b8fd2398566 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Thu, 29 Jan 2026 18:58:04 -0500 Subject: [PATCH 07/13] Add matmul() triton_op with split paths like addmm --- include/tritonblas/matmul.py | 110 ++++++++++++++++++++++++++++++++--- 1 file changed, 101 insertions(+), 9 deletions(-) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index f03a3f9..051c221 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -261,27 +261,119 @@ def matmul_a8w8_lt( else: return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) -def matmul( + +@triton_op("tritonblas::_matmul", mutates_args={}) +def _matmul( a: torch.Tensor, b: torch.Tensor, - out: Optional[torch.Tensor] = None, - enable_streamk=False, - sk_grid=None, -): - assert a.shape[1] == b.shape[0], "Incompatible Dimensions" + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape _, N = b.shape - # Allocate an output tensor iff one is not provided from inputs - if out is None: - out = a.new_empty(M, N) + # Allocate an output tensor + out = a.new_empty(M, N) + # Query Origami for solution selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) if enable_streamk: return streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) else: return persistent_matmul_lt(a, b, out, selector) + +def _setup_context_matmul_backwards( + ctx: Any, + inputs: tuple[Any, ...], + output: Any +): + a, b, enable_streamk, sk_grid = inputs + ctx.save_for_backward(a, b) + ctx.enable_streamk = enable_streamk + ctx.sk_grid = sk_grid + + +def _matmul_backwards( + ctx: Any, + grad_output: torch.Tensor +): + a, b = ctx.saved_tensors + enable_streamk = ctx.enable_streamk + sk_grid = ctx.sk_grid + + # Make grad_output contiguous + grad_output_cont = grad_output.contiguous() + + # grad_a = grad_output @ b^T + b_t = b.T.contiguous() + grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # grad_b = a^T @ grad_output + a_t = a.T.contiguous() + grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # tuple[a, b, enable_streamk, sk_grid] + # First 2 must be in the order that matches matmul()'s forward args + # Last 2 are not part of the gradient and so are None + return grad_a, grad_b, None, None + + +_matmul.register_autograd(_matmul_backwards, + setup_context=_setup_context_matmul_backwards) + + +@triton_op("tritonblas::_matmul_out", mutates_args={'out'}) +def _matmul_out( + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> None: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" + M, K = a.shape + _, N = b.shape + + # Query Origami for solution + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) + + if enable_streamk: + streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) + else: + persistent_matmul_lt(a, b, out, selector) + + # Custom torch ops cannot return a value which is an alias of an input. So + # even though torch returns a pointer to the out arg when used, we can't. + return None + + +def matmul( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> Optional[torch.Tensor]: + # If no out tensor provided - we do the allocation - we support autograd + if out is None: + return _matmul(a, b, enable_streamk, sk_grid) + + # If out tensor provided - in-place - we do NOT support autograd + # Check for autograd conditions (global and per-tensor) + if torch.is_grad_enabled() and ( + a.requires_grad + or b.requires_grad + or out.requires_grad + ): + raise RuntimeError( + "tritonblas.matmul(): functions with out=... arguments don't support " + "automatic differentiation, but one of the arguments requires grad." + ) + return _matmul_out(a, b, out, enable_streamk, sk_grid) + + def matmul_a8w8( a: torch.Tensor, b: torch.Tensor, From 09674e6be6766f0e8788f79839037cc4ec9364a9 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Mon, 2 Feb 2026 17:16:33 -0500 Subject: [PATCH 08/13] Add workaround for torch.compile triton issue The version of triton that torch.compile uses has an issue with the way tl.multiple_of generates load operations - in some situations it is possible to generate an illegal vectorized load setup which will result in incorrect operation results but without any warning/error. This does not happen outside of torch.compile where a maineline version of triton is used instead of torch's specific version. This commit adds a workaround only for the torch.compile case which will be removed once the torch.compile triton version is updated to fix this issue. --- include/tritonblas/kernels/persistent_gemm.py | 6 +- .../kernels/stages/algorithms/gemm_loop.py | 11 +- .../tritonblas/kernels/stages/memory/load.py | 11 +- include/tritonblas/matmul.py | 91 ++++-- include/tritonblas/origami.py | 1 + ...est_addmm.py => test_addmm_correctness.py} | 5 + tests/test_matmul_correctness.py | 293 ++++++++++++++++++ 7 files changed, 384 insertions(+), 34 deletions(-) rename tests/{test_addmm.py => test_addmm_correctness.py} (97%) create mode 100644 tests/test_matmul_correctness.py diff --git a/include/tritonblas/kernels/persistent_gemm.py b/include/tritonblas/kernels/persistent_gemm.py index 7d8012c..290cca7 100644 --- a/include/tritonblas/kernels/persistent_gemm.py +++ b/include/tritonblas/kernels/persistent_gemm.py @@ -24,8 +24,8 @@ def persistent_matmul( stride_cm, stride_cn, stride_bias, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, + stride_ak, + stride_bk, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -39,6 +39,7 @@ def persistent_matmul( CACHE_MODIFIER_B: tl.constexpr, QUANTIZED: tl.constexpr = False, # True for int8/fp8, False for fp16/bf16 ALLOW_TF32: tl.constexpr = True, + VECTORIZED_LOAD_SIZE: tl.constexpr = 16, #FIXME: Temporary workaround for torch.compile - remove once upstream is fixed ): # Stride guards tl.assume(stride_am > 0) @@ -88,6 +89,7 @@ def persistent_matmul( BLOCK_SIZE_K, #Block Size in K dimension CACHE_MODIFIER_A, CACHE_MODIFIER_B, #Cache modifiers to control locality QUANTIZED, ALLOW_TF32, EVEN_K, #Extra compile time constants + VECTORIZED_LOAD_SIZE, #FIXME: Temporary workaround for torch.compile - remove once upstream is fixed ) # ============================================================ diff --git a/include/tritonblas/kernels/stages/algorithms/gemm_loop.py b/include/tritonblas/kernels/stages/algorithms/gemm_loop.py index db3c118..66e0acc 100644 --- a/include/tritonblas/kernels/stages/algorithms/gemm_loop.py +++ b/include/tritonblas/kernels/stages/algorithms/gemm_loop.py @@ -30,6 +30,7 @@ def gemm_loop( QUANTIZED: tl.constexpr, ALLOW_TF32: tl.constexpr, EVEN_K: tl.constexpr, + VECTORIZED_LOAD_SIZE: tl.constexpr = 16, #FIXME: Temporary workaround for torch.compile - remove once upstream is fixed ): """ Execute the main GEMM loop over the K dimension. @@ -89,8 +90,9 @@ def gemm_loop( k0 = k_iter * BLOCK_SIZE_K # Load - Address math + global → CU load - a = load(A, row_indices, k0, stride_am, stride_ak, BLOCK_SIZE_K, K, CACHE_MODIFIER_A, mask_k=False, is_row_major=True) - b = load(B, col_indices, k0, stride_bn, stride_bk, BLOCK_SIZE_K, K, CACHE_MODIFIER_B, mask_k=False, is_row_major=False) + # FIXME: VECTORIZED_LOAD_SIZE is temporary workaround for torch.compile - remove once upstream is fixed + a = load(A, row_indices, k0, stride_am, stride_ak, BLOCK_SIZE_K, K, CACHE_MODIFIER_A, mask_k=False, is_row_major=True, VECTORIZED_LOAD_SIZE=VECTORIZED_LOAD_SIZE) + b = load(B, col_indices, k0, stride_bn, stride_bk, BLOCK_SIZE_K, K, CACHE_MODIFIER_B, mask_k=False, is_row_major=False, VECTORIZED_LOAD_SIZE=VECTORIZED_LOAD_SIZE) # Compute - Math only acc = multiply_accumulate(acc, a, b, QUANTIZED, ALLOW_TF32) @@ -100,8 +102,9 @@ def gemm_loop( k0 = loop_k * BLOCK_SIZE_K # Load with masking - a = load(A, row_indices, k0, stride_am, stride_ak, BLOCK_SIZE_K, K, CACHE_MODIFIER_A, mask_k=True, is_row_major=True) - b = load(B, col_indices, k0, stride_bn, stride_bk, BLOCK_SIZE_K, K, CACHE_MODIFIER_B, mask_k=True, is_row_major=False) + # FIXME: VECTORIZED_LOAD_SIZE is temporary workaround for torch.compile - remove once upstream is fixed + a = load(A, row_indices, k0, stride_am, stride_ak, BLOCK_SIZE_K, K, CACHE_MODIFIER_A, mask_k=True, is_row_major=True, VECTORIZED_LOAD_SIZE=VECTORIZED_LOAD_SIZE) + b = load(B, col_indices, k0, stride_bn, stride_bk, BLOCK_SIZE_K, K, CACHE_MODIFIER_B, mask_k=True, is_row_major=False, VECTORIZED_LOAD_SIZE=VECTORIZED_LOAD_SIZE) # Compute acc = multiply_accumulate(acc, a, b, QUANTIZED, ALLOW_TF32) diff --git a/include/tritonblas/kernels/stages/memory/load.py b/include/tritonblas/kernels/stages/memory/load.py index ac7facb..0328650 100644 --- a/include/tritonblas/kernels/stages/memory/load.py +++ b/include/tritonblas/kernels/stages/memory/load.py @@ -16,6 +16,7 @@ def load( CACHE_MODIFIER: tl.constexpr, mask_k: tl.constexpr = False, is_row_major: tl.constexpr = True, + VECTORIZED_LOAD_SIZE: tl.constexpr = 16, # FIXME: Temporary workaround for torch.compile - remove once upstream is fixed ): """ Load a single tile from global memory. @@ -39,6 +40,8 @@ def load( """ # Compute K indices rk = k0 + tl.arange(0, BLOCK_SIZE_K) + + # vec_load_size: tl.constexpr = VECTORIZED_LOAD_SIZE # Compute addresses based on layout if is_row_major: @@ -46,9 +49,9 @@ def load( ptrs = matrix_ptr + indices[:, None] * stride_major + rk[None, :] * stride_k # Apply alignment hints if stride_k == 1: - ptrs = tl.multiple_of(ptrs, (1, 16)) + ptrs = tl.multiple_of(ptrs, (1, VECTORIZED_LOAD_SIZE)) else: - ptrs = tl.multiple_of(ptrs, (16, 1)) + ptrs = tl.multiple_of(ptrs, (VECTORIZED_LOAD_SIZE, 1)) # Load with optional K masking if mask_k: tile = tl.load(ptrs, mask=rk[None, :] < K, other=0.0, cache_modifier=CACHE_MODIFIER) @@ -59,9 +62,9 @@ def load( ptrs = matrix_ptr + rk[:, None] * stride_k + indices[None, :] * stride_major # Apply alignment hints if stride_k == 1: - ptrs = tl.multiple_of(ptrs, (16, 1)) + ptrs = tl.multiple_of(ptrs, (VECTORIZED_LOAD_SIZE, 1)) else: - ptrs = tl.multiple_of(ptrs, (1, 16)) + ptrs = tl.multiple_of(ptrs, (1, VECTORIZED_LOAD_SIZE)) # Load with optional K masking if mask_k: tile = tl.load(ptrs, mask=rk[:, None] < K, other=0.0, cache_modifier=CACHE_MODIFIER) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 051c221..69e8531 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -23,6 +23,25 @@ _global_P = torch.empty(MAX_SMS, MAX_BLOCK_SIZE, device="cuda", dtype=torch.float32) +def _compute_safe_vectorize_size(a: torch.Tensor) -> int: + """ + Compute a safe VECTORIZED_LOAD_SIZE for torch.compile workaround. + + FIXME: Temporary workaround for torch.compile Triton bug - remove once upstream is fixed. + The bundled Triton version has a bug where tl.multiple_of can generate illegal + vectorized loads. This computes a safe alignment value based on M dimension and dtype. + + Returns the largest power-of-2 X such that (M * dtype_size_bytes) % X == 0. + """ + M = a.shape[0] + dtype_size = a.element_size() + total_bytes = M * dtype_size + + # Find largest power of 2 that divides total_bytes using bit manipulation + # n & (-n) isolates the lowest set bit, which is the largest power of 2 factor + return total_bytes & (-total_bytes) + + # Function will behave like an LRU-Cache of heuristic results # Saves several microseconds for previously seen problems by not rerunning the heuristic unnecessarily #@functools.lru_cache(maxsize=1024) @@ -59,6 +78,7 @@ def persistent_matmul_lt( a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, + _vectorized_load_size: Optional[int] = None, # FIXME: torch.compile workaround - remove once upstream fixed ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape @@ -95,6 +115,9 @@ def persistent_matmul_lt( chunk_size = gsize_m * gsize_m chunk_size = min(chunk_size, total_programs // num_xcds) + # FIXME: Temporary workaround for torch.compile - remove once upstream is fixed + VECTORIZED_LOAD_SIZE = _vectorized_load_size if _vectorized_load_size is not None else 16 + # TODO: Support other matmul algs. #kk = persistent_matmul[(grids,)]( kk = wrap_triton(persistent_matmul)[(grids,)]( @@ -132,6 +155,7 @@ def persistent_matmul_lt( matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + VECTORIZED_LOAD_SIZE=VECTORIZED_LOAD_SIZE, # FIXME: Temporary workaround for torch.compile - remove once upstream is fixed ) return c @@ -146,6 +170,7 @@ def streamk_matmul_lt( a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, + _vectorized_load_size: Optional[int] = None, # FIXME: torch.compile workaround - remove once upstream fixed ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape @@ -267,7 +292,8 @@ def _matmul( a: torch.Tensor, b: torch.Tensor, enable_streamk: Optional[bool] = False, - sk_grid: Optional[int] = None + sk_grid: Optional[int] = None, + _vectorized_load_size: Optional[int] = None, # FIXME: torch.compile workaround - remove once upstream fixed ) -> torch.Tensor: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape @@ -279,9 +305,9 @@ def _matmul( # Query Origami for solution selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) if enable_streamk: - return streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) + return streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid, _vectorized_load_size=_vectorized_load_size) else: - return persistent_matmul_lt(a, b, out, selector) + return persistent_matmul_lt(a, b, out, selector, _vectorized_load_size=_vectorized_load_size) def _setup_context_matmul_backwards( @@ -289,10 +315,11 @@ def _setup_context_matmul_backwards( inputs: tuple[Any, ...], output: Any ): - a, b, enable_streamk, sk_grid = inputs + a, b, enable_streamk, sk_grid, _vectorized_load_size = inputs ctx.save_for_backward(a, b) ctx.enable_streamk = enable_streamk ctx.sk_grid = sk_grid + # Note: _vectorized_load_size is not saved - backward pass will recompute if needed def _matmul_backwards( @@ -314,10 +341,10 @@ def _matmul_backwards( a_t = a.T.contiguous() grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid) - # tuple[a, b, enable_streamk, sk_grid] + # tuple[a, b, enable_streamk, sk_grid, _vectorized_load_size] # First 2 must be in the order that matches matmul()'s forward args - # Last 2 are not part of the gradient and so are None - return grad_a, grad_b, None, None + # Last 3 are not part of the gradient and so are None + return grad_a, grad_b, None, None, None _matmul.register_autograd(_matmul_backwards, @@ -330,7 +357,8 @@ def _matmul_out( b: torch.Tensor, out: torch.Tensor, enable_streamk: Optional[bool] = False, - sk_grid: Optional[int] = None + sk_grid: Optional[int] = None, + _vectorized_load_size: Optional[int] = None, # FIXME: torch.compile workaround - remove once upstream fixed ) -> None: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape @@ -340,9 +368,9 @@ def _matmul_out( selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) if enable_streamk: - streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) + streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid, _vectorized_load_size=_vectorized_load_size) else: - persistent_matmul_lt(a, b, out, selector) + persistent_matmul_lt(a, b, out, selector, _vectorized_load_size=_vectorized_load_size) # Custom torch ops cannot return a value which is an alias of an input. So # even though torch returns a pointer to the out arg when used, we can't. @@ -356,9 +384,15 @@ def matmul( enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None ) -> Optional[torch.Tensor]: + # FIXME: Workaround for torch.compile Triton bug - detect at trace time + # This check works here because matmul() is traced by dynamo (not a @triton_op) + _vectorized_load_size = None + if torch.compiler.is_compiling(): + _vectorized_load_size = _compute_safe_vectorize_size(a) + # If no out tensor provided - we do the allocation - we support autograd if out is None: - return _matmul(a, b, enable_streamk, sk_grid) + return _matmul(a, b, enable_streamk, sk_grid, _vectorized_load_size) # If out tensor provided - in-place - we do NOT support autograd # Check for autograd conditions (global and per-tensor) @@ -371,7 +405,7 @@ def matmul( "tritonblas.matmul(): functions with out=... arguments don't support " "automatic differentiation, but one of the arguments requires grad." ) - return _matmul_out(a, b, out, enable_streamk, sk_grid) + return _matmul_out(a, b, out, enable_streamk, sk_grid, _vectorized_load_size) def matmul_a8w8( @@ -508,7 +542,8 @@ def _addmm( a: torch.Tensor, b: torch.Tensor, enable_streamk: Optional[bool] = False, - sk_grid: Optional[int] = None + sk_grid: Optional[int] = None, + _vectorized_load_size: Optional[int] = None, # FIXME: torch.compile workaround - remove once upstream fixed ) -> torch.Tensor: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape @@ -521,9 +556,9 @@ def _addmm( out = a.new_empty(M, N) if enable_streamk: - return streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) + return streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid, _vectorized_load_size=_vectorized_load_size) else: - return persistent_matmul_lt(a, b, out, selector, bias=bias) + return persistent_matmul_lt(a, b, out, selector, bias=bias, _vectorized_load_size=_vectorized_load_size) def _setup_context_addmm_backwards( @@ -531,10 +566,11 @@ def _setup_context_addmm_backwards( inputs: tuple[Any, ...], output: Any ): - bias, a, b, enable_streamk, sk_grid = inputs + bias, a, b, enable_streamk, sk_grid, _vectorized_load_size = inputs ctx.save_for_backward(a, b) ctx.enable_streamk = enable_streamk ctx.sk_grid = sk_grid + # Note: _vectorized_load_size is not saved - backward pass will recompute if needed def _addmm_backwards( @@ -559,10 +595,10 @@ def _addmm_backwards( # grad_bias = sum(grad_output) grad_bias = grad_output.sum(dim=0) - # tuple[bias, a, b, enable_streamk, sk_grid] + # tuple[bias, a, b, enable_streamk, sk_grid, _vectorized_load_size] # First 3 must be in the order that matches addmm()'s forward args - # Last 2 are not part of the gradient and so are None - return grad_bias, grad_a, grad_b, None, None + # Last 3 are not part of the gradient and so are None + return grad_bias, grad_a, grad_b, None, None, None _addmm.register_autograd(_addmm_backwards, @@ -576,7 +612,8 @@ def _addmm_out( b: torch.Tensor, out: torch.Tensor, enable_streamk: Optional[bool] = False, - sk_grid: Optional[int] = None + sk_grid: Optional[int] = None, + _vectorized_load_size: Optional[int] = None, # FIXME: torch.compile workaround - remove once upstream fixed ) -> None: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape @@ -586,9 +623,9 @@ def _addmm_out( selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) if enable_streamk: - streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) + streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid, _vectorized_load_size=_vectorized_load_size) else: - persistent_matmul_lt(a, b, out, selector, bias=bias) + persistent_matmul_lt(a, b, out, selector, bias=bias, _vectorized_load_size=_vectorized_load_size) # Custom torch ops cannot return a value which is an alias of an input. So # even though torch returns a pointer to the out arg when used, we can't. @@ -603,9 +640,15 @@ def addmm( enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None ) -> Optional[torch.Tensor]: + # FIXME: Workaround for torch.compile Triton bug - detect at trace time + # This check works here because addmm() is traced by dynamo (not a @triton_op) + _vectorized_load_size = None + if torch.compiler.is_compiling(): + _vectorized_load_size = _compute_safe_vectorize_size(a) + # If no out tensor provided - we do the allocation - we support autograd if out is None: - return _addmm(bias, a, b, enable_streamk, sk_grid) + return _addmm(bias, a, b, enable_streamk, sk_grid, _vectorized_load_size) # If out tensor provided - in-place - we do NOT support autograd # Check for autograd conditions (global and per-tensor) @@ -619,5 +662,5 @@ def addmm( "tritonblas.addmm(): functions with out=... arguments don't support " "automatic differentiation, but one of the arguments requires grad." ) - return _addmm_out(bias, a, b, out, enable_streamk, sk_grid) + return _addmm_out(bias, a, b, out, enable_streamk, sk_grid, _vectorized_load_size) diff --git a/include/tritonblas/origami.py b/include/tritonblas/origami.py index 6195e0f..93a73e0 100644 --- a/include/tritonblas/origami.py +++ b/include/tritonblas/origami.py @@ -1,6 +1,7 @@ import itertools import torch import origami +import math from math import ceil diff --git a/tests/test_addmm.py b/tests/test_addmm_correctness.py similarity index 97% rename from tests/test_addmm.py rename to tests/test_addmm_correctness.py index acd200d..b24aeb6 100644 --- a/tests/test_addmm.py +++ b/tests/test_addmm_correctness.py @@ -17,6 +17,9 @@ # If we don't increase this, torch will complain about too many recompilations. torch._dynamo.config.cache_size_limit = 10000 +# Also disable caches so every compile is fresh and new issues are caught. +# Note this causes a single UserWarning that notes caches are disabled. +torch._inductor.config.force_disable_caches = True # Standard test dimensions STANDARD_DIMS = [ @@ -35,6 +38,8 @@ (128, 8, 256), # N < 16 (8, 128, 256), # M < 16 (12, 12, 512), # Small M and N + (15, 17, 512), # Weird and small M and N + (19, 13, 512), # Weird and small M and N (128, 64, 12), # Small K ] diff --git a/tests/test_matmul_correctness.py b/tests/test_matmul_correctness.py new file mode 100644 index 0000000..1178b48 --- /dev/null +++ b/tests/test_matmul_correctness.py @@ -0,0 +1,293 @@ +""" +Tests for tritonblas.matmul with torch.autograd and torch.compile support + +Tests cover: +1. Forward pass correctness against torch.mm +2. Backward pass gradient correctness against torch.mm +3. In-place (out=...) functionality and autograd restrictions +4. Edge cases including small dimensions +5. torch.compile compatibility +6. Persistent vs. StreamK compatibility +""" + +import pytest +import torch +import tritonblas + + +# If we don't increase this, torch will complain about too many recompilations. +torch._dynamo.config.cache_size_limit = 10000 +# Also disable caches so every compile is fresh and new issues are caught. +# Note this causes a single UserWarning that notes caches are disabled. +torch._inductor.config.force_disable_caches = True + +# Standard test dimensions +STANDARD_DIMS = [ + (128, 256, 512), # Medium sizes + (256, 256, 256), # Square + (512, 1024, 768), # Larger + (2048, 1024, 512), # Wide output + (1024, 2048, 512), # Tall output +] + +# Edge case dimensions (small dimensions, N < 16 cases) +EDGE_CASE_DIMS = [ + (32, 32, 32), # Small square + (64, 16, 128), # Small N + (16, 64, 128), # Small M + (128, 8, 256), # N < 16 + (8, 128, 256), # M < 16 + (12, 12, 512), # Small M and N + (15, 17, 512), # Weird and small M and N + (19, 13, 512), # Weird and small M and N + (128, 64, 12), # Small K +] + +# Skinny matrix dimensions (stress tests) +SKINNY_DIMS = [ + (16, 16, 4096), # Very large K + (32, 32, 8192), # Large K +] + +# Data types to test +DTYPES = [torch.bfloat16, torch.float16] + +# Whether to test with torch.compile +USE_COMPILE = [False, True] + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_matmul_forward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.matmul forward pass matches torch.mm.""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # tritonblas result + result = matmul_fn(a, b) + + # torch reference + expected = torch.mm(a, b) + + # Check forward correctness with relaxed tolerance for low precision + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_matmul_backward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.matmul backward pass produces correct gradients.""" + torch.manual_seed(42) + + # Create inputs with requires_grad for tritonblas + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + + # Clone for torch reference + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Forward pass + result = matmul_fn(a, b) + result_ref = torch.mm(a_ref, b_ref) + + # Backward pass with same upstream gradient + grad_output = torch.randn_like(result) + result.backward(grad_output) + result_ref.backward(grad_output) + + # Check gradients match + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1, + msg="a gradient mismatch") + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1, + msg="b gradient mismatch") + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", SKINNY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_matmul_skinny_matrices(m, n, k, dtype, use_compile): + """Test matmul with skinny matrices (large K dimension).""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Forward + result = matmul_fn(a, b) + result_ref = torch.mm(a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_inplace_with_grad_raises(use_compile): + """Test that matmul with out=... raises RuntimeError when autograd is enabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + with pytest.raises(RuntimeError, match="don't support automatic differentiation"): + matmul_fn(a, b, out=out) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_inplace_without_grad_works(use_compile): + """Test that matmul with out=... works when autograd is disabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Should work with torch.no_grad() + with torch.no_grad(): + result = matmul_fn(a, b, out=out) + + # In-place path returns None (custom ops don't support aliasing) + assert result is None, "in-place matmul should return None" + + # Verify correctness against torch + expected = torch.mm(a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_inplace_output_correctness(use_compile): + """Test that matmul in-place mode produces correct results.""" + torch.manual_seed(42) + m, n, k = 128, 256, 512 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + with torch.no_grad(): + matmul_fn(a, b, out=out) + + expected = torch.mm(a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_no_grad_tensors(use_compile): + """Test matmul works when input tensors don't require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=False) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + result = matmul_fn(a, b) + expected = torch.mm(a, b) + + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_partial_grad(use_compile): + """Test matmul when only some inputs require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + # Only a requires grad + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone() + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + result = matmul_fn(a, b) + result_ref = torch.mm(a_ref, b_ref) + + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", [False, True]) +def test_matmul_streamk_modes(enable_streamk, use_compile): + """Test matmul with different streamk settings.""" + torch.manual_seed(42) + m, n, k = 256, 256, 256 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Forward + result = matmul_fn(a, b, enable_streamk=enable_streamk) + result_ref = torch.mm(a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) From c8731b3cebc385b6487af1ebce615b174c4eaa01 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Wed, 4 Feb 2026 18:35:36 -0500 Subject: [PATCH 09/13] Add temp fix for GPU runtime context errors There is a bug in testing at the moment which causes CUDA RuntimeErrors due to muliple forked processes trying to initialize the GPU runtime during torch.compile runs. The temporary fix is to force Inductor to run single-threaded after which the error is gone and all tests pass, but changing the multiprocessing methodology to 'spawn' over 'fork' may be a better long-term solution (or submitting a bug with torch because this wasn't happening before). --- tests/test_addmm_correctness.py | 6 ++++++ tests/test_matmul_correctness.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/tests/test_addmm_correctness.py b/tests/test_addmm_correctness.py index b24aeb6..5e3e8be 100644 --- a/tests/test_addmm_correctness.py +++ b/tests/test_addmm_correctness.py @@ -20,6 +20,12 @@ # Also disable caches so every compile is fresh and new issues are caught. # Note this causes a single UserWarning that notes caches are disabled. torch._inductor.config.force_disable_caches = True +# FIXME: Inductor seems to be initializing multiple CUDA runtimes somehow in +# relation to some of triton's new features which is causing errors unrelated to +# tritonBLAS. The error tells you to change the multiplrocessing strategy to +# 'spawn' but that actually doesn't fix the issue - you have to force +# single-threaded compilation. This needs to be fixed upstream in torch/triton. +torch._inductor.config.compile_threads = 1 # Standard test dimensions STANDARD_DIMS = [ diff --git a/tests/test_matmul_correctness.py b/tests/test_matmul_correctness.py index 1178b48..d555ba5 100644 --- a/tests/test_matmul_correctness.py +++ b/tests/test_matmul_correctness.py @@ -20,6 +20,12 @@ # Also disable caches so every compile is fresh and new issues are caught. # Note this causes a single UserWarning that notes caches are disabled. torch._inductor.config.force_disable_caches = True +# FIXME: Inductor seems to be initializing multiple CUDA runtimes somehow in +# relation to some of triton's new features which is causing errors unrelated to +# tritonBLAS. The error tells you to change the multiplrocessing strategy to +# 'spawn' but that actually doesn't fix the issue - you have to force +# single-threaded compilation. This needs to be fixed upstream in torch/triton. +torch._inductor.config.compile_threads = 1 # Standard test dimensions STANDARD_DIMS = [ From be974a536266f4ff1e2347aebcbaa386b8d9e980 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Fri, 6 Feb 2026 13:40:51 -0500 Subject: [PATCH 10/13] Update CI to build container with newer torch --- .github/workflows/test.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bf1404c..f731991 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,7 +12,7 @@ jobs: runs-on: [self-hosted, amd-gpu] container: - image: rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1 + image: rocm/dev-ubuntu-24.04:7.2-complete options: --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined steps: @@ -30,11 +30,16 @@ jobs: - name: Install system dependencies run: | apt-get update - apt-get install -y git - - - name: Install Python dependencies + apt-get install -y git python3.12 python3.12-venv python3-pip python3.12-dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 + + - name: Install PyTorch (ROCm) run: | python3 -m pip install --upgrade pip + pip3 install torch --index-url https://download.pytorch.org/whl/rocm7.1 + + - name: Install Python dependencies + run: | pip3 install -U triton pip3 install -e . From 2da21b65c51c22ae1cefe1e9b46f95e02e0661f0 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Fri, 6 Feb 2026 13:52:37 -0500 Subject: [PATCH 11/13] Rearrange CI steps to install git before checkout --- .github/workflows/test.yml | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f731991..df749f0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,31 +16,34 @@ jobs: options: --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - submodules: recursive - - - name: Set up environment - run: | - echo "Setting up ROCm environment..." - export ROCM_PATH=/opt/rocm - export PATH=$ROCM_PATH/bin:$PATH - - name: Install system dependencies run: | apt-get update apt-get install -y git python3.12 python3.12-venv python3-pip python3.12-dev update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 - - name: Install PyTorch (ROCm) + - name: Set up environment + run: | + echo "Setting up ROCm environment..." + export ROCM_PATH=/opt/rocm + export PATH=$ROCM_PATH/bin:$PATH + + - name: Install PyTorch with ROCm support run: | python3 -m pip install --upgrade pip pip3 install torch --index-url https://download.pytorch.org/whl/rocm7.1 - - name: Install Python dependencies + - name: Install Triton run: | pip3 install -U triton + + - name: Checkout tritonBLAS code + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Install tritonBLAS + run: | pip3 install -e . - name: Verify installation From 166e978560e7153fd9830b5d783fcb04094c12a7 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Fri, 6 Feb 2026 14:04:37 -0500 Subject: [PATCH 12/13] Remove CI container marker locking pip installs --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index df749f0..3af37d2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,8 @@ jobs: apt-get update apt-get install -y git python3.12 python3.12-venv python3-pip python3.12-dev update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 + # Remove PEP 668 externally-managed marker so pip works in this disposable container since we're not using a virtual environment + rm -f /usr/lib/python3.12/EXTERNALLY-MANAGED - name: Set up environment run: | From a8138c25e1001c5fd8f90ba27dae1152c5ac4698 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Fri, 6 Feb 2026 14:18:34 -0500 Subject: [PATCH 13/13] Remove pip upgrade because pip is now apt-managed --- .github/workflows/test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3af37d2..fede161 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,6 @@ jobs: - name: Install PyTorch with ROCm support run: | - python3 -m pip install --upgrade pip pip3 install torch --index-url https://download.pytorch.org/whl/rocm7.1 - name: Install Triton