Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions include/tritonblas/kernels/streamk_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def streamk_matmul(
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn

if BIAS:
bias_ = bias_ptr + rm * stride_bias
bias = tl.load(bias_, mask=rm < M, other=0.0)
bias_ = bias_ptr + rn * stride_bias
bias = tl.load(bias_, mask=rn < N, other=0.0)

loop_k = tl.cdiv(K, BLOCK_SIZE_K)
if not EVEN_K:
Expand Down Expand Up @@ -140,12 +140,12 @@ def streamk_matmul(
if QUANTIZED:
# For quantized mode: convert bias to float32, add to acc, then convert to output dtype
bias_float = bias.to(tl.float32)
c = acc + bias_float[:, None]
c = acc + bias_float[None, :]
c = c.to(C.type.element_ty)
else:
# For non-quantized mode: convert acc to output dtype, then add bias
c = acc.to(C.type.element_ty)
c += bias[:, None]
c += bias[None, :]
else:
c = acc.to(C.type.element_ty)

Expand Down Expand Up @@ -210,8 +210,8 @@ def streamk_matmul(
B_BASE = tl.multiple_of(B_BASE, (1, 16))

if BIAS:
bias_ = bias_ptr + rm * stride_bias
bias = tl.load(bias_, mask=rm < M, other=0.0)
bias_ = bias_ptr + rn * stride_bias
bias = tl.load(bias_, mask=rn < N, other=0.0)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for current_iter in range(start_iter, end_iter):
Expand Down Expand Up @@ -322,28 +322,34 @@ def streamk_matmul(
next_pid += 1

# Unified bias handling for Stream-K section
# Bias is applied along N dimension (columns), broadcast across M (rows)
if BIAS:
# Split bias for top and bottom halves
bias_top = bias[:BLOCK_SIZE_M // 2]
bias_bottom = bias[BLOCK_SIZE_M // 2:]
# Load bias for left and right halves (N dimension) with explicit indices
# Apply modulo N to match the behavior of the original rn computation
rn_bias_left = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N // 2)) % N
rn_bias_right = (pid_n * BLOCK_SIZE_N + BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2)) % N

bias_top_reshaped = tl.reshape(bias_top, (BLOCK_SIZE_M // 2, 1))
bias_bottom_reshaped = tl.reshape(bias_bottom, (BLOCK_SIZE_M // 2, 1))
bias_left = tl.load(bias_ptr + rn_bias_left * stride_bias, mask=rn_bias_left < N, other=0.0)
bias_right = tl.load(bias_ptr + rn_bias_right * stride_bias, mask=rn_bias_right < N, other=0.0)

# Reshape for broadcasting: (1, N//2) to broadcast across M rows
bias_left_reshaped = tl.reshape(bias_left, (1, BLOCK_SIZE_N // 2))
Comment on lines 324 to +336
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this Stream-K section, bias is (now correctly) loaded earlier as a full BLOCK_SIZE_N vector, but this block re-loads bias twice (bias_left/bias_right) and also introduces % N plus a redundant mask (after modulo, indices are already in-range for N>0). To reduce global memory traffic and simplify the indexing, consider reusing the already-loaded bias vector and slicing/splitting it into left/right halves for broadcasting, rather than recomputing indices and issuing additional tl.load operations.

Copilot uses AI. Check for mistakes.
bias_right_reshaped = tl.reshape(bias_right, (1, BLOCK_SIZE_N // 2))

if QUANTIZED:
# For quantized mode: convert bias to float32 before adding
bias_top_float = bias_top_reshaped.to(tl.float32)
bias_bottom_float = bias_bottom_reshaped.to(tl.float32)
acc00 += bias_top_float
acc01 += bias_top_float
acc10 += bias_bottom_float
acc11 += bias_bottom_float
bias_left_float = bias_left_reshaped.to(tl.float32)
bias_right_float = bias_right_reshaped.to(tl.float32)
acc00 += bias_left_float
acc01 += bias_right_float
acc10 += bias_left_float
acc11 += bias_right_float
else:
# For non-quantized mode: add bias directly
acc00 += bias_top_reshaped
acc01 += bias_top_reshaped
acc10 += bias_bottom_reshaped
acc11 += bias_bottom_reshaped
acc00 += bias_left_reshaped
acc01 += bias_right_reshaped
acc10 += bias_left_reshaped
acc11 += bias_right_reshaped

# Convert to output dtype
c00 = acc00.to(C.type.element_ty)
Expand Down
76 changes: 26 additions & 50 deletions tests/test_addmm_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@
# Whether to test with torch.compile
USE_COMPILE = [False, True]

# Whether to enable StreamK (vs. Persistent path)
ENABLE_STREAMK = [False, True]


@pytest.mark.parametrize("use_compile", USE_COMPILE)
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_addmm_forward_correctness(m, n, k, dtype, use_compile):
def test_addmm_forward_correctness(m, n, k, dtype, enable_streamk, use_compile):
"""Test that tritonblas.addmm forward pass matches torch.addmm."""
torch.manual_seed(42)

Expand All @@ -78,7 +82,7 @@ def test_addmm_forward_correctness(m, n, k, dtype, use_compile):
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

# tritonblas result
result = addmm_fn(bias, a, b)
result = addmm_fn(bias, a, b, enable_streamk=enable_streamk)

# torch reference
expected = torch.addmm(bias, a, b)
Expand All @@ -88,9 +92,10 @@ def test_addmm_forward_correctness(m, n, k, dtype, use_compile):


@pytest.mark.parametrize("use_compile", USE_COMPILE)
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_addmm_backward_correctness(m, n, k, dtype, use_compile):
def test_addmm_backward_correctness(m, n, k, dtype, enable_streamk, use_compile):
"""Test that tritonblas.addmm backward pass produces correct gradients."""
torch.manual_seed(42)

Expand All @@ -109,7 +114,7 @@ def test_addmm_backward_correctness(m, n, k, dtype, use_compile):
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

# Forward pass
result = addmm_fn(bias, a, b)
result = addmm_fn(bias, a, b, enable_streamk=enable_streamk)
result_ref = torch.addmm(bias_ref, a_ref, b_ref)

# Backward pass with same upstream gradient
Expand All @@ -127,9 +132,10 @@ def test_addmm_backward_correctness(m, n, k, dtype, use_compile):


@pytest.mark.parametrize("use_compile", USE_COMPILE)
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
@pytest.mark.parametrize("m, n, k", SKINNY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_addmm_skinny_matrices(m, n, k, dtype, use_compile):
def test_addmm_skinny_matrices(m, n, k, dtype, enable_streamk, use_compile):
"""Test addmm with skinny matrices (large K dimension)."""
torch.manual_seed(42)

Expand All @@ -146,7 +152,7 @@ def test_addmm_skinny_matrices(m, n, k, dtype, use_compile):
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

# Forward
result = addmm_fn(bias, a, b)
result = addmm_fn(bias, a, b, enable_streamk=enable_streamk)
result_ref = torch.addmm(bias_ref, a_ref, b_ref)

torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1)
Expand All @@ -161,7 +167,8 @@ def test_addmm_skinny_matrices(m, n, k, dtype, use_compile):


@pytest.mark.parametrize("use_compile", USE_COMPILE)
def test_addmm_inplace_with_grad_raises(use_compile):
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
def test_addmm_inplace_with_grad_raises(enable_streamk, use_compile):
"""Test that addmm with out=... raises RuntimeError when autograd is enabled."""
torch.manual_seed(42)
m, n, k = 64, 64, 64
Expand All @@ -177,11 +184,12 @@ def test_addmm_inplace_with_grad_raises(use_compile):
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

with pytest.raises(RuntimeError, match="don't support automatic differentiation"):
addmm_fn(bias, a, b, out=out)
addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk)


@pytest.mark.parametrize("use_compile", USE_COMPILE)
def test_addmm_inplace_without_grad_works(use_compile):
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
def test_addmm_inplace_without_grad_works(enable_streamk, use_compile):
"""Test that addmm with out=... works when autograd is disabled."""
torch.manual_seed(42)
m, n, k = 64, 64, 64
Expand All @@ -198,7 +206,7 @@ def test_addmm_inplace_without_grad_works(use_compile):

# Should work with torch.no_grad()
with torch.no_grad():
result = addmm_fn(bias, a, b, out=out)
result = addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk)

# In-place path returns None (custom ops don't support aliasing)
assert result is None, "in-place addmm should return None"
Expand All @@ -209,7 +217,8 @@ def test_addmm_inplace_without_grad_works(use_compile):


@pytest.mark.parametrize("use_compile", USE_COMPILE)
def test_addmm_inplace_output_correctness(use_compile):
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
def test_addmm_inplace_output_correctness(enable_streamk, use_compile):
"""Test that addmm in-place mode produces correct results."""
torch.manual_seed(42)
m, n, k = 128, 256, 512
Expand All @@ -225,14 +234,15 @@ def test_addmm_inplace_output_correctness(use_compile):
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

with torch.no_grad():
addmm_fn(bias, a, b, out=out)
addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk)

expected = torch.addmm(bias, a, b)
torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1)


@pytest.mark.parametrize("use_compile", USE_COMPILE)
def test_addmm_no_grad_tensors(use_compile):
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
def test_addmm_no_grad_tensors(enable_streamk, use_compile):
"""Test addmm works when input tensors don't require grad."""
torch.manual_seed(42)
m, n, k = 64, 64, 64
Expand All @@ -246,14 +256,15 @@ def test_addmm_no_grad_tensors(use_compile):
if use_compile:
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

result = addmm_fn(bias, a, b)
result = addmm_fn(bias, a, b, enable_streamk=enable_streamk)
expected = torch.addmm(bias, a, b)

torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1)


@pytest.mark.parametrize("use_compile", USE_COMPILE)
def test_addmm_partial_grad(use_compile):
@pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK)
def test_addmm_partial_grad(enable_streamk, use_compile):
"""Test addmm when only some inputs require grad."""
torch.manual_seed(42)
m, n, k = 64, 64, 64
Expand All @@ -272,45 +283,10 @@ def test_addmm_partial_grad(use_compile):
if use_compile:
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

result = addmm_fn(bias, a, b)
result_ref = torch.addmm(bias_ref, a_ref, b_ref)

result.sum().backward()
result_ref.sum().backward()

torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1)


@pytest.mark.parametrize("use_compile", USE_COMPILE)
@pytest.mark.parametrize("enable_streamk", [False, True])
def test_addmm_streamk_modes(enable_streamk, use_compile):
"""Test addmm with different streamk settings."""
torch.manual_seed(42)
m, n, k = 256, 256, 256
dtype = torch.bfloat16

a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True)
b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True)
bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True)

a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)
bias_ref = bias.detach().clone().requires_grad_(True)

addmm_fn = tritonblas.addmm
if use_compile:
addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True)

# Forward
result = addmm_fn(bias, a, b, enable_streamk=enable_streamk)
result_ref = torch.addmm(bias_ref, a_ref, b_ref)

torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1)

# Backward
result.sum().backward()
result_ref.sum().backward()

torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1)
Comment on lines 286 to 292
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the matmul test, this looks like it now covers what test_addmm_streamk_modes used to, but the forward assertion and the b.grad/bias.grad comparisons from the removed test are missing. If the goal is to validate StreamK behavior end-to-end, please restore the forward assert_close(result, result_ref, ...) and the gradient comparisons for all tensors that require grad in this test.

Copilot uses AI. Check for mistakes.
torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1)
torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1)
Loading
Loading