Skip to content

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Feb 5, 2026

Motivation

Dynamically take away tile ids instead of fixed partitioning.

Copilot AI review requested due to automatic review settings February 5, 2026 20:34
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a work-stealing-based persistent GEMM kernel that dynamically allocates tile IDs across compute units instead of using fixed partitioning. The implementation uses per-XCD (chiplet) atomic counters to reduce contention compared to global atomic operations. The work-stealing kernel is exposed as an opt-in feature through a new work_stealing parameter in the matmul APIs.

Changes:

  • Added MatmulConfig class to pre-allocate and manage GPU buffers for kernel launches (tile counters, stream-K locks/partials)
  • Implemented work-stealing kernel with per-XCD atomic tile counters in persistent_gemm_work_stealing.py
  • Extended all matmul APIs with optional work_stealing and config parameters to support the new kernel

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 22 comments.

Show a summary per file
File Description
include/tritonblas/matmul.py Added MatmulConfig class for buffer management; integrated work_stealing parameter and ws_persistent_matmul kernel; refactored buffer allocation to use config objects
include/tritonblas/kernels/persistent_gemm_work_stealing.py New work-stealing kernel implementation with per-XCD atomic counters and dynamic tile assignment
include/tritonblas/kernels/__init__.py Exported ws_persistent_matmul kernel
include/tritonblas/__init__.py Exported MatmulConfig and matmul_preamble to public API
tests/test_work_stealing.py Standalone test with custom module loading to test work-stealing kernel correctness and performance
benchmarks/benchmark_work_stealing.py Comprehensive benchmark comparing work-stealing against static persistent, stream-K, and torch.matmul

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +271 to +274
config: Optional[MatmulConfig] = None,
):
cfg = config or _get_default_config()

Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The work_stealing parameter is added to persistent_matmul_lt but not to streamk_matmul_lt, creating an inconsistent API. If work-stealing is mutually exclusive with stream-K, this should be documented. If they could potentially be combined in the future, the parameter should be present (even if ignored) for API consistency. Consider either: (1) adding a comment explaining why work_stealing is only available for persistent matmul, or (2) adding the parameter to streamk_matmul_lt with a check that raises an error if both are enabled.

Suggested change
config: Optional[MatmulConfig] = None,
):
cfg = config or _get_default_config()
config: Optional[MatmulConfig] = None,
work_stealing: bool = False,
):
cfg = config or _get_default_config()
# Work-stealing and Stream-K are mutually exclusive. Expose the parameter
# here for API consistency with persistent_matmul_lt but disallow enabling
# it for the Stream-K matmul path.
if work_stealing:
raise ValueError(
"work_stealing is not supported for streamk_matmul_lt; "
"work-stealing is only available for persistent matmul."
)

Copilot uses AI. Check for mistakes.
Comment on lines +53 to +64
def reset(self):
"""Reset all mutable state (tile counter + stream-K buffers)."""
self.reset_tile_counter()
self.reset_streamk()

def reset_tile_counter(self):
"""Zero the work-stealing tile counter."""
self.tile_counter.zero_()

def reset_streamk(self):
"""Zero the stream-K locks (P is overwritten, no need to clear)."""
self.locks.zero_()
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When reusing config buffers across multiple kernel launches, there's a potential race condition if multiple matmul operations are launched asynchronously on different CUDA streams without proper synchronization. The reset_tile_counter() and reset_streamk() methods modify shared buffers that could be in use by a previous kernel. Consider documenting that MatmulConfig instances should not be shared across concurrent operations, or add stream-aware synchronization.

Copilot uses AI. Check for mistakes.
Comment on lines +97 to +101
def _get_default_config() -> MatmulConfig:
global _default_config
if _default_config is None:
_default_config = matmul_preamble()
return _default_config
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lazy initialization of _default_config is not thread-safe. In a multi-threaded environment, multiple threads could simultaneously check if _default_config is None and all attempt to create a new config, potentially leading to race conditions, wasted GPU memory allocation, or inconsistent state. Consider using a lock or threading.Lock() to protect the initialization, or use a thread-safe lazy initialization pattern like functools.lru_cache with a sentinel value.

Copilot uses AI. Check for mistakes.
Comment on lines +175 to +230
def main():
torch.manual_seed(42)
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
print(f"Device: {props.name} (CUs: {props.multi_processor_count})")
print(f"HIP_VISIBLE_DEVICES = {os.environ.get('HIP_VISIBLE_DEVICES', '<not set>')}")
print(f"Per-XCD counters: {NUM_XCDS}")
print()

# ── Correctness ───────────────────────────────────────────────────
print("=" * 68)
print("Correctness (per-XCD work-stealing kernel vs torch.matmul)")
print("=" * 68)
all_pass = True
for m, n, k in [
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
(8192, 8192, 8192),
]:
try:
ok = test_correctness(m, n, k)
all_pass &= ok
except Exception as e:
print(f" [ERROR] {m}x{n}x{k}: {e}")
import traceback; traceback.print_exc()
all_pass = False

# ── Throughput ────────────────────────────────────────────────────
print()
print("=" * 68)
print("Throughput (per-XCD work-stealing kernel)")
print("=" * 68)
for m, n, k in [
(1024, 1024, 1024),
(4096, 4096, 4096),
(8192, 8192, 8192),
]:
try:
bench_throughput(m, n, k)
except Exception as e:
print(f" [ERROR] {m}x{n}x{k}: {e}")
import traceback; traceback.print_exc()

print()
if all_pass:
print("All correctness tests PASSED.")
else:
print("Some correctness tests FAILED.")
sys.exit(1)


if __name__ == "__main__":
main()
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test file uses a standalone main() function with manual test execution instead of following pytest conventions. Other test files in the codebase (e.g., tests/test_matmul.py, tests/test_matmul_lt.py) use @pytest.mark.parametrize decorators and pytest test discovery. This inconsistency makes it harder to integrate these tests into the test suite and prevents selective test execution. Consider refactoring to use proper pytest fixtures and parametrization like the existing tests.

Copilot uses AI. Check for mistakes.
Comment on lines +84 to +188
for _ in range(tiles_per_xcd):
if local_tile_idx < tiles_this_xcd:
# Map local index → global tile_id
tile_id = xcd_base + local_tile_idx

# GROUP_SIZE_M swizzle → (pid_m, pid_n)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rk = tl.arange(0, BLOCK_SIZE_K)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn

if BIAS:
bias_ = bias_ptr + rm * stride_bias
bias = tl.load(bias_, mask=rm < M, other=0.0)

loop_k = tl.cdiv(K, BLOCK_SIZE_K)
if not EVEN_K:
loop_k -= 1
tl.assume(loop_k > 1)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, loop_k):
if stride_ak == 1:
a = tl.load(tl.multiple_of(A_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_A)
else:
a = tl.load(tl.multiple_of(A_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_A)

if stride_bk == 1:
b = tl.load(tl.multiple_of(B_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_B)
else:
b = tl.load(tl.multiple_of(B_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_B)

# Conditional dot product precision based on quantization mode
if QUANTIZED:
acc += tl.dot(a, b, input_precision="ieee")
else:
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A_BASE += BLOCK_SIZE_K * stride_ak
B_BASE += BLOCK_SIZE_K * stride_bk

if not EVEN_K:
k = loop_k
rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
if stride_ak == 1:
A_BASE = tl.multiple_of(A_BASE, (1, 16))
else:
A_BASE = tl.multiple_of(A_BASE, (16, 1))

if stride_bk == 1:
B_BASE = tl.multiple_of(B_BASE, (16, 1))
else:
B_BASE = tl.multiple_of(B_BASE, (1, 16))
a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0, cache_modifier=CACHE_MODIFIER_A)
b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0, cache_modifier=CACHE_MODIFIER_B)

if QUANTIZED:
acc += tl.dot(a, b, input_precision="ieee")
else:
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)

# Conditional scaling for quantized mode
if QUANTIZED:
# Create pointers for the scale tensors and load them
rm_A_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) % M
rn_B_scale = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N
A_scale = tl.load(A_scale_ptr + rm_A_scale)
B_scale = tl.load(B_scale_ptr + rn_B_scale)
acc *= A_scale[:, None] * B_scale[None, :]

# Unified bias handling
if BIAS:
if QUANTIZED:
bias_float = bias.to(tl.float32)
c = acc + bias_float[:, None]
c = c.to(C.type.element_ty)
else:
c = acc.to(C.type.element_ty)
c += bias[:, None]
else:
c = acc.to(C.type.element_ty)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
c_mask = (rm[:, None] < M) & (rn[None, :] < N)
C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(C_, c, c_mask)

# Grab next tile from this XCD's counter
local_tile_idx = tl.atomic_add(counter_ptr, 1)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop for _ in range(tiles_per_xcd) iterates exactly tiles_per_xcd times, but the actual number of tiles available for this XCD could be less (stored in tiles_this_xcd). While the if local_tile_idx < tiles_this_xcd check prevents processing invalid tiles, the loop continues to make unnecessary atomic increments even after all valid tiles are exhausted. This could cause contention and performance degradation. Consider using for _ in range(tiles_this_xcd) or an early break when local_tile_idx >= tiles_this_xcd.

Copilot uses AI. Check for mistakes.
# Reset all per-XCD tile counters before each launch.
cfg.reset_tile_counter()

kk = ws_persistent_matmul[(grids,)](
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable kk is not used.

Suggested change
kk = ws_persistent_matmul[(grids,)](
ws_persistent_matmul[(grids,)](

Copilot uses AI. Check for mistakes.
grids = total_tiles

# TODO: Support other matmul algs.
kk = persistent_matmul[(grids,)](
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable kk is not used.

Copilot uses AI. Check for mistakes.

import os
import sys
import time
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'time' is not used.

Suggested change
import time

Copilot uses AI. Check for mistakes.
import importlib.util
import torch
import triton
import triton.language as tl
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'tl' is not used.

Suggested change
import triton.language as tl

Copilot uses AI. Check for mistakes.
import types
import importlib.util
import torch
import triton
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'triton' is not used.

Suggested change
import triton

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant