-
Notifications
You must be signed in to change notification settings - Fork 9
Work-Stealing-based Persistent Kernel #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces a work-stealing-based persistent GEMM kernel that dynamically allocates tile IDs across compute units instead of using fixed partitioning. The implementation uses per-XCD (chiplet) atomic counters to reduce contention compared to global atomic operations. The work-stealing kernel is exposed as an opt-in feature through a new work_stealing parameter in the matmul APIs.
Changes:
- Added
MatmulConfigclass to pre-allocate and manage GPU buffers for kernel launches (tile counters, stream-K locks/partials) - Implemented work-stealing kernel with per-XCD atomic tile counters in
persistent_gemm_work_stealing.py - Extended all matmul APIs with optional
work_stealingandconfigparameters to support the new kernel
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 22 comments.
Show a summary per file
| File | Description |
|---|---|
include/tritonblas/matmul.py |
Added MatmulConfig class for buffer management; integrated work_stealing parameter and ws_persistent_matmul kernel; refactored buffer allocation to use config objects |
include/tritonblas/kernels/persistent_gemm_work_stealing.py |
New work-stealing kernel implementation with per-XCD atomic counters and dynamic tile assignment |
include/tritonblas/kernels/__init__.py |
Exported ws_persistent_matmul kernel |
include/tritonblas/__init__.py |
Exported MatmulConfig and matmul_preamble to public API |
tests/test_work_stealing.py |
Standalone test with custom module loading to test work-stealing kernel correctness and performance |
benchmarks/benchmark_work_stealing.py |
Comprehensive benchmark comparing work-stealing against static persistent, stream-K, and torch.matmul |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| config: Optional[MatmulConfig] = None, | ||
| ): | ||
| cfg = config or _get_default_config() | ||
|
|
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The work_stealing parameter is added to persistent_matmul_lt but not to streamk_matmul_lt, creating an inconsistent API. If work-stealing is mutually exclusive with stream-K, this should be documented. If they could potentially be combined in the future, the parameter should be present (even if ignored) for API consistency. Consider either: (1) adding a comment explaining why work_stealing is only available for persistent matmul, or (2) adding the parameter to streamk_matmul_lt with a check that raises an error if both are enabled.
| config: Optional[MatmulConfig] = None, | |
| ): | |
| cfg = config or _get_default_config() | |
| config: Optional[MatmulConfig] = None, | |
| work_stealing: bool = False, | |
| ): | |
| cfg = config or _get_default_config() | |
| # Work-stealing and Stream-K are mutually exclusive. Expose the parameter | |
| # here for API consistency with persistent_matmul_lt but disallow enabling | |
| # it for the Stream-K matmul path. | |
| if work_stealing: | |
| raise ValueError( | |
| "work_stealing is not supported for streamk_matmul_lt; " | |
| "work-stealing is only available for persistent matmul." | |
| ) |
| def reset(self): | ||
| """Reset all mutable state (tile counter + stream-K buffers).""" | ||
| self.reset_tile_counter() | ||
| self.reset_streamk() | ||
|
|
||
| def reset_tile_counter(self): | ||
| """Zero the work-stealing tile counter.""" | ||
| self.tile_counter.zero_() | ||
|
|
||
| def reset_streamk(self): | ||
| """Zero the stream-K locks (P is overwritten, no need to clear).""" | ||
| self.locks.zero_() |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When reusing config buffers across multiple kernel launches, there's a potential race condition if multiple matmul operations are launched asynchronously on different CUDA streams without proper synchronization. The reset_tile_counter() and reset_streamk() methods modify shared buffers that could be in use by a previous kernel. Consider documenting that MatmulConfig instances should not be shared across concurrent operations, or add stream-aware synchronization.
| def _get_default_config() -> MatmulConfig: | ||
| global _default_config | ||
| if _default_config is None: | ||
| _default_config = matmul_preamble() | ||
| return _default_config |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The lazy initialization of _default_config is not thread-safe. In a multi-threaded environment, multiple threads could simultaneously check if _default_config is None and all attempt to create a new config, potentially leading to race conditions, wasted GPU memory allocation, or inconsistent state. Consider using a lock or threading.Lock() to protect the initialization, or use a thread-safe lazy initialization pattern like functools.lru_cache with a sentinel value.
| def main(): | ||
| torch.manual_seed(42) | ||
| device = torch.cuda.current_device() | ||
| props = torch.cuda.get_device_properties(device) | ||
| print(f"Device: {props.name} (CUs: {props.multi_processor_count})") | ||
| print(f"HIP_VISIBLE_DEVICES = {os.environ.get('HIP_VISIBLE_DEVICES', '<not set>')}") | ||
| print(f"Per-XCD counters: {NUM_XCDS}") | ||
| print() | ||
|
|
||
| # ── Correctness ─────────────────────────────────────────────────── | ||
| print("=" * 68) | ||
| print("Correctness (per-XCD work-stealing kernel vs torch.matmul)") | ||
| print("=" * 68) | ||
| all_pass = True | ||
| for m, n, k in [ | ||
| (256, 256, 256), | ||
| (512, 512, 512), | ||
| (1024, 1024, 1024), | ||
| (2048, 2048, 2048), | ||
| (4096, 4096, 4096), | ||
| (8192, 8192, 8192), | ||
| ]: | ||
| try: | ||
| ok = test_correctness(m, n, k) | ||
| all_pass &= ok | ||
| except Exception as e: | ||
| print(f" [ERROR] {m}x{n}x{k}: {e}") | ||
| import traceback; traceback.print_exc() | ||
| all_pass = False | ||
|
|
||
| # ── Throughput ──────────────────────────────────────────────────── | ||
| print() | ||
| print("=" * 68) | ||
| print("Throughput (per-XCD work-stealing kernel)") | ||
| print("=" * 68) | ||
| for m, n, k in [ | ||
| (1024, 1024, 1024), | ||
| (4096, 4096, 4096), | ||
| (8192, 8192, 8192), | ||
| ]: | ||
| try: | ||
| bench_throughput(m, n, k) | ||
| except Exception as e: | ||
| print(f" [ERROR] {m}x{n}x{k}: {e}") | ||
| import traceback; traceback.print_exc() | ||
|
|
||
| print() | ||
| if all_pass: | ||
| print("All correctness tests PASSED.") | ||
| else: | ||
| print("Some correctness tests FAILED.") | ||
| sys.exit(1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test file uses a standalone main() function with manual test execution instead of following pytest conventions. Other test files in the codebase (e.g., tests/test_matmul.py, tests/test_matmul_lt.py) use @pytest.mark.parametrize decorators and pytest test discovery. This inconsistency makes it harder to integrate these tests into the test suite and prevents selective test execution. Consider refactoring to use proper pytest fixtures and parametrization like the existing tests.
| for _ in range(tiles_per_xcd): | ||
| if local_tile_idx < tiles_this_xcd: | ||
| # Map local index → global tile_id | ||
| tile_id = xcd_base + local_tile_idx | ||
|
|
||
| # GROUP_SIZE_M swizzle → (pid_m, pid_n) | ||
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | ||
| group_id = tile_id // num_pid_in_group | ||
| first_pid_m = group_id * GROUP_SIZE_M | ||
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | ||
| pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) | ||
| pid_n = (tile_id % num_pid_in_group) // group_size_m | ||
| tl.assume(pid_m >= 0) | ||
| tl.assume(pid_n >= 0) | ||
|
|
||
| rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
| rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
| rk = tl.arange(0, BLOCK_SIZE_K) | ||
| rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) | ||
| rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak | ||
| B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn | ||
|
|
||
| if BIAS: | ||
| bias_ = bias_ptr + rm * stride_bias | ||
| bias = tl.load(bias_, mask=rm < M, other=0.0) | ||
|
|
||
| loop_k = tl.cdiv(K, BLOCK_SIZE_K) | ||
| if not EVEN_K: | ||
| loop_k -= 1 | ||
| tl.assume(loop_k > 1) | ||
|
|
||
| acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) | ||
| for k in range(0, loop_k): | ||
| if stride_ak == 1: | ||
| a = tl.load(tl.multiple_of(A_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_A) | ||
| else: | ||
| a = tl.load(tl.multiple_of(A_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_A) | ||
|
|
||
| if stride_bk == 1: | ||
| b = tl.load(tl.multiple_of(B_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_B) | ||
| else: | ||
| b = tl.load(tl.multiple_of(B_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_B) | ||
|
|
||
| # Conditional dot product precision based on quantization mode | ||
| if QUANTIZED: | ||
| acc += tl.dot(a, b, input_precision="ieee") | ||
| else: | ||
| acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | ||
| A_BASE += BLOCK_SIZE_K * stride_ak | ||
| B_BASE += BLOCK_SIZE_K * stride_bk | ||
|
|
||
| if not EVEN_K: | ||
| k = loop_k | ||
| rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) | ||
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak | ||
| B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn | ||
| if stride_ak == 1: | ||
| A_BASE = tl.multiple_of(A_BASE, (1, 16)) | ||
| else: | ||
| A_BASE = tl.multiple_of(A_BASE, (16, 1)) | ||
|
|
||
| if stride_bk == 1: | ||
| B_BASE = tl.multiple_of(B_BASE, (16, 1)) | ||
| else: | ||
| B_BASE = tl.multiple_of(B_BASE, (1, 16)) | ||
| a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0, cache_modifier=CACHE_MODIFIER_A) | ||
| b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0, cache_modifier=CACHE_MODIFIER_B) | ||
|
|
||
| if QUANTIZED: | ||
| acc += tl.dot(a, b, input_precision="ieee") | ||
| else: | ||
| acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | ||
|
|
||
| # Conditional scaling for quantized mode | ||
| if QUANTIZED: | ||
| # Create pointers for the scale tensors and load them | ||
| rm_A_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) % M | ||
| rn_B_scale = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N | ||
| A_scale = tl.load(A_scale_ptr + rm_A_scale) | ||
| B_scale = tl.load(B_scale_ptr + rn_B_scale) | ||
| acc *= A_scale[:, None] * B_scale[None, :] | ||
|
|
||
| # Unified bias handling | ||
| if BIAS: | ||
| if QUANTIZED: | ||
| bias_float = bias.to(tl.float32) | ||
| c = acc + bias_float[:, None] | ||
| c = c.to(C.type.element_ty) | ||
| else: | ||
| c = acc.to(C.type.element_ty) | ||
| c += bias[:, None] | ||
| else: | ||
| c = acc.to(C.type.element_ty) | ||
|
|
||
| rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
| rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
| rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) | ||
| rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
| c_mask = (rm[:, None] < M) & (rn[None, :] < N) | ||
| C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn | ||
| tl.store(C_, c, c_mask) | ||
|
|
||
| # Grab next tile from this XCD's counter | ||
| local_tile_idx = tl.atomic_add(counter_ptr, 1) |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop for _ in range(tiles_per_xcd) iterates exactly tiles_per_xcd times, but the actual number of tiles available for this XCD could be less (stored in tiles_this_xcd). While the if local_tile_idx < tiles_this_xcd check prevents processing invalid tiles, the loop continues to make unnecessary atomic increments even after all valid tiles are exhausted. This could cause contention and performance degradation. Consider using for _ in range(tiles_this_xcd) or an early break when local_tile_idx >= tiles_this_xcd.
| # Reset all per-XCD tile counters before each launch. | ||
| cfg.reset_tile_counter() | ||
|
|
||
| kk = ws_persistent_matmul[(grids,)]( |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable kk is not used.
| kk = ws_persistent_matmul[(grids,)]( | |
| ws_persistent_matmul[(grids,)]( |
| grids = total_tiles | ||
|
|
||
| # TODO: Support other matmul algs. | ||
| kk = persistent_matmul[(grids,)]( |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable kk is not used.
|
|
||
| import os | ||
| import sys | ||
| import time |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'time' is not used.
| import time |
| import importlib.util | ||
| import torch | ||
| import triton | ||
| import triton.language as tl |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'tl' is not used.
| import triton.language as tl |
| import types | ||
| import importlib.util | ||
| import torch | ||
| import triton |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'triton' is not used.
| import triton |
Motivation
Dynamically take away tile ids instead of fixed partitioning.