Skip to content

Conversation

@alexdutu
Copy link

@alexdutu alexdutu commented Feb 6, 2026

No description provided.

@alexdutu alexdutu requested review from Copilot and ryanswann-amd and removed request for Copilot February 6, 2026 05:11
Copilot AI review requested due to automatic review settings February 8, 2026 01:21
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an opt-in per-XCD work-stealing scheduler for the persistent GEMM kernel and wires it into the public matmul API with reusable preallocated buffers.

Changes:

  • Introduces a new work-stealing persistent GEMM kernel (ws_persistent_matmul) using per-XCD atomic tile counters.
  • Adds MatmulConfig / matmul_preamble to manage reusable GPU buffers and enables work_stealing in matmul entrypoints.
  • Adds standalone smoke/benchmark scripts for work-stealing and CPU→GPU atomic signaling experiments.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/test_work_stealing.py Adds a standalone smoke test/benchmark driver for the new work-stealing kernel.
tests/test_cpu_gpu_atomics.py Adds an experimental script for CPU-driven signaling to a Triton “scheduler” kernel.
include/tritonblas/matmul.py Adds MatmulConfig + preamble and integrates work_stealing into matmul call paths.
include/tritonblas/kernels/persistent_gemm_work_stealing.py Implements the per-XCD atomic-counter work-stealing GEMM kernel.
include/tritonblas/kernels/init.py Exposes the new ws_persistent_matmul kernel in the package API.
include/tritonblas/init.py Exports MatmulConfig and matmul_preamble at the top-level package.
benchmarks/benchmark_work_stealing.py Adds a benchmark harness comparing work-stealing vs persistent vs stream-k vs torch.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

)
return passed


Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest will collect test_correctness as a test, but it has parameters (m, n, k, dtype) that pytest will treat as fixtures, causing a collection/runtime error. Convert this to a proper parametrized pytest test (e.g., @pytest.mark.parametrize with fixed shapes) or rename it (and the file) to avoid pytest collection if it’s intended to be a standalone smoke script.

Suggested change
# Mark as helper-only so pytest does not attempt to collect it as a test.
test_correctness.__test__ = False

Copilot uses AI. Check for mistakes.
Comment on lines 135 to 153
flags_h = hip_check(hip.hipMalloc(sch_grid * sys.getsizeof(live)))
# Casting flags_h to a typed pointer, for content access
flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * sch_grid))
print(f'Flags (init):')
for i in range(0, sch_grid):
flags_typed_ptr.contents[i] = live
print(f'{flags_typed_ptr.contents[i]}')

flags_h_np_array = np.ctypeslib.as_array(flags_typed_ptr, shape=(sch_grid,))
flags_h_tensor = torch.from_numpy(flags_h_np_array)

sch_comp = torch.ones(num_xcds * BLOCK_SIZE, device="cuda", dtype=torch.float32).contiguous()

print(f'Scheduler kernel started')
with torch.cuda.stream(sch_stream):
sch[(sch_grid, 1, 1)](flags_h_tensor, sch_comp, num_xcds * BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE_SCH,
num_warps=NUM_WARPS, num_stages=NUM_STAGES)

time.sleep(1)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written, this file is very likely to break test discovery/environments: (1) it’s under tests/ and imports hip unconditionally (will fail on non-ROCm setups during pytest collection), and (2) it allocates device memory via hipMalloc but then creates a CPU tensor via torch.from_numpy(...) and passes it to a Triton kernel—this does not create a CUDA/HIP device tensor and won’t provide a valid device pointer. Also, sys.getsizeof(live) measures Python object size, not sizeof(int32), so the allocation size is incorrect. Consider moving this to benchmarks/ or examples/, and if you need a GPU-visible flag tensor, allocate it as a CUDA tensor (torch.empty(..., device='cuda', dtype=...)) and update it via appropriate device/host transfer APIs.

Copilot uses AI. Check for mistakes.
Comment on lines +35 to +41
def __init__(self, device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE,
max_xcds: int = _DEFAULT_MAX_XCDS):
props = torch.cuda.get_device_properties(device)
self.device = device
self.num_sms: int = props.multi_processor_count
self.max_block_size: int = max_block_size
self.max_xcds: int = max_xcds
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.get_device_properties(...) typically expects an int device index or a torch.device, not a string like 'cuda'/'cuda:0'. This can raise at runtime when users call matmul_preamble()/MatmulConfig(...). Use torch.device(device) (and/or torch.cuda.current_device() when device == 'cuda') when querying properties, while keeping the string for tensor allocations if desired.

Copilot uses AI. Check for mistakes.
Comment on lines +41 to +44
self.max_xcds: int = max_xcds

# Work-stealing per-XCD tile counters
self.tile_counter = torch.zeros(max_xcds, device=device, dtype=torch.int32)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The work-stealing kernel indexes tile_counter[xcd_id] with xcd_id = pid % NUM_XCDS, but tile_counter is allocated with length max_xcds (default 16) while NUM_XCDS is passed as num_xcds from the selector. If NUM_XCDS > cfg.tile_counter.numel(), the kernel will do out-of-bounds pointer arithmetic on tile_counter (undefined behavior / crash). Add a hard check (e.g., assert NUM_XCDS <= cfg.max_xcds), or derive NUM_XCDS from cfg.tile_counter.numel() / an explicit config.num_xcds field so the kernel launch and buffer sizing are guaranteed consistent.

Suggested change
self.max_xcds: int = max_xcds
# Work-stealing per-XCD tile counters
self.tile_counter = torch.zeros(max_xcds, device=device, dtype=torch.int32)
# Ensure the tile_counter buffer is never smaller than the maximum
# XCD count that kernels/selectors may assume (default headroom).
self.max_xcds: int = max(max_xcds, _DEFAULT_MAX_XCDS)
# Work-stealing per-XCD tile counters
self.tile_counter = torch.zeros(self.max_xcds, device=device, dtype=torch.int32)

Copilot uses AI. Check for mistakes.
Comment on lines +176 to +207
if work_stealing:
# Work-stealing: launch grid = num CUs, tiles assigned dynamically
# via per-XCD atomic counters.
grids = cfg.num_sms

# Reset all per-XCD tile counters before each launch.
cfg.reset_tile_counter()

kk = ws_persistent_matmul[(grids,)](
a,
b,
c,
a_scale if quantized else None, # A_scale_ptr
b_scale if quantized else None, # B_scale_ptr
None, # TODO: Enable bias.
cfg.tile_counter, # Per-XCD tile counters (int32[max_xcds])
M,
N,
K,
a.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
0, # TODO: Enable bias stride.
stride_ak=a.stride(1),
stride_bk=b.stride(0),
BLOCK_SIZE_M=BLK_M,
BLOCK_SIZE_N=BLK_N,
BLOCK_SIZE_K=BLK_K,
GROUP_SIZE_M=gsize_m,
NUM_SMS=grids,
NUM_XCDS=num_xcds,
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The work-stealing kernel indexes tile_counter[xcd_id] with xcd_id = pid % NUM_XCDS, but tile_counter is allocated with length max_xcds (default 16) while NUM_XCDS is passed as num_xcds from the selector. If NUM_XCDS > cfg.tile_counter.numel(), the kernel will do out-of-bounds pointer arithmetic on tile_counter (undefined behavior / crash). Add a hard check (e.g., assert NUM_XCDS <= cfg.max_xcds), or derive NUM_XCDS from cfg.tile_counter.numel() / an explicit config.num_xcds field so the kernel launch and buffer sizing are guaranteed consistent.

Copilot uses AI. Check for mistakes.
Comment on lines +99 to +105
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rk = tl.arange(0, BLOCK_SIZE_K)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using % M / % N wraps out-of-range lanes back into valid indices. That makes rm < M and rn < N always true, causing edge tiles (when M or N aren’t multiples of the block sizes) to write wrapped results into the wrong rows/cols (overwriting valid outputs). Remove the modulo wrapping for rm/rn and use proper masks for loads/stores based on the non-wrapped offsets.

Copilot uses AI. Check for mistakes.
Comment on lines +179 to +185
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
c_mask = (rm[:, None] < M) & (rn[None, :] < N)
C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(C_, c, c_mask)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using % M / % N wraps out-of-range lanes back into valid indices. That makes rm < M and rn < N always true, causing edge tiles (when M or N aren’t multiples of the block sizes) to write wrapped results into the wrong rows/cols (overwriting valid outputs). Remove the modulo wrapping for rm/rn and use proper masks for loads/stores based on the non-wrapped offsets.

Copilot uses AI. Check for mistakes.
loop_k = tl.cdiv(K, BLOCK_SIZE_K)
if not EVEN_K:
loop_k -= 1
tl.assume(loop_k > 1)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.assume(loop_k > 1) is not generally true and can be false for small-K cases (e.g., K <= BLOCK_SIZE_K gives loop_k == 1, and with EVEN_K == False it becomes 0). Incorrect tl.assume statements can lead to miscompilation/incorrect codegen. Remove this assumption or replace it with a condition that always holds (or restructure the loop logic so assumptions match the valid range).

Suggested change
tl.assume(loop_k > 1)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants