-
Notifications
You must be signed in to change notification settings - Fork 9
Fix StreamK Kernel Bias #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Updates StreamK behavior to fix bias application in the StreamK GEMM kernel and expands correctness tests to exercise both Persistent and StreamK paths (including under torch.compile).
Changes:
- Add
enable_streamkparametrization to matmul/addmm correctness tests. - Fix StreamK GEMM bias indexing/broadcast to apply bias along the N (columns) dimension.
- Remove now-redundant dedicated “streamk modes” tests by folding coverage into existing parametrized tests.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| tests/test_matmul_correctness.py | Parametrize tests with enable_streamk and pass flag into tritonblas.matmul. |
| tests/test_addmm_correctness.py | Parametrize tests with enable_streamk and pass flag into tritonblas.addmm. |
| include/tritonblas/kernels/streamk_gemm.py | Correct bias loads/broadcasting to be column-wise (N dimension) in StreamK GEMM. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Whether to enable StreamK (vs. Persistent path) | ||
| ENABLE_STREAMK = [False, True] |
Copilot
AI
Feb 9, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ENABLE_STREAMK is phrased differently from the existing USE_COMPILE flag. For consistency/readability in the test matrix, consider renaming this to something like USE_STREAMK (or USE_ENABLE_STREAMK) so the parametrization flags follow the same convention.
| result = matmul_fn(a, b, enable_streamk=enable_streamk) | ||
| result_ref = torch.mm(a_ref, b_ref) | ||
|
|
||
| torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) | ||
|
|
||
| # Backward | ||
| result.sum().backward() | ||
| result_ref.sum().backward() | ||
|
|
||
| torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) |
Copilot
AI
Feb 9, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block appears to have absorbed the old test_matmul_streamk_modes, but it no longer asserts forward correctness (result vs result_ref) and it also dropped the b.grad comparison that the removed test previously covered. If this test is intended to validate StreamK mode correctness, please re-add the forward assert_close(result, result_ref, ...) and (when b requires grad in this test) restore the b.grad assertion as well; otherwise, split the concerns back into separate tests so partial-grad semantics and StreamK-mode semantics are both explicitly validated.
| result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) | ||
| result_ref = torch.addmm(bias_ref, a_ref, b_ref) | ||
|
|
||
| torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) | ||
|
|
||
| # Backward | ||
| result.sum().backward() | ||
| result_ref.sum().backward() | ||
|
|
||
| torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) |
Copilot
AI
Feb 9, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the matmul test, this looks like it now covers what test_addmm_streamk_modes used to, but the forward assertion and the b.grad/bias.grad comparisons from the removed test are missing. If the goal is to validate StreamK behavior end-to-end, please restore the forward assert_close(result, result_ref, ...) and the gradient comparisons for all tensors that require grad in this test.
| # Unified bias handling for Stream-K section | ||
| # Bias is applied along N dimension (columns), broadcast across M (rows) | ||
| if BIAS: | ||
| # Split bias for top and bottom halves | ||
| bias_top = bias[:BLOCK_SIZE_M // 2] | ||
| bias_bottom = bias[BLOCK_SIZE_M // 2:] | ||
| # Load bias for left and right halves (N dimension) with explicit indices | ||
| # Apply modulo N to match the behavior of the original rn computation | ||
| rn_bias_left = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N // 2)) % N | ||
| rn_bias_right = (pid_n * BLOCK_SIZE_N + BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2)) % N | ||
|
|
||
| bias_top_reshaped = tl.reshape(bias_top, (BLOCK_SIZE_M // 2, 1)) | ||
| bias_bottom_reshaped = tl.reshape(bias_bottom, (BLOCK_SIZE_M // 2, 1)) | ||
| bias_left = tl.load(bias_ptr + rn_bias_left * stride_bias, mask=rn_bias_left < N, other=0.0) | ||
| bias_right = tl.load(bias_ptr + rn_bias_right * stride_bias, mask=rn_bias_right < N, other=0.0) | ||
|
|
||
| # Reshape for broadcasting: (1, N//2) to broadcast across M rows | ||
| bias_left_reshaped = tl.reshape(bias_left, (1, BLOCK_SIZE_N // 2)) |
Copilot
AI
Feb 9, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this Stream-K section, bias is (now correctly) loaded earlier as a full BLOCK_SIZE_N vector, but this block re-loads bias twice (bias_left/bias_right) and also introduces % N plus a redundant mask (after modulo, indices are already in-range for N>0). To reduce global memory traffic and simplify the indexing, consider reusing the already-loaded bias vector and slicing/splitting it into left/right halves for broadcasting, rather than recomputing indices and issuing additional tl.load operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - my other PR (#66) should (?) fix the compile issues. The actual issue - the StreamK bias support - is fixed here.
Motivation
The StreamK GEMM kernel was missing bias addition in the tile aggregation section. When multiple processing elements (PEs) contribute partial results to a single output tile, the bias was not being applied after aggregation. This caused incorrect
addmmresults when using StreamK mode with bias tensors.Technical Details
Modified
include/tritonblas/kernels/streamk_gemm.pyto add bias handling in the Stream-K tile aggregation section. After aggregating accumulator quadrants (acc00, acc01, acc10, acc11) from multiple PEs, bias is now loaded for left and right halves of the N dimension, reshaped for broadcasting (1, N//2), and added to each quadrant. The fix handles both quantized (bias converted to float32 before adding) and non-quantized modes (bias added directly).Test Plan
Improve StreamK Testing
Test Result
===== 64 failed, 196 passed, 15 warnings in 170.99s (0:02:50) =====
The 64 failed are due to globally allocated locks breaking torch compile which will be addressed in another PR
Submission Checklist