-
Notifications
You must be signed in to change notification settings - Fork 9
Work Stealing Scheduler Kernel #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds an opt-in per-XCD work-stealing scheduler for the persistent GEMM kernel and wires it into the public matmul API with reusable preallocated buffers.
Changes:
- Introduces a new work-stealing persistent GEMM kernel (
ws_persistent_matmul) using per-XCD atomic tile counters. - Adds
MatmulConfig/matmul_preambleto manage reusable GPU buffers and enableswork_stealingin matmul entrypoints. - Adds standalone smoke/benchmark scripts for work-stealing and CPU→GPU atomic signaling experiments.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_work_stealing.py | Adds a standalone smoke test/benchmark driver for the new work-stealing kernel. |
| tests/test_cpu_gpu_atomics.py | Adds an experimental script for CPU-driven signaling to a Triton “scheduler” kernel. |
| include/tritonblas/matmul.py | Adds MatmulConfig + preamble and integrates work_stealing into matmul call paths. |
| include/tritonblas/kernels/persistent_gemm_work_stealing.py | Implements the per-XCD atomic-counter work-stealing GEMM kernel. |
| include/tritonblas/kernels/init.py | Exposes the new ws_persistent_matmul kernel in the package API. |
| include/tritonblas/init.py | Exports MatmulConfig and matmul_preamble at the top-level package. |
| benchmarks/benchmark_work_stealing.py | Adds a benchmark harness comparing work-stealing vs persistent vs stream-k vs torch. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| return passed | ||
|
|
||
|
|
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytest will collect test_correctness as a test, but it has parameters (m, n, k, dtype) that pytest will treat as fixtures, causing a collection/runtime error. Convert this to a proper parametrized pytest test (e.g., @pytest.mark.parametrize with fixed shapes) or rename it (and the file) to avoid pytest collection if it’s intended to be a standalone smoke script.
| # Mark as helper-only so pytest does not attempt to collect it as a test. | |
| test_correctness.__test__ = False |
tests/test_cpu_gpu_atomics.py
Outdated
| flags_h = hip_check(hip.hipMalloc(sch_grid * sys.getsizeof(live))) | ||
| # Casting flags_h to a typed pointer, for content access | ||
| flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * sch_grid)) | ||
| print(f'Flags (init):') | ||
| for i in range(0, sch_grid): | ||
| flags_typed_ptr.contents[i] = live | ||
| print(f'{flags_typed_ptr.contents[i]}') | ||
|
|
||
| flags_h_np_array = np.ctypeslib.as_array(flags_typed_ptr, shape=(sch_grid,)) | ||
| flags_h_tensor = torch.from_numpy(flags_h_np_array) | ||
|
|
||
| sch_comp = torch.ones(num_xcds * BLOCK_SIZE, device="cuda", dtype=torch.float32).contiguous() | ||
|
|
||
| print(f'Scheduler kernel started') | ||
| with torch.cuda.stream(sch_stream): | ||
| sch[(sch_grid, 1, 1)](flags_h_tensor, sch_comp, num_xcds * BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE_SCH, | ||
| num_warps=NUM_WARPS, num_stages=NUM_STAGES) | ||
|
|
||
| time.sleep(1) |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As written, this file is very likely to break test discovery/environments: (1) it’s under tests/ and imports hip unconditionally (will fail on non-ROCm setups during pytest collection), and (2) it allocates device memory via hipMalloc but then creates a CPU tensor via torch.from_numpy(...) and passes it to a Triton kernel—this does not create a CUDA/HIP device tensor and won’t provide a valid device pointer. Also, sys.getsizeof(live) measures Python object size, not sizeof(int32), so the allocation size is incorrect. Consider moving this to benchmarks/ or examples/, and if you need a GPU-visible flag tensor, allocate it as a CUDA tensor (torch.empty(..., device='cuda', dtype=...)) and update it via appropriate device/host transfer APIs.
| def __init__(self, device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE, | ||
| max_xcds: int = _DEFAULT_MAX_XCDS): | ||
| props = torch.cuda.get_device_properties(device) | ||
| self.device = device | ||
| self.num_sms: int = props.multi_processor_count | ||
| self.max_block_size: int = max_block_size | ||
| self.max_xcds: int = max_xcds |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.cuda.get_device_properties(...) typically expects an int device index or a torch.device, not a string like 'cuda'/'cuda:0'. This can raise at runtime when users call matmul_preamble()/MatmulConfig(...). Use torch.device(device) (and/or torch.cuda.current_device() when device == 'cuda') when querying properties, while keeping the string for tensor allocations if desired.
| self.max_xcds: int = max_xcds | ||
|
|
||
| # Work-stealing per-XCD tile counters | ||
| self.tile_counter = torch.zeros(max_xcds, device=device, dtype=torch.int32) |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The work-stealing kernel indexes tile_counter[xcd_id] with xcd_id = pid % NUM_XCDS, but tile_counter is allocated with length max_xcds (default 16) while NUM_XCDS is passed as num_xcds from the selector. If NUM_XCDS > cfg.tile_counter.numel(), the kernel will do out-of-bounds pointer arithmetic on tile_counter (undefined behavior / crash). Add a hard check (e.g., assert NUM_XCDS <= cfg.max_xcds), or derive NUM_XCDS from cfg.tile_counter.numel() / an explicit config.num_xcds field so the kernel launch and buffer sizing are guaranteed consistent.
| self.max_xcds: int = max_xcds | |
| # Work-stealing per-XCD tile counters | |
| self.tile_counter = torch.zeros(max_xcds, device=device, dtype=torch.int32) | |
| # Ensure the tile_counter buffer is never smaller than the maximum | |
| # XCD count that kernels/selectors may assume (default headroom). | |
| self.max_xcds: int = max(max_xcds, _DEFAULT_MAX_XCDS) | |
| # Work-stealing per-XCD tile counters | |
| self.tile_counter = torch.zeros(self.max_xcds, device=device, dtype=torch.int32) |
| if work_stealing: | ||
| # Work-stealing: launch grid = num CUs, tiles assigned dynamically | ||
| # via per-XCD atomic counters. | ||
| grids = cfg.num_sms | ||
|
|
||
| # Reset all per-XCD tile counters before each launch. | ||
| cfg.reset_tile_counter() | ||
|
|
||
| kk = ws_persistent_matmul[(grids,)]( | ||
| a, | ||
| b, | ||
| c, | ||
| a_scale if quantized else None, # A_scale_ptr | ||
| b_scale if quantized else None, # B_scale_ptr | ||
| None, # TODO: Enable bias. | ||
| cfg.tile_counter, # Per-XCD tile counters (int32[max_xcds]) | ||
| M, | ||
| N, | ||
| K, | ||
| a.stride(0), | ||
| b.stride(1), | ||
| c.stride(0), | ||
| c.stride(1), | ||
| 0, # TODO: Enable bias stride. | ||
| stride_ak=a.stride(1), | ||
| stride_bk=b.stride(0), | ||
| BLOCK_SIZE_M=BLK_M, | ||
| BLOCK_SIZE_N=BLK_N, | ||
| BLOCK_SIZE_K=BLK_K, | ||
| GROUP_SIZE_M=gsize_m, | ||
| NUM_SMS=grids, | ||
| NUM_XCDS=num_xcds, |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The work-stealing kernel indexes tile_counter[xcd_id] with xcd_id = pid % NUM_XCDS, but tile_counter is allocated with length max_xcds (default 16) while NUM_XCDS is passed as num_xcds from the selector. If NUM_XCDS > cfg.tile_counter.numel(), the kernel will do out-of-bounds pointer arithmetic on tile_counter (undefined behavior / crash). Add a hard check (e.g., assert NUM_XCDS <= cfg.max_xcds), or derive NUM_XCDS from cfg.tile_counter.numel() / an explicit config.num_xcds field so the kernel launch and buffer sizing are guaranteed consistent.
| rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
| rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
| rk = tl.arange(0, BLOCK_SIZE_K) | ||
| rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) | ||
| rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak | ||
| B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using % M / % N wraps out-of-range lanes back into valid indices. That makes rm < M and rn < N always true, causing edge tiles (when M or N aren’t multiples of the block sizes) to write wrapped results into the wrong rows/cols (overwriting valid outputs). Remove the modulo wrapping for rm/rn and use proper masks for loads/stores based on the non-wrapped offsets.
| rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
| rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
| rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) | ||
| rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
| c_mask = (rm[:, None] < M) & (rn[None, :] < N) | ||
| C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn | ||
| tl.store(C_, c, c_mask) |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using % M / % N wraps out-of-range lanes back into valid indices. That makes rm < M and rn < N always true, causing edge tiles (when M or N aren’t multiples of the block sizes) to write wrapped results into the wrong rows/cols (overwriting valid outputs). Remove the modulo wrapping for rm/rn and use proper masks for loads/stores based on the non-wrapped offsets.
| loop_k = tl.cdiv(K, BLOCK_SIZE_K) | ||
| if not EVEN_K: | ||
| loop_k -= 1 | ||
| tl.assume(loop_k > 1) |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tl.assume(loop_k > 1) is not generally true and can be false for small-K cases (e.g., K <= BLOCK_SIZE_K gives loop_k == 1, and with EVEN_K == False it becomes 0). Incorrect tl.assume statements can lead to miscompilation/incorrect codegen. Remove this assumption or replace it with a condition that always holds (or restructure the loop logic so assumptions match the valid range).
| tl.assume(loop_k > 1) |
No description provided.