Skip to content

Conversation

@ryanswann-amd
Copy link
Collaborator

@ryanswann-amd ryanswann-amd commented Feb 9, 2026

Motivation

The StreamK GEMM kernel was missing bias addition in the tile aggregation section. When multiple processing elements (PEs) contribute partial results to a single output tile, the bias was not being applied after aggregation. This caused incorrect addmm results when using StreamK mode with bias tensors.

Technical Details

Modified include/tritonblas/kernels/streamk_gemm.py to add bias handling in the Stream-K tile aggregation section. After aggregating accumulator quadrants (acc00, acc01, acc10, acc11) from multiple PEs, bias is now loaded for left and right halves of the N dimension, reshaped for broadcasting (1, N//2), and added to each quadrant. The fix handles both quantized (bias converted to float32 before adding) and non-quantized modes (bias added directly).

Test Plan

Improve StreamK Testing

Test Result

===== 64 failed, 196 passed, 15 warnings in 170.99s (0:02:50) =====
The 64 failed are due to globally allocated locks breaking torch compile which will be addressed in another PR

Submission Checklist

Copilot AI review requested due to automatic review settings February 9, 2026 22:20
@ryanswann-amd ryanswann-amd changed the title Ryaswann/fix sk bias Fix StreamK Kernel Bias Feb 9, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates StreamK behavior to fix bias application in the StreamK GEMM kernel and expands correctness tests to exercise both Persistent and StreamK paths (including under torch.compile).

Changes:

  • Add enable_streamk parametrization to matmul/addmm correctness tests.
  • Fix StreamK GEMM bias indexing/broadcast to apply bias along the N (columns) dimension.
  • Remove now-redundant dedicated “streamk modes” tests by folding coverage into existing parametrized tests.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
tests/test_matmul_correctness.py Parametrize tests with enable_streamk and pass flag into tritonblas.matmul.
tests/test_addmm_correctness.py Parametrize tests with enable_streamk and pass flag into tritonblas.addmm.
include/tritonblas/kernels/streamk_gemm.py Correct bias loads/broadcasting to be column-wise (N dimension) in StreamK GEMM.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +64 to +65
# Whether to enable StreamK (vs. Persistent path)
ENABLE_STREAMK = [False, True]
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ENABLE_STREAMK is phrased differently from the existing USE_COMPILE flag. For consistency/readability in the test matrix, consider renaming this to something like USE_STREAMK (or USE_ENABLE_STREAMK) so the parametrization flags follow the same convention.

Copilot uses AI. Check for mistakes.
Comment on lines 272 to 278
result = matmul_fn(a, b, enable_streamk=enable_streamk)
result_ref = torch.mm(a_ref, b_ref)

torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1)

# Backward
result.sum().backward()
result_ref.sum().backward()

torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1)
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block appears to have absorbed the old test_matmul_streamk_modes, but it no longer asserts forward correctness (result vs result_ref) and it also dropped the b.grad comparison that the removed test previously covered. If this test is intended to validate StreamK mode correctness, please re-add the forward assert_close(result, result_ref, ...) and (when b requires grad in this test) restore the b.grad assertion as well; otherwise, split the concerns back into separate tests so partial-grad semantics and StreamK-mode semantics are both explicitly validated.

Copilot uses AI. Check for mistakes.
Comment on lines 286 to 292
result = addmm_fn(bias, a, b, enable_streamk=enable_streamk)
result_ref = torch.addmm(bias_ref, a_ref, b_ref)

torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1)

# Backward
result.sum().backward()
result_ref.sum().backward()

torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1)
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the matmul test, this looks like it now covers what test_addmm_streamk_modes used to, but the forward assertion and the b.grad/bias.grad comparisons from the removed test are missing. If the goal is to validate StreamK behavior end-to-end, please restore the forward assert_close(result, result_ref, ...) and the gradient comparisons for all tensors that require grad in this test.

Copilot uses AI. Check for mistakes.
Comment on lines 324 to +336
# Unified bias handling for Stream-K section
# Bias is applied along N dimension (columns), broadcast across M (rows)
if BIAS:
# Split bias for top and bottom halves
bias_top = bias[:BLOCK_SIZE_M // 2]
bias_bottom = bias[BLOCK_SIZE_M // 2:]
# Load bias for left and right halves (N dimension) with explicit indices
# Apply modulo N to match the behavior of the original rn computation
rn_bias_left = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N // 2)) % N
rn_bias_right = (pid_n * BLOCK_SIZE_N + BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2)) % N

bias_top_reshaped = tl.reshape(bias_top, (BLOCK_SIZE_M // 2, 1))
bias_bottom_reshaped = tl.reshape(bias_bottom, (BLOCK_SIZE_M // 2, 1))
bias_left = tl.load(bias_ptr + rn_bias_left * stride_bias, mask=rn_bias_left < N, other=0.0)
bias_right = tl.load(bias_ptr + rn_bias_right * stride_bias, mask=rn_bias_right < N, other=0.0)

# Reshape for broadcasting: (1, N//2) to broadcast across M rows
bias_left_reshaped = tl.reshape(bias_left, (1, BLOCK_SIZE_N // 2))
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this Stream-K section, bias is (now correctly) loaded earlier as a full BLOCK_SIZE_N vector, but this block re-loads bias twice (bias_left/bias_right) and also introduces % N plus a redundant mask (after modulo, indices are already in-range for N>0). To reduce global memory traffic and simplify the indexing, consider reusing the already-loaded bias vector and slicing/splitting it into left/right halves for broadcasting, rather than recomputing indices and issuing additional tl.load operations.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@asunderwood asunderwood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - my other PR (#66) should (?) fix the compile issues. The actual issue - the StreamK bias support - is fixed here.

@ryanswann-amd ryanswann-amd merged commit a0f8292 into main Feb 9, 2026
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants