From 0f0503b9e2c5bb0b3d324d84c85aa2575354df6d Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Mon, 9 Feb 2026 16:04:19 -0500 Subject: [PATCH 1/2] Fix StreamK GEMM bias --- include/tritonblas/kernels/streamk_gemm.py | 48 ++++++++++++---------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/include/tritonblas/kernels/streamk_gemm.py b/include/tritonblas/kernels/streamk_gemm.py index 632d1a9..3ed864a 100644 --- a/include/tritonblas/kernels/streamk_gemm.py +++ b/include/tritonblas/kernels/streamk_gemm.py @@ -76,8 +76,8 @@ def streamk_matmul( B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn if BIAS: - bias_ = bias_ptr + rm * stride_bias - bias = tl.load(bias_, mask=rm < M, other=0.0) + bias_ = bias_ptr + rn * stride_bias + bias = tl.load(bias_, mask=rn < N, other=0.0) loop_k = tl.cdiv(K, BLOCK_SIZE_K) if not EVEN_K: @@ -140,12 +140,12 @@ def streamk_matmul( if QUANTIZED: # For quantized mode: convert bias to float32, add to acc, then convert to output dtype bias_float = bias.to(tl.float32) - c = acc + bias_float[:, None] + c = acc + bias_float[None, :] c = c.to(C.type.element_ty) else: # For non-quantized mode: convert acc to output dtype, then add bias c = acc.to(C.type.element_ty) - c += bias[:, None] + c += bias[None, :] else: c = acc.to(C.type.element_ty) @@ -210,8 +210,8 @@ def streamk_matmul( B_BASE = tl.multiple_of(B_BASE, (1, 16)) if BIAS: - bias_ = bias_ptr + rm * stride_bias - bias = tl.load(bias_, mask=rm < M, other=0.0) + bias_ = bias_ptr + rn * stride_bias + bias = tl.load(bias_, mask=rn < N, other=0.0) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for current_iter in range(start_iter, end_iter): @@ -322,28 +322,34 @@ def streamk_matmul( next_pid += 1 # Unified bias handling for Stream-K section + # Bias is applied along N dimension (columns), broadcast across M (rows) if BIAS: - # Split bias for top and bottom halves - bias_top = bias[:BLOCK_SIZE_M // 2] - bias_bottom = bias[BLOCK_SIZE_M // 2:] + # Load bias for left and right halves (N dimension) with explicit indices + # Apply modulo N to match the behavior of the original rn computation + rn_bias_left = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N // 2)) % N + rn_bias_right = (pid_n * BLOCK_SIZE_N + BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2)) % N - bias_top_reshaped = tl.reshape(bias_top, (BLOCK_SIZE_M // 2, 1)) - bias_bottom_reshaped = tl.reshape(bias_bottom, (BLOCK_SIZE_M // 2, 1)) + bias_left = tl.load(bias_ptr + rn_bias_left * stride_bias, mask=rn_bias_left < N, other=0.0) + bias_right = tl.load(bias_ptr + rn_bias_right * stride_bias, mask=rn_bias_right < N, other=0.0) + + # Reshape for broadcasting: (1, N//2) to broadcast across M rows + bias_left_reshaped = tl.reshape(bias_left, (1, BLOCK_SIZE_N // 2)) + bias_right_reshaped = tl.reshape(bias_right, (1, BLOCK_SIZE_N // 2)) if QUANTIZED: # For quantized mode: convert bias to float32 before adding - bias_top_float = bias_top_reshaped.to(tl.float32) - bias_bottom_float = bias_bottom_reshaped.to(tl.float32) - acc00 += bias_top_float - acc01 += bias_top_float - acc10 += bias_bottom_float - acc11 += bias_bottom_float + bias_left_float = bias_left_reshaped.to(tl.float32) + bias_right_float = bias_right_reshaped.to(tl.float32) + acc00 += bias_left_float + acc01 += bias_right_float + acc10 += bias_left_float + acc11 += bias_right_float else: # For non-quantized mode: add bias directly - acc00 += bias_top_reshaped - acc01 += bias_top_reshaped - acc10 += bias_bottom_reshaped - acc11 += bias_bottom_reshaped + acc00 += bias_left_reshaped + acc01 += bias_right_reshaped + acc10 += bias_left_reshaped + acc11 += bias_right_reshaped # Convert to output dtype c00 = acc00.to(C.type.element_ty) From 2a9add4d43952b05706f88be045a927bca457be3 Mon Sep 17 00:00:00 2001 From: Alex Underwood Date: Mon, 9 Feb 2026 16:45:50 -0500 Subject: [PATCH 2/2] Permute StreamK tests to expand coverage --- tests/test_addmm_correctness.py | 76 +++++++++++--------------------- tests/test_matmul_correctness.py | 73 +++++++++++------------------- 2 files changed, 52 insertions(+), 97 deletions(-) diff --git a/tests/test_addmm_correctness.py b/tests/test_addmm_correctness.py index 5e3e8be..13488c5 100644 --- a/tests/test_addmm_correctness.py +++ b/tests/test_addmm_correctness.py @@ -61,11 +61,15 @@ # Whether to test with torch.compile USE_COMPILE = [False, True] +# Whether to enable StreamK (vs. Persistent path) +ENABLE_STREAMK = [False, True] + @pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) @pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) @pytest.mark.parametrize("dtype", DTYPES) -def test_addmm_forward_correctness(m, n, k, dtype, use_compile): +def test_addmm_forward_correctness(m, n, k, dtype, enable_streamk, use_compile): """Test that tritonblas.addmm forward pass matches torch.addmm.""" torch.manual_seed(42) @@ -78,7 +82,7 @@ def test_addmm_forward_correctness(m, n, k, dtype, use_compile): addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) # tritonblas result - result = addmm_fn(bias, a, b) + result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) # torch reference expected = torch.addmm(bias, a, b) @@ -88,9 +92,10 @@ def test_addmm_forward_correctness(m, n, k, dtype, use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) @pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) @pytest.mark.parametrize("dtype", DTYPES) -def test_addmm_backward_correctness(m, n, k, dtype, use_compile): +def test_addmm_backward_correctness(m, n, k, dtype, enable_streamk, use_compile): """Test that tritonblas.addmm backward pass produces correct gradients.""" torch.manual_seed(42) @@ -109,7 +114,7 @@ def test_addmm_backward_correctness(m, n, k, dtype, use_compile): addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) # Forward pass - result = addmm_fn(bias, a, b) + result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) result_ref = torch.addmm(bias_ref, a_ref, b_ref) # Backward pass with same upstream gradient @@ -127,9 +132,10 @@ def test_addmm_backward_correctness(m, n, k, dtype, use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) @pytest.mark.parametrize("m, n, k", SKINNY_DIMS) @pytest.mark.parametrize("dtype", DTYPES) -def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): +def test_addmm_skinny_matrices(m, n, k, dtype, enable_streamk, use_compile): """Test addmm with skinny matrices (large K dimension).""" torch.manual_seed(42) @@ -146,7 +152,7 @@ def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) # Forward - result = addmm_fn(bias, a, b) + result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) result_ref = torch.addmm(bias_ref, a_ref, b_ref) torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) @@ -161,7 +167,8 @@ def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_addmm_inplace_with_grad_raises(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_addmm_inplace_with_grad_raises(enable_streamk, use_compile): """Test that addmm with out=... raises RuntimeError when autograd is enabled.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -177,11 +184,12 @@ def test_addmm_inplace_with_grad_raises(use_compile): addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) with pytest.raises(RuntimeError, match="don't support automatic differentiation"): - addmm_fn(bias, a, b, out=out) + addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk) @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_addmm_inplace_without_grad_works(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_addmm_inplace_without_grad_works(enable_streamk, use_compile): """Test that addmm with out=... works when autograd is disabled.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -198,7 +206,7 @@ def test_addmm_inplace_without_grad_works(use_compile): # Should work with torch.no_grad() with torch.no_grad(): - result = addmm_fn(bias, a, b, out=out) + result = addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk) # In-place path returns None (custom ops don't support aliasing) assert result is None, "in-place addmm should return None" @@ -209,7 +217,8 @@ def test_addmm_inplace_without_grad_works(use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_addmm_inplace_output_correctness(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_addmm_inplace_output_correctness(enable_streamk, use_compile): """Test that addmm in-place mode produces correct results.""" torch.manual_seed(42) m, n, k = 128, 256, 512 @@ -225,14 +234,15 @@ def test_addmm_inplace_output_correctness(use_compile): addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) with torch.no_grad(): - addmm_fn(bias, a, b, out=out) + addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk) expected = torch.addmm(bias, a, b) torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_addmm_no_grad_tensors(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_addmm_no_grad_tensors(enable_streamk, use_compile): """Test addmm works when input tensors don't require grad.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -246,14 +256,15 @@ def test_addmm_no_grad_tensors(use_compile): if use_compile: addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) - result = addmm_fn(bias, a, b) + result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) expected = torch.addmm(bias, a, b) torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_addmm_partial_grad(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_addmm_partial_grad(enable_streamk, use_compile): """Test addmm when only some inputs require grad.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -272,45 +283,10 @@ def test_addmm_partial_grad(use_compile): if use_compile: addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) - result = addmm_fn(bias, a, b) - result_ref = torch.addmm(bias_ref, a_ref, b_ref) - - result.sum().backward() - result_ref.sum().backward() - - torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) - - -@pytest.mark.parametrize("use_compile", USE_COMPILE) -@pytest.mark.parametrize("enable_streamk", [False, True]) -def test_addmm_streamk_modes(enable_streamk, use_compile): - """Test addmm with different streamk settings.""" - torch.manual_seed(42) - m, n, k = 256, 256, 256 - dtype = torch.bfloat16 - - a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) - b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) - bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) - - a_ref = a.detach().clone().requires_grad_(True) - b_ref = b.detach().clone().requires_grad_(True) - bias_ref = bias.detach().clone().requires_grad_(True) - - addmm_fn = tritonblas.addmm - if use_compile: - addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) - - # Forward result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) result_ref = torch.addmm(bias_ref, a_ref, b_ref) - torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) - - # Backward result.sum().backward() result_ref.sum().backward() torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) - torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) - torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1) diff --git a/tests/test_matmul_correctness.py b/tests/test_matmul_correctness.py index d555ba5..e7320ef 100644 --- a/tests/test_matmul_correctness.py +++ b/tests/test_matmul_correctness.py @@ -61,11 +61,15 @@ # Whether to test with torch.compile USE_COMPILE = [False, True] +# Whether to enable StreamK (vs. Persistent path) +ENABLE_STREAMK = [False, True] + @pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) @pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) @pytest.mark.parametrize("dtype", DTYPES) -def test_matmul_forward_correctness(m, n, k, dtype, use_compile): +def test_matmul_forward_correctness(m, n, k, dtype, enable_streamk, use_compile): """Test that tritonblas.matmul forward pass matches torch.mm.""" torch.manual_seed(42) @@ -77,7 +81,7 @@ def test_matmul_forward_correctness(m, n, k, dtype, use_compile): matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) # tritonblas result - result = matmul_fn(a, b) + result = matmul_fn(a, b, enable_streamk=enable_streamk) # torch reference expected = torch.mm(a, b) @@ -87,9 +91,10 @@ def test_matmul_forward_correctness(m, n, k, dtype, use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) @pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) @pytest.mark.parametrize("dtype", DTYPES) -def test_matmul_backward_correctness(m, n, k, dtype, use_compile): +def test_matmul_backward_correctness(m, n, k, dtype, enable_streamk, use_compile): """Test that tritonblas.matmul backward pass produces correct gradients.""" torch.manual_seed(42) @@ -106,7 +111,7 @@ def test_matmul_backward_correctness(m, n, k, dtype, use_compile): matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) # Forward pass - result = matmul_fn(a, b) + result = matmul_fn(a, b, enable_streamk=enable_streamk) result_ref = torch.mm(a_ref, b_ref) # Backward pass with same upstream gradient @@ -122,9 +127,10 @@ def test_matmul_backward_correctness(m, n, k, dtype, use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) @pytest.mark.parametrize("m, n, k", SKINNY_DIMS) @pytest.mark.parametrize("dtype", DTYPES) -def test_matmul_skinny_matrices(m, n, k, dtype, use_compile): +def test_matmul_skinny_matrices(m, n, k, dtype, enable_streamk, use_compile): """Test matmul with skinny matrices (large K dimension).""" torch.manual_seed(42) @@ -139,7 +145,7 @@ def test_matmul_skinny_matrices(m, n, k, dtype, use_compile): matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) # Forward - result = matmul_fn(a, b) + result = matmul_fn(a, b, enable_streamk=enable_streamk) result_ref = torch.mm(a_ref, b_ref) torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) @@ -153,7 +159,8 @@ def test_matmul_skinny_matrices(m, n, k, dtype, use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_matmul_inplace_with_grad_raises(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_matmul_inplace_with_grad_raises(enable_streamk, use_compile): """Test that matmul with out=... raises RuntimeError when autograd is enabled.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -168,11 +175,12 @@ def test_matmul_inplace_with_grad_raises(use_compile): matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) with pytest.raises(RuntimeError, match="don't support automatic differentiation"): - matmul_fn(a, b, out=out) + matmul_fn(a, b, out=out, enable_streamk=enable_streamk) @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_matmul_inplace_without_grad_works(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_matmul_inplace_without_grad_works(enable_streamk, use_compile): """Test that matmul with out=... works when autograd is disabled.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -188,7 +196,7 @@ def test_matmul_inplace_without_grad_works(use_compile): # Should work with torch.no_grad() with torch.no_grad(): - result = matmul_fn(a, b, out=out) + result = matmul_fn(a, b, out=out, enable_streamk=enable_streamk) # In-place path returns None (custom ops don't support aliasing) assert result is None, "in-place matmul should return None" @@ -199,7 +207,8 @@ def test_matmul_inplace_without_grad_works(use_compile): @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_matmul_inplace_output_correctness(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_matmul_inplace_output_correctness(enable_streamk, use_compile): """Test that matmul in-place mode produces correct results.""" torch.manual_seed(42) m, n, k = 128, 256, 512 @@ -214,14 +223,15 @@ def test_matmul_inplace_output_correctness(use_compile): matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) with torch.no_grad(): - matmul_fn(a, b, out=out) + matmul_fn(a, b, out=out, enable_streamk=enable_streamk) expected = torch.mm(a, b) torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_matmul_no_grad_tensors(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_matmul_no_grad_tensors(enable_streamk, use_compile): """Test matmul works when input tensors don't require grad.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -234,14 +244,15 @@ def test_matmul_no_grad_tensors(use_compile): if use_compile: matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) - result = matmul_fn(a, b) + result = matmul_fn(a, b, enable_streamk=enable_streamk) expected = torch.mm(a, b) torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) @pytest.mark.parametrize("use_compile", USE_COMPILE) -def test_matmul_partial_grad(use_compile): +@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) +def test_matmul_partial_grad(enable_streamk, use_compile): """Test matmul when only some inputs require grad.""" torch.manual_seed(42) m, n, k = 64, 64, 64 @@ -258,42 +269,10 @@ def test_matmul_partial_grad(use_compile): if use_compile: matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) - result = matmul_fn(a, b) - result_ref = torch.mm(a_ref, b_ref) - - result.sum().backward() - result_ref.sum().backward() - - torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) - - -@pytest.mark.parametrize("use_compile", USE_COMPILE) -@pytest.mark.parametrize("enable_streamk", [False, True]) -def test_matmul_streamk_modes(enable_streamk, use_compile): - """Test matmul with different streamk settings.""" - torch.manual_seed(42) - m, n, k = 256, 256, 256 - dtype = torch.bfloat16 - - a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) - b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) - - a_ref = a.detach().clone().requires_grad_(True) - b_ref = b.detach().clone().requires_grad_(True) - - matmul_fn = tritonblas.matmul - if use_compile: - matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) - - # Forward result = matmul_fn(a, b, enable_streamk=enable_streamk) result_ref = torch.mm(a_ref, b_ref) - torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) - - # Backward result.sum().backward() result_ref.sum().backward() torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) - torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1)