From 2373b38883829e6940995b68ee38c8369c5fadf3 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Thu, 5 Feb 2026 19:46:44 +0000 Subject: [PATCH 1/8] ... --- benchmarks/benchmark_work_stealing.py | 399 ++++++++++++++++++ include/tritonblas/__init__.py | 1 + include/tritonblas/kernels/__init__.py | 6 +- .../kernels/persistent_gemm_work_stealing.py | 174 ++++++++ include/tritonblas/matmul.py | 263 +++++++++--- tests/test_work_stealing.py | 234 ++++++++++ 6 files changed, 1011 insertions(+), 66 deletions(-) create mode 100644 benchmarks/benchmark_work_stealing.py create mode 100644 include/tritonblas/kernels/persistent_gemm_work_stealing.py create mode 100644 tests/test_work_stealing.py diff --git a/benchmarks/benchmark_work_stealing.py b/benchmarks/benchmark_work_stealing.py new file mode 100644 index 0000000..5a6664d --- /dev/null +++ b/benchmarks/benchmark_work_stealing.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +""" +Benchmark: work-stealing persistent GEMM vs. static persistent GEMM vs. Stream-K vs. torch.matmul + +Uses importlib bootstrap to load kernels directly, bypassing the full tritonblas +import (which requires triton.constexpr_function not available in older builds). + +Usage: + HIP_VISIBLE_DEVICES=6 python benchmarks/benchmark_work_stealing.py +""" + +import os +import sys +import time +import types +import importlib.util +import torch +import triton +import triton.language as tl +from math import ceil + +# --------------------------------------------------------------------------- +# Bootstrap: load kernels without triggering stages/__init__.py +# --------------------------------------------------------------------------- +_root = os.path.join(os.path.dirname(__file__), "..", "include", "tritonblas") +_kernels_dir = os.path.join(_root, "kernels") +_stages_dir = os.path.join(_kernels_dir, "stages") +_indexing_dir = os.path.join(_stages_dir, "indexing") + + +def _load_module(fqn, filepath, package_path=None): + spec = importlib.util.spec_from_file_location(fqn, filepath) + mod = importlib.util.module_from_spec(spec) + if package_path is not None: + mod.__path__ = [package_path] + sys.modules[fqn] = mod + spec.loader.exec_module(mod) + return mod + + +def _make_stub_package(fqn, path): + pkg = types.ModuleType(fqn) + pkg.__path__ = [path] + pkg.__package__ = fqn + sys.modules[fqn] = pkg + return pkg + + +# Stub packages +_make_stub_package("tritonblas", _root) +_make_stub_package("tritonblas.kernels", _kernels_dir) +_make_stub_package("tritonblas.kernels.stages", _stages_dir) +_make_stub_package("tritonblas.kernels.stages.indexing", _indexing_dir) + +# Load pid_transforms (pure @triton.jit, no constexpr_function) +_load_module( + "tritonblas.kernels.stages.indexing.pid_transforms", + os.path.join(_indexing_dir, "pid_transforms.py"), +) + +# Load kernels +_mono_mod = _load_module( + "tritonblas.kernels.persistent_gemm_monolithic", + os.path.join(_kernels_dir, "persistent_gemm_monolithic.py"), +) +_ws_mod = _load_module( + "tritonblas.kernels.persistent_gemm_work_stealing", + os.path.join(_kernels_dir, "persistent_gemm_work_stealing.py"), +) +_sk_mod = _load_module( + "tritonblas.kernels.streamk_gemm", + os.path.join(_kernels_dir, "streamk_gemm.py"), +) + +persistent_matmul = _mono_mod.persistent_matmul +ws_persistent_matmul = _ws_mod.ws_persistent_matmul +streamk_matmul = _sk_mod.streamk_matmul + + +# --------------------------------------------------------------------------- +# Launch helpers +# --------------------------------------------------------------------------- +def _common_params(A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS): + M, K = A.shape + _, N = B.shape + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + total_tiles = total_blocks_M * total_blocks_N + even_k = K % BLK_K == 0 + chunk_size = GROUP_M * GROUP_M + chunk_size = min(chunk_size, max(1, total_tiles // NUM_XCDS)) + return M, K, N, total_tiles, even_k, chunk_size + + +def launch_persistent(A, B, C, BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8, NUM_XCDS=8): + """Original static-partition persistent GEMM (monolithic).""" + M, K, N, total_tiles, even_k, chunk_size = _common_params( + A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS + ) + grids = total_tiles + persistent_matmul[(grids,)]( + A, B, C, + None, None, None, # scale, bias + M, N, K, + A.stride(0), B.stride(1), C.stride(0), C.stride(1), 0, + stride_ak=A.stride(1), stride_bk=B.stride(0), + BLOCK_SIZE_M=BLK_M, BLOCK_SIZE_N=BLK_N, BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=GROUP_M, NUM_SMS=grids, NUM_XCDS=NUM_XCDS, + CHUNK_SIZE=chunk_size, BIAS=False, EVEN_K=even_k, + CACHE_MODIFIER_A=None, CACHE_MODIFIER_B=None, QUANTIZED=False, + num_stages=2, num_warps=8, waves_per_eu=0, + matrix_instr_nonkdim=16, kpack=1, + ) + + +def launch_work_stealing(A, B, C, tile_counter, num_sms, + BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8, NUM_XCDS=8): + """Work-stealing persistent GEMM (atomic counter).""" + M, K, N, total_tiles, even_k, chunk_size = _common_params( + A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS + ) + grids = num_sms + tile_counter.zero_() + ws_persistent_matmul[(grids,)]( + A, B, C, + None, None, None, # scale, bias + tile_counter, + M, N, K, + A.stride(0), B.stride(1), C.stride(0), C.stride(1), 0, + stride_ak=A.stride(1), stride_bk=B.stride(0), + BLOCK_SIZE_M=BLK_M, BLOCK_SIZE_N=BLK_N, BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=GROUP_M, NUM_SMS=grids, NUM_XCDS=NUM_XCDS, + CHUNK_SIZE=chunk_size, BIAS=False, EVEN_K=even_k, + CACHE_MODIFIER_A=None, CACHE_MODIFIER_B=None, QUANTIZED=False, + num_stages=2, num_warps=8, waves_per_eu=0, + matrix_instr_nonkdim=16, kpack=1, + ) + + +def launch_streamk(A, B, C, locks, P, sk_grid, + BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8, NUM_XCDS=8): + """Stream-K persistent GEMM.""" + M, K, N, total_tiles, even_k, chunk_size = _common_params( + A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS + ) + # StreamK tiles = remainder tiles that need cooperative decomposition + streamk_tiles = total_tiles % sk_grid if sk_grid > 0 else 0 + + chunk_size_sk = GROUP_M * GROUP_M + chunk_size_sk = min(chunk_size_sk, max(1, sk_grid // NUM_XCDS)) + + locks[:sk_grid].zero_() + streamk_matmul[(sk_grid,)]( + A, B, C, + None, None, None, # scale, bias + P[:sk_grid, :BLK_M * BLK_N], + locks[:sk_grid], + M, N, K, + A.stride(0), B.stride(1), C.stride(0), C.stride(1), 0, + stride_ak=A.stride(1), stride_bk=B.stride(0), + BLOCK_SIZE_M=BLK_M, BLOCK_SIZE_N=BLK_N, BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=GROUP_M, NUM_SMS=sk_grid, NUM_XCDS=NUM_XCDS, + CHUNK_SIZE=chunk_size_sk, STREAMK_TILES=streamk_tiles, + BIAS=False, EVEN_K=even_k, + CACHE_MODIFIER_A=None, CACHE_MODIFIER_B=None, QUANTIZED=False, + num_stages=2, num_warps=8, waves_per_eu=0, + matrix_instr_nonkdim=16, kpack=1, + ) + + +def launch_torch(A, B, C): + """torch.matmul (rocBLAS/hipBLAS backend).""" + torch.matmul(A, B, out=C) + + +# --------------------------------------------------------------------------- +# Simple Stream-K grid heuristic (mirrors origami logic) +# --------------------------------------------------------------------------- +def compute_sk_grid(M, N, K, BLK_M, BLK_N, BLK_K, cu_count): + tiles = ceil(M / BLK_M) * ceil(N / BLK_N) + sk_grid = tiles + split_factors = [8, 6, 4, 3, 2, 1] + tile_fractions = [0.0, 0.5, 0.125, 0.2, 0.25, 1.0 / 3.0] + iters_per_tile = max(1, ceil(K / BLK_K)) + + if tiles > cu_count: + min_even_tiles = tiles / cu_count + for frac in tile_fractions: + frac_grid = int((tiles / (min_even_tiles + frac)) + 0.5) + if frac_grid <= cu_count: + sk_grid = frac_grid + break + elif tiles < cu_count: + for factor in split_factors: + split_grid = tiles * factor + iters_per_cu = iters_per_tile // factor + if split_grid <= cu_count and iters_per_cu >= 8: + sk_grid = split_grid + break + + if tiles % sk_grid != 0: + sk_grid = tiles + + if tiles >= cu_count: + last_wave_remainder = tiles % cu_count + if last_wave_remainder < 128 and last_wave_remainder > 0 and cu_count in [304, 80, 64]: + sk_grid = 256 if cu_count == 304 else 64 + + return sk_grid + + +# --------------------------------------------------------------------------- +# Benchmark harness +# --------------------------------------------------------------------------- +def bench(fn, warmup=25, iters=50): + """Return median runtime in ms using triton.testing.do_bench.""" + return triton.testing.do_bench(fn, warmup=warmup, rep=iters) + + +def main(): + torch.manual_seed(42) + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + NUM_SMS = props.multi_processor_count + NUM_XCDS = 8 # MI300X + + print(f"Device : {props.name}") + print(f"CUs (SMs) : {NUM_SMS}") + print(f"HIP_VISIBLE : {os.environ.get('HIP_VISIBLE_DEVICES', '')}") + print() + + # Pre-allocate work-stealing counter + Stream-K buffers once + tile_counter = torch.zeros(1, device="cuda", dtype=torch.int32) + max_grid = NUM_SMS * 2 # generous upper bound for SK grid + block_area = 128 * 128 + locks = torch.zeros(max_grid, device="cuda", dtype=torch.uint8) + P = torch.zeros(max_grid, block_area, device="cuda", dtype=torch.float32) + + BLK_M, BLK_N, BLK_K, GROUP_M = 128, 128, 64, 8 + dtype = torch.float16 + + # Problem sizes to benchmark + sizes = [ + # Square + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (8192, 8192, 8192), + # Rectangular (common LLM shapes) + (1, 4096, 4096), + (4, 4096, 4096), + (16, 4096, 4096), + (32, 4096, 4096), + (64, 4096, 4096), + (128, 4096, 4096), + (256, 4096, 4096), + (512, 4096, 4096), + (1024, 4096, 4096), + (2048, 4096, 4096), + (4096, 4096, 11008), + (4096, 11008, 4096), + (8192, 8192, 4096), + (8192, 4096, 8192), + ] + + # Header + hdr = ( + f"{'M':>6} {'N':>6} {'K':>6} │ " + f"{'Persistent':>12} {'WorkSteal':>12} {'StreamK':>12} {'torch.mm':>12} │ " + f"{'WS/Pers':>8} {'WS/SK':>8} {'WS/Torch':>8}" + ) + sep = "─" * len(hdr) + print(sep) + print(f"{'':>20} │ {'── Time (ms) ──':^51} │ {'── Speedup ──':^26}") + print(hdr) + print(sep) + + results = [] + + for M, N, K in sizes: + A = torch.randn(M, K, device="cuda", dtype=dtype) + B = torch.randn(N, K, device="cuda", dtype=dtype).T # K x N contiguous + C_pers = torch.zeros(M, N, device="cuda", dtype=dtype) + C_ws = torch.zeros(M, N, device="cuda", dtype=dtype) + C_sk = torch.zeros(M, N, device="cuda", dtype=dtype) + C_ref = torch.zeros(M, N, device="cuda", dtype=dtype) + + even_k = K % BLK_K == 0 + total_tiles_m = triton.cdiv(M, BLK_M) + total_tiles_n = triton.cdiv(N, BLK_N) + total_tiles = total_tiles_m * total_tiles_n + + # Skip tiny sizes where tiles < 1 (e.g. M=1 with BLK_M=128 still gives 1 tile) + sk_grid = compute_sk_grid(M, N, K, BLK_M, BLK_N, BLK_K, NUM_SMS) + # Clamp stream-K grid to our pre-allocated buffer size + sk_grid = min(sk_grid, max_grid) + + # ── Benchmark each variant ────────────────────────────────── + try: + ms_pers = bench(lambda: launch_persistent( + A, B, C_pers, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS)) + except Exception as e: + ms_pers = float("nan") + + try: + ms_ws = bench(lambda: launch_work_stealing( + A, B, C_ws, tile_counter, NUM_SMS, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS)) + except Exception as e: + ms_ws = float("nan") + + try: + ms_sk = bench(lambda: launch_streamk( + A, B, C_sk, locks, P, sk_grid, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS)) + except Exception as e: + ms_sk = float("nan") + + ms_torch = bench(lambda: launch_torch(A, B, C_ref)) + + # ── Speedups (> 1.0 means work-stealing is faster) ───────── + su_pers = ms_pers / ms_ws if ms_ws > 0 else float("nan") + su_sk = ms_sk / ms_ws if ms_ws > 0 else float("nan") + su_torch = ms_torch / ms_ws if ms_ws > 0 else float("nan") + + # ── TFLOP/s ──────────────────────────────────────────────── + flops = 2.0 * M * N * K + def to_tflops(ms): + return flops / (ms * 1e-3) / 1e12 if ms > 0 else 0 + + row = { + "M": M, "N": N, "K": K, + "persistent_ms": ms_pers, + "work_stealing_ms": ms_ws, + "streamk_ms": ms_sk, + "torch_ms": ms_torch, + "persistent_tflops": to_tflops(ms_pers), + "work_stealing_tflops": to_tflops(ms_ws), + "streamk_tflops": to_tflops(ms_sk), + "torch_tflops": to_tflops(ms_torch), + "speedup_vs_pers": su_pers, + "speedup_vs_sk": su_sk, + "speedup_vs_torch": su_torch, + } + results.append(row) + + # Format ms; mark NaN + def fmt_ms(v): + return f"{v:12.4f}" if v == v else f"{'N/A':>12}" + def fmt_su(v): + return f"{v:8.3f}" if v == v else f"{'N/A':>8}" + + print( + f"{M:>6} {N:>6} {K:>6} │ " + f"{fmt_ms(ms_pers)} {fmt_ms(ms_ws)} {fmt_ms(ms_sk)} {fmt_ms(ms_torch)} │ " + f"{fmt_su(su_pers)} {fmt_su(su_sk)} {fmt_su(su_torch)}" + ) + + print(sep) + + # ── Summary in TFLOP/s ────────────────────────────────────────── + print() + print(sep) + print(f"{'':>20} │ {'── TFLOP/s ──':^51} │") + hdr2 = ( + f"{'M':>6} {'N':>6} {'K':>6} │ " + f"{'Persistent':>12} {'WorkSteal':>12} {'StreamK':>12} {'torch.mm':>12} │" + ) + print(hdr2) + print(sep) + for r in results: + def fmt_tf(v): + return f"{v:12.2f}" if v > 0 else f"{'N/A':>12}" + print( + f"{r['M']:>6} {r['N']:>6} {r['K']:>6} │ " + f"{fmt_tf(r['persistent_tflops'])} {fmt_tf(r['work_stealing_tflops'])} " + f"{fmt_tf(r['streamk_tflops'])} {fmt_tf(r['torch_tflops'])} │" + ) + print(sep) + + # ── Geometric mean speedup ────────────────────────────────────── + import math + valid_pers = [r["speedup_vs_pers"] for r in results if r["speedup_vs_pers"] == r["speedup_vs_pers"] and r["speedup_vs_pers"] > 0] + valid_sk = [r["speedup_vs_sk"] for r in results if r["speedup_vs_sk"] == r["speedup_vs_sk"] and r["speedup_vs_sk"] > 0] + valid_torch= [r["speedup_vs_torch"]for r in results if r["speedup_vs_torch"]== r["speedup_vs_torch"]and r["speedup_vs_torch"] > 0] + + def geomean(xs): + return math.exp(sum(math.log(x) for x in xs) / len(xs)) if xs else float("nan") + + print() + print("Geometric-mean speedup of Work-Stealing over:") + print(f" Persistent (static) : {geomean(valid_pers):.4f}x") + print(f" Stream-K : {geomean(valid_sk):.4f}x") + print(f" torch.matmul : {geomean(valid_torch):.4f}x") + print() + + +if __name__ == "__main__": + main() diff --git a/include/tritonblas/__init__.py b/include/tritonblas/__init__.py index 6f9365f..b2360bb 100644 --- a/include/tritonblas/__init__.py +++ b/include/tritonblas/__init__.py @@ -1,4 +1,5 @@ from .matmul import matmul, matmul_a8w8 from .matmul import matmul_lt, matmul_a8w8_lt from .matmul import matmul_fp4 +from .matmul import MatmulConfig, matmul_preamble from .origami import OrigamiMatmulSelector diff --git a/include/tritonblas/kernels/__init__.py b/include/tritonblas/kernels/__init__.py index b7ea70c..81bbd57 100644 --- a/include/tritonblas/kernels/__init__.py +++ b/include/tritonblas/kernels/__init__.py @@ -4,6 +4,7 @@ This package contains specific GEMM kernel implementations: - persistent_gemm: Persistent (data-parallel) GEMM kernel using composable stages - persistent_gemm_monolithic: Monolithic persistent GEMM kernel (legacy, for debugging) +- persistent_gemm_work_stealing: Work-stealing persistent GEMM kernel - streamk_gemm: Stream-K GEMM kernel for load balancing - stages: Composable kernel building blocks @@ -23,6 +24,9 @@ # Use composable stages version (default) from .persistent_gemm import persistent_matmul +# Work-stealing kernel (opt-in via work_stealing=True in matmul calls) +from .persistent_gemm_work_stealing import ws_persistent_matmul + # Stream-K kernel is always the same from .streamk_gemm import streamk_matmul @@ -32,4 +36,4 @@ # Export stages submodule from . import stages -__all__ = ['persistent_matmul', 'streamk_matmul', 'fp4_matmul', 'stages'] +__all__ = ['persistent_matmul', 'ws_persistent_matmul', 'streamk_matmul', 'fp4_matmul', 'stages'] diff --git a/include/tritonblas/kernels/persistent_gemm_work_stealing.py b/include/tritonblas/kernels/persistent_gemm_work_stealing.py new file mode 100644 index 0000000..7ef946d --- /dev/null +++ b/include/tritonblas/kernels/persistent_gemm_work_stealing.py @@ -0,0 +1,174 @@ +""" +Work-stealing persistent GEMM kernel. + +Instead of statically partitioning tiles across workgroups (for tile_id in +range(pid, total_tiles, NUM_SMS)), each WG dynamically grabs the next +available tile via a global atomic counter. This naturally load-balances +when some WGs arrive late to the party. +""" + +import triton +import triton.language as tl +import torch + +from .stages.indexing.pid_transforms import chiplet_transform_chunked + +@triton.jit() +def ws_persistent_matmul( + A, + B, + C, + A_scale_ptr, # Optional: None for fp16/bf16, pointer for int8/fp8 + B_scale_ptr, # Optional: None for fp16/bf16, pointer for int8/fp8 + bias_ptr, + tile_counter, # Global atomic counter for work-stealing (int32[1]) + M, + N, + K, + stride_am, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + CACHE_MODIFIER_A: tl.constexpr, + CACHE_MODIFIER_B: tl.constexpr, + QUANTIZED: tl.constexpr = False, # True for int8/fp8, False for fp16/bf16 + ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32, +): + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + # ── Work-stealing with chiplet swizzle ──────────────────────────────── + # 1. Grab a raw tile index from a single global atomic counter. + # 2. Swizzle it through chiplet_transform_chunked so that consecutive + # tile_ids land on the same XCD → better L2 locality. + # 3. The GROUP_SIZE_M decomposition below turns the swizzled tile_id + # into (pid_m, pid_n). + tile_id = tl.atomic_add(tile_counter, 1) + for _ in range(total_tiles): + if tile_id < total_tiles: + # Chiplet-aware swizzle + if NUM_XCDS != 1: + tile_id = chiplet_transform_chunked(tile_id, total_tiles, NUM_XCDS, CHUNK_SIZE) + + # GROUP_SIZE_M swizzle → (pid_m, pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + if BIAS: + bias_ = bias_ptr + rm * stride_bias + bias = tl.load(bias_, mask=rm < M, other=0.0) + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + tl.assume(loop_k > 1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + if stride_ak == 1: + a = tl.load(tl.multiple_of(A_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_A) + else: + a = tl.load(tl.multiple_of(A_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_A) + + if stride_bk == 1: + b = tl.load(tl.multiple_of(B_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_B) + else: + b = tl.load(tl.multiple_of(B_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_B) + + # Conditional dot product precision based on quantization mode + if QUANTIZED: + acc += tl.dot(a, b, input_precision="ieee") + else: + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + if stride_ak == 1: + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + else: + A_BASE = tl.multiple_of(A_BASE, (16, 1)) + + if stride_bk == 1: + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + else: + B_BASE = tl.multiple_of(B_BASE, (1, 16)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0, cache_modifier=CACHE_MODIFIER_A) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0, cache_modifier=CACHE_MODIFIER_B) + + if QUANTIZED: + acc += tl.dot(a, b, input_precision="ieee") + else: + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + + # Conditional scaling for quantized mode + if QUANTIZED: + # Create pointers for the scale tensors and load them + rm_A_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) % M + rn_B_scale = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N + A_scale = tl.load(A_scale_ptr + rm_A_scale) + B_scale = tl.load(B_scale_ptr + rn_B_scale) + acc *= A_scale[:, None] * B_scale[None, :] + + # Unified bias handling + if BIAS: + if QUANTIZED: + bias_float = bias.to(tl.float32) + c = acc + bias_float[:, None] + c = c.to(C.type.element_ty) + else: + c = acc.to(C.type.element_ty) + c += bias[:, None] + else: + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, c, c_mask) + + # Grab next tile + tile_id = tl.atomic_add(tile_counter, 1) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 9b23986..5aa8995 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -3,21 +3,96 @@ import random import functools import time -from .kernels import persistent_matmul, streamk_matmul +from .kernels import persistent_matmul, ws_persistent_matmul, streamk_matmul from .kernels.fp4_matmul import fp4_matmul from .origami import OrigamiMatmulSelector from typing import Dict, Tuple, Optional _tensor_cache = {} -current_device_index = torch.cuda.current_device() -current_device = torch.cuda.get_device_properties(current_device_index) -MAX_SMS = current_device.multi_processor_count + # TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4 -MAX_BLOCK_SIZE = 65536 +_DEFAULT_MAX_BLOCK_SIZE = 65536 + + +class MatmulConfig: + """ + Pre-allocated GPU buffers and device metadata for GEMM kernel launches. + + Create one via :func:`matmul_preamble` and pass it to any ``matmul*`` + function. Call :meth:`reset` (or the more targeted helpers) between + launches to zero out mutable state. + + Attributes: + num_sms: Number of SMs / CUs on the device. + max_block_size: Maximum tile footprint (BLOCK_M * BLOCK_N). + tile_counter: ``int32[1]`` atomic counter for work-stealing. + locks: ``uint8[num_sms]`` stream-K lock array. + P: ``float32[num_sms, max_block_size]`` stream-K partial buffer. + """ + + def __init__(self, device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE): + props = torch.cuda.get_device_properties(device) + self.device = device + self.num_sms: int = props.multi_processor_count + self.max_block_size: int = max_block_size + + # Work-stealing tile counter + self.tile_counter = torch.zeros(1, device=device, dtype=torch.int32) + + # Stream-K buffers + self.locks = torch.empty(self.num_sms, device=device, dtype=torch.uint8) + self.P = torch.empty(self.num_sms, max_block_size, device=device, dtype=torch.float32) + + # ------------------------------------------------------------------ + # Reset helpers + # ------------------------------------------------------------------ + def reset(self): + """Reset all mutable state (tile counter + stream-K buffers).""" + self.reset_tile_counter() + self.reset_streamk() + + def reset_tile_counter(self): + """Zero the work-stealing tile counter.""" + self.tile_counter.zero_() + + def reset_streamk(self): + """Zero the stream-K locks (P is overwritten, no need to clear).""" + self.locks.zero_() + + def __repr__(self): + return ( + f"MatmulConfig(device={self.device!r}, num_sms={self.num_sms}, " + f"max_block_size={self.max_block_size})" + ) + + +def matmul_preamble(device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE) -> MatmulConfig: + """ + Allocate all GPU-side buffers needed by the tritonBLAS GEMM kernels. + + Call this once (e.g. during model init) and pass the returned config + into ``matmul``, ``matmul_lt``, ``matmul_a8w8``, etc. + + Args: + device: CUDA device string (default ``"cuda"``). + max_block_size: Maximum tile footprint (default 65536 = 256*256). + + Returns: + A :class:`MatmulConfig` ready for kernel launches. + """ + return MatmulConfig(device=device, max_block_size=max_block_size) + -# Global pre-allocated buffers -_global_locks = torch.empty(MAX_SMS, device="cuda", dtype=torch.uint8) -_global_P = torch.empty(MAX_SMS, MAX_BLOCK_SIZE, device="cuda", dtype=torch.float32) +# Lazy default config -- created on first use so import-time CUDA init is +# deferred until somebody actually calls a matmul function. +_default_config: Optional[MatmulConfig] = None + + +def _get_default_config() -> MatmulConfig: + global _default_config + if _default_config is None: + _default_config = matmul_preamble() + return _default_config # Function will behave like an LRU-Cache of heuristic results @@ -55,7 +130,11 @@ def persistent_matmul_lt( a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, + work_stealing: bool = False, + config: Optional[MatmulConfig] = None, ): + cfg = config or _get_default_config() + assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape @@ -84,49 +163,93 @@ def persistent_matmul_lt( CACHE_MODIFIER_A = None CACHE_MODIFIER_B = None - # Run in Data-parallel mode. - grids = total_tiles - # Set chunk size to same area as L2 tiles. chunk_size = gsize_m * gsize_m - chunk_size = min(chunk_size, total_programs // num_xcds) - - # TODO: Support other matmul algs. - kk = persistent_matmul[(grids,)]( - a, - b, - c, - a_scale if quantized else None, # A_scale_ptr - b_scale if quantized else None, # B_scale_ptr - None, # TODO: Enable bias. - M, - N, - K, - a.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - 0, # TODO: Enable bias stride. - stride_ak=a.stride(1), - stride_bk=b.stride(0), - BLOCK_SIZE_M=BLK_M, - BLOCK_SIZE_N=BLK_N, - BLOCK_SIZE_K=BLK_K, - GROUP_SIZE_M=gsize_m, - NUM_SMS=total_programs, - NUM_XCDS=num_xcds, - CHUNK_SIZE=chunk_size, - BIAS=False, - EVEN_K=even_k, - CACHE_MODIFIER_A=CACHE_MODIFIER_A, - CACHE_MODIFIER_B=CACHE_MODIFIER_B, - QUANTIZED=quantized, - num_stages=num_stages, - num_warps=num_warps, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, - ) + chunk_size = min(chunk_size, max(1, total_programs // num_xcds)) + + if work_stealing: + # Work-stealing: launch grid = num CUs, tiles assigned dynamically. + grids = cfg.num_sms + + # Reset the tile counter before each launch. + cfg.reset_tile_counter() + + kk = ws_persistent_matmul[(grids,)]( + a, + b, + c, + a_scale if quantized else None, # A_scale_ptr + b_scale if quantized else None, # B_scale_ptr + None, # TODO: Enable bias. + cfg.tile_counter, # Work-stealing tile counter + M, + N, + K, + a.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + 0, # TODO: Enable bias stride. + stride_ak=a.stride(1), + stride_bk=b.stride(0), + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=grids, + NUM_XCDS=num_xcds, + CHUNK_SIZE=chunk_size, + BIAS=False, + EVEN_K=even_k, + CACHE_MODIFIER_A=CACHE_MODIFIER_A, + CACHE_MODIFIER_B=CACHE_MODIFIER_B, + QUANTIZED=quantized, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + ) + else: + # Default: data-parallel mode – one WG per tile. + grids = total_tiles + + # TODO: Support other matmul algs. + kk = persistent_matmul[(grids,)]( + a, + b, + c, + a_scale if quantized else None, # A_scale_ptr + b_scale if quantized else None, # B_scale_ptr + None, # TODO: Enable bias. + M, + N, + K, + a.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + 0, # TODO: Enable bias stride. + stride_ak=a.stride(1), + stride_bk=b.stride(0), + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=total_programs, + NUM_XCDS=num_xcds, + CHUNK_SIZE=chunk_size, + BIAS=False, + EVEN_K=even_k, + CACHE_MODIFIER_A=CACHE_MODIFIER_A, + CACHE_MODIFIER_B=CACHE_MODIFIER_B, + QUANTIZED=quantized, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + ) return c @@ -139,7 +262,10 @@ def streamk_matmul_lt( a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, + config: Optional[MatmulConfig] = None, ): + cfg = config or _get_default_config() + assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape @@ -180,13 +306,13 @@ def streamk_matmul_lt( grids = total_programs_streamk block_size = BLK_M * BLK_N - # Use global buffers with optimized zeroing - if grids <= MAX_SMS and block_size <= MAX_BLOCK_SIZE: - locks = _global_locks[:grids] - P = _global_P[:grids, :block_size] + # Use config buffers; fall back to fresh allocation for oversized grids. + if grids <= cfg.num_sms and block_size <= cfg.max_block_size: + locks = cfg.locks[:grids] + P = cfg.P[:grids, :block_size] else: - locks = torch.empty(grids, device="cuda", dtype=torch.uint8) - P = torch.empty(grids, block_size, device="cuda", dtype=torch.float32) + locks = torch.empty(grids, device=cfg.device, dtype=torch.uint8) + P = torch.empty(grids, block_size, device=cfg.device, dtype=torch.float32) # Set chunk size to same area as L2 tiles. chunk_size = gsize_m * gsize_m @@ -234,31 +360,36 @@ def streamk_matmul_lt( return c def matmul_lt( - a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector, + enable_streamk=False, work_stealing=False, config: Optional[MatmulConfig] = None, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" if enable_streamk: - return streamk_matmul_lt(a, b, c, selector) + return streamk_matmul_lt(a, b, c, selector, config=config) else: - return persistent_matmul_lt(a, b, c, selector) + return persistent_matmul_lt(a, b, c, selector, work_stealing=work_stealing, config=config) def matmul_a8w8_lt( - a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False + a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, + c: torch.Tensor, selector, enable_streamk=False, work_stealing=False, + config: Optional[MatmulConfig] = None, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) + return streamk_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True, config=config) else: - return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) + return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True, work_stealing=work_stealing, config=config) def matmul( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, enable_streamk=False, + work_stealing=False, sk_grid=None, + config: Optional[MatmulConfig] = None, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape @@ -266,9 +397,9 @@ def matmul( selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk) if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid) + return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid, config=config) else: - return persistent_matmul_lt(a, b, c, selector) + return persistent_matmul_lt(a, b, c, selector, work_stealing=work_stealing, config=config) def matmul_a8w8( a: torch.Tensor, @@ -277,7 +408,9 @@ def matmul_a8w8( b_scale: torch.Tensor, c: torch.Tensor, enable_streamk=False, + work_stealing=False, sk_grid=None, + config: Optional[MatmulConfig] = None, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape @@ -285,9 +418,9 @@ def matmul_a8w8( selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk) if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid, a_scale=a_scale, b_scale=b_scale, quantized=True) + return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid, a_scale=a_scale, b_scale=b_scale, quantized=True, config=config) else: - return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) + return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True, work_stealing=work_stealing, config=config) def matmul_fp4( a: torch.Tensor, diff --git a/tests/test_work_stealing.py b/tests/test_work_stealing.py new file mode 100644 index 0000000..7131c67 --- /dev/null +++ b/tests/test_work_stealing.py @@ -0,0 +1,234 @@ +""" +Standalone smoke-test for the work-stealing persistent GEMM kernel. + +Directly imports the work-stealing kernel to avoid the stages/streamk import +chain that requires a newer Triton with `constexpr_function`. +""" + +import os +import sys +import time +import types +import importlib.util +import torch +import triton + +# --------------------------------------------------------------------------- +# Bootstrap: load only the pieces the work-stealing kernel needs, bypassing +# the full tritonblas package init (which pulls in stages/__init__.py that +# requires triton.constexpr_function not available in this Triton build). +# --------------------------------------------------------------------------- +_kernels_dir = os.path.join( + os.path.dirname(__file__), "..", "include", "tritonblas", "kernels", +) +_stages_dir = os.path.join(_kernels_dir, "stages") +_indexing_dir = os.path.join(_stages_dir, "indexing") + + +def _load_module(fqn, filepath, package_path=None): + """Load a single .py file and register it in sys.modules.""" + spec = importlib.util.spec_from_file_location(fqn, filepath) + mod = importlib.util.module_from_spec(spec) + if package_path is not None: + mod.__path__ = [package_path] + sys.modules[fqn] = mod + spec.loader.exec_module(mod) + return mod + + +def _make_stub_package(fqn, path): + """Register a stub package so relative imports can traverse it.""" + pkg = types.ModuleType(fqn) + pkg.__path__ = [path] + pkg.__package__ = fqn + sys.modules[fqn] = pkg + return pkg + + +# Stub packages (just enough for the relative import chain) +_make_stub_package("tritonblas", os.path.join(_kernels_dir, "..")) +_make_stub_package("tritonblas.kernels", _kernels_dir) +_make_stub_package("tritonblas.kernels.stages", _stages_dir) +_make_stub_package("tritonblas.kernels.stages.indexing", _indexing_dir) + +# Load pid_transforms (pure @triton.jit, no constexpr_function dependency) +_load_module( + "tritonblas.kernels.stages.indexing.pid_transforms", + os.path.join(_indexing_dir, "pid_transforms.py"), +) + +# Now load the work-stealing kernel — its relative import will resolve +_ws_mod = _load_module( + "tritonblas.kernels.persistent_gemm_work_stealing", + os.path.join(_kernels_dir, "persistent_gemm_work_stealing.py"), +) +ws_persistent_matmul = _ws_mod.ws_persistent_matmul + + +def make_tile_counter(device="cuda"): + """Allocate a fresh work-stealing tile counter.""" + return torch.zeros(1, device=device, dtype=torch.int32) + + +def run_ws_persistent_matmul(A, B, C, tile_counter, BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8): + """Launch the work-stealing persistent kernel.""" + M, K = A.shape + _, N = B.shape + + props = torch.cuda.get_device_properties(A.device) + NUM_SMS = props.multi_processor_count + + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + total_tiles = total_blocks_M * total_blocks_N + even_k = K % BLK_K == 0 + + NUM_XCDS = 8 + chunk_size = GROUP_M * GROUP_M + chunk_size = min(chunk_size, max(1, total_tiles // NUM_XCDS)) + + # Grid = number of CUs (work-stealing) + grids = NUM_SMS + + # Reset counter + tile_counter.zero_() + + ws_persistent_matmul[(grids,)]( + A, B, C, + None, # A_scale_ptr + None, # B_scale_ptr + None, # bias_ptr + tile_counter, + M, N, K, + A.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + 0, # bias stride + stride_ak=A.stride(1), + stride_bk=B.stride(0), + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=GROUP_M, + NUM_SMS=grids, + NUM_XCDS=NUM_XCDS, + CHUNK_SIZE=chunk_size, + BIAS=False, + EVEN_K=even_k, + CACHE_MODIFIER_A=None, + CACHE_MODIFIER_B=None, + QUANTIZED=False, + num_stages=2, + num_warps=8, + waves_per_eu=0, + matrix_instr_nonkdim=16, + kpack=1, + ) + + +def test_correctness(m, n, k, dtype=torch.float16): + """Run a single test and compare against torch.matmul.""" + A = torch.randn(m, k, device="cuda", dtype=dtype) + B = torch.randn(n, k, device="cuda", dtype=dtype).T + C = torch.zeros(m, n, device="cuda", dtype=dtype) + + ref = torch.matmul(A, B) + + tile_counter = make_tile_counter() + run_ws_persistent_matmul(A, B, C, tile_counter) + torch.cuda.synchronize() + + max_diff = (C - ref).abs().max().item() + mean_diff = (C - ref).abs().mean().item() + passed = torch.allclose(C, ref, atol=1e-1, rtol=1e-2) + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] {m:>5}x{n:<5}x{k:<5} " + f"max_diff={max_diff:.6f} mean_diff={mean_diff:.6f}" + ) + return passed + + +def bench_throughput(m, n, k, dtype=torch.float16, warmup=5, iters=20): + """Quick throughput benchmark.""" + A = torch.randn(m, k, device="cuda", dtype=dtype) + B = torch.randn(n, k, device="cuda", dtype=dtype).T + C = torch.zeros(m, n, device="cuda", dtype=dtype) + tile_counter = make_tile_counter() + + # Warmup + for _ in range(warmup): + run_ws_persistent_matmul(A, B, C, tile_counter) + torch.cuda.synchronize() + + # Timed + start = time.perf_counter() + for _ in range(iters): + run_ws_persistent_matmul(A, B, C, tile_counter) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + avg_ms = (elapsed / iters) * 1000 + flops = 2.0 * m * n * k + tflops = (flops / (avg_ms / 1000)) / 1e12 + print(f" {m:>5}x{n:<5}x{k:<5} avg={avg_ms:7.3f} ms {tflops:6.2f} TFLOP/s") + return avg_ms + + +def main(): + torch.manual_seed(42) + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + print(f"Device: {props.name} (CUs: {props.multi_processor_count})") + print(f"HIP_VISIBLE_DEVICES = {os.environ.get('HIP_VISIBLE_DEVICES', '')}") + print() + + # ── Correctness ─────────────────────────────────────────────────── + print("=" * 68) + print("Correctness (work-stealing kernel vs torch.matmul)") + print("=" * 68) + all_pass = True + for m, n, k in [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (8192, 8192, 8192), + ]: + try: + ok = test_correctness(m, n, k) + all_pass &= ok + except Exception as e: + print(f" [ERROR] {m}x{n}x{k}: {e}") + import traceback; traceback.print_exc() + all_pass = False + + # ── Throughput ──────────────────────────────────────────────────── + print() + print("=" * 68) + print("Throughput (work-stealing kernel)") + print("=" * 68) + for m, n, k in [ + (1024, 1024, 1024), + (4096, 4096, 4096), + (8192, 8192, 8192), + ]: + try: + bench_throughput(m, n, k) + except Exception as e: + print(f" [ERROR] {m}x{n}x{k}: {e}") + import traceback; traceback.print_exc() + + print() + if all_pass: + print("All correctness tests PASSED.") + else: + print("Some correctness tests FAILED.") + sys.exit(1) + + +if __name__ == "__main__": + main() From ecc5d731549a9d11c4c2107d4fb48c3e17c55964 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Thu, 5 Feb 2026 19:58:59 +0000 Subject: [PATCH 2/8] 8-way spread contention. --- benchmarks/benchmark_work_stealing.py | 58 +++++++++---------- .../kernels/persistent_gemm_work_stealing.py | 54 ++++++++++------- include/tritonblas/matmul.py | 28 +++++---- tests/test_work_stealing.py | 26 ++++----- 4 files changed, 89 insertions(+), 77 deletions(-) diff --git a/benchmarks/benchmark_work_stealing.py b/benchmarks/benchmark_work_stealing.py index 5a6664d..398893a 100644 --- a/benchmarks/benchmark_work_stealing.py +++ b/benchmarks/benchmark_work_stealing.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 """ -Benchmark: work-stealing persistent GEMM vs. static persistent GEMM vs. Stream-K vs. torch.matmul +Benchmark: per-XCD work-stealing persistent GEMM vs. static persistent GEMM + vs. Stream-K vs. torch.matmul Uses importlib bootstrap to load kernels directly, bypassing the full tritonblas import (which requires triton.constexpr_function not available in older builds). @@ -80,7 +81,10 @@ def _make_stub_package(fqn, path): # --------------------------------------------------------------------------- # Launch helpers # --------------------------------------------------------------------------- -def _common_params(A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS): +NUM_XCDS = 8 # MI300X + + +def _common_params(A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M): M, K = A.shape _, N = B.shape total_blocks_M = triton.cdiv(M, BLK_M) @@ -92,10 +96,10 @@ def _common_params(A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS): return M, K, N, total_tiles, even_k, chunk_size -def launch_persistent(A, B, C, BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8, NUM_XCDS=8): +def launch_persistent(A, B, C, BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8): """Original static-partition persistent GEMM (monolithic).""" M, K, N, total_tiles, even_k, chunk_size = _common_params( - A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS + A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M ) grids = total_tiles persistent_matmul[(grids,)]( @@ -114,10 +118,10 @@ def launch_persistent(A, B, C, BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8, NUM_XC def launch_work_stealing(A, B, C, tile_counter, num_sms, - BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8, NUM_XCDS=8): - """Work-stealing persistent GEMM (atomic counter).""" + BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8): + """Per-XCD work-stealing persistent GEMM.""" M, K, N, total_tiles, even_k, chunk_size = _common_params( - A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS + A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M ) grids = num_sms tile_counter.zero_() @@ -130,7 +134,7 @@ def launch_work_stealing(A, B, C, tile_counter, num_sms, stride_ak=A.stride(1), stride_bk=B.stride(0), BLOCK_SIZE_M=BLK_M, BLOCK_SIZE_N=BLK_N, BLOCK_SIZE_K=BLK_K, GROUP_SIZE_M=GROUP_M, NUM_SMS=grids, NUM_XCDS=NUM_XCDS, - CHUNK_SIZE=chunk_size, BIAS=False, EVEN_K=even_k, + BIAS=False, EVEN_K=even_k, CACHE_MODIFIER_A=None, CACHE_MODIFIER_B=None, QUANTIZED=False, num_stages=2, num_warps=8, waves_per_eu=0, matrix_instr_nonkdim=16, kpack=1, @@ -138,12 +142,11 @@ def launch_work_stealing(A, B, C, tile_counter, num_sms, def launch_streamk(A, B, C, locks, P, sk_grid, - BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8, NUM_XCDS=8): + BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8): """Stream-K persistent GEMM.""" M, K, N, total_tiles, even_k, chunk_size = _common_params( - A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS + A, B, C, BLK_M, BLK_N, BLK_K, GROUP_M ) - # StreamK tiles = remainder tiles that need cooperative decomposition streamk_tiles = total_tiles % sk_grid if sk_grid > 0 else 0 chunk_size_sk = GROUP_M * GROUP_M @@ -222,16 +225,17 @@ def main(): device = torch.cuda.current_device() props = torch.cuda.get_device_properties(device) NUM_SMS = props.multi_processor_count - NUM_XCDS = 8 # MI300X print(f"Device : {props.name}") print(f"CUs (SMs) : {NUM_SMS}") + print(f"XCDs : {NUM_XCDS}") + print(f"CUs per XCD : {NUM_SMS // NUM_XCDS}") print(f"HIP_VISIBLE : {os.environ.get('HIP_VISIBLE_DEVICES', '')}") print() - # Pre-allocate work-stealing counter + Stream-K buffers once - tile_counter = torch.zeros(1, device="cuda", dtype=torch.int32) - max_grid = NUM_SMS * 2 # generous upper bound for SK grid + # Pre-allocate per-XCD tile counters + Stream-K buffers + tile_counter = torch.zeros(NUM_XCDS, device="cuda", dtype=torch.int32) + max_grid = NUM_SMS * 2 block_area = 128 * 128 locks = torch.zeros(max_grid, device="cuda", dtype=torch.uint8) P = torch.zeros(max_grid, block_area, device="cuda", dtype=torch.float32) @@ -239,7 +243,7 @@ def main(): BLK_M, BLK_N, BLK_K, GROUP_M = 128, 128, 64, 8 dtype = torch.float16 - # Problem sizes to benchmark + # Problem sizes sizes = [ # Square (256, 256, 256), @@ -268,7 +272,7 @@ def main(): # Header hdr = ( f"{'M':>6} {'N':>6} {'K':>6} │ " - f"{'Persistent':>12} {'WorkSteal':>12} {'StreamK':>12} {'torch.mm':>12} │ " + f"{'Persistent':>12} {'WS-perXCD':>12} {'StreamK':>12} {'torch.mm':>12} │ " f"{'WS/Pers':>8} {'WS/SK':>8} {'WS/Torch':>8}" ) sep = "─" * len(hdr) @@ -281,38 +285,31 @@ def main(): for M, N, K in sizes: A = torch.randn(M, K, device="cuda", dtype=dtype) - B = torch.randn(N, K, device="cuda", dtype=dtype).T # K x N contiguous + B = torch.randn(N, K, device="cuda", dtype=dtype).T C_pers = torch.zeros(M, N, device="cuda", dtype=dtype) C_ws = torch.zeros(M, N, device="cuda", dtype=dtype) C_sk = torch.zeros(M, N, device="cuda", dtype=dtype) C_ref = torch.zeros(M, N, device="cuda", dtype=dtype) - even_k = K % BLK_K == 0 - total_tiles_m = triton.cdiv(M, BLK_M) - total_tiles_n = triton.cdiv(N, BLK_N) - total_tiles = total_tiles_m * total_tiles_n - - # Skip tiny sizes where tiles < 1 (e.g. M=1 with BLK_M=128 still gives 1 tile) sk_grid = compute_sk_grid(M, N, K, BLK_M, BLK_N, BLK_K, NUM_SMS) - # Clamp stream-K grid to our pre-allocated buffer size sk_grid = min(sk_grid, max_grid) # ── Benchmark each variant ────────────────────────────────── try: ms_pers = bench(lambda: launch_persistent( - A, B, C_pers, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS)) + A, B, C_pers, BLK_M, BLK_N, BLK_K, GROUP_M)) except Exception as e: ms_pers = float("nan") try: ms_ws = bench(lambda: launch_work_stealing( - A, B, C_ws, tile_counter, NUM_SMS, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS)) + A, B, C_ws, tile_counter, NUM_SMS, BLK_M, BLK_N, BLK_K, GROUP_M)) except Exception as e: ms_ws = float("nan") try: ms_sk = bench(lambda: launch_streamk( - A, B, C_sk, locks, P, sk_grid, BLK_M, BLK_N, BLK_K, GROUP_M, NUM_XCDS)) + A, B, C_sk, locks, P, sk_grid, BLK_M, BLK_N, BLK_K, GROUP_M)) except Exception as e: ms_sk = float("nan") @@ -344,7 +341,6 @@ def to_tflops(ms): } results.append(row) - # Format ms; mark NaN def fmt_ms(v): return f"{v:12.4f}" if v == v else f"{'N/A':>12}" def fmt_su(v): @@ -364,7 +360,7 @@ def fmt_su(v): print(f"{'':>20} │ {'── TFLOP/s ──':^51} │") hdr2 = ( f"{'M':>6} {'N':>6} {'K':>6} │ " - f"{'Persistent':>12} {'WorkSteal':>12} {'StreamK':>12} {'torch.mm':>12} │" + f"{'Persistent':>12} {'WS-perXCD':>12} {'StreamK':>12} {'torch.mm':>12} │" ) print(hdr2) print(sep) @@ -388,7 +384,7 @@ def geomean(xs): return math.exp(sum(math.log(x) for x in xs) / len(xs)) if xs else float("nan") print() - print("Geometric-mean speedup of Work-Stealing over:") + print("Geometric-mean speedup of per-XCD Work-Stealing over:") print(f" Persistent (static) : {geomean(valid_pers):.4f}x") print(f" Stream-K : {geomean(valid_sk):.4f}x") print(f" torch.matmul : {geomean(valid_torch):.4f}x") diff --git a/include/tritonblas/kernels/persistent_gemm_work_stealing.py b/include/tritonblas/kernels/persistent_gemm_work_stealing.py index 7ef946d..b5cdd66 100644 --- a/include/tritonblas/kernels/persistent_gemm_work_stealing.py +++ b/include/tritonblas/kernels/persistent_gemm_work_stealing.py @@ -1,17 +1,24 @@ """ -Work-stealing persistent GEMM kernel. +Work-stealing persistent GEMM kernel with per-XCD atomic counters. Instead of statically partitioning tiles across workgroups (for tile_id in range(pid, total_tiles, NUM_SMS)), each WG dynamically grabs the next -available tile via a global atomic counter. This naturally load-balances -when some WGs arrive late to the party. +available tile via an atomic counter that is local to its XCD. + +PIDs are assigned round-robin across XCDs: + pid 0 → XCD 0, pid 1 → XCD 1, …, pid 7 → XCD 7, pid 8 → XCD 0, … + +The tile space is partitioned into contiguous per-XCD regions: + XCD i owns tiles [i * tiles_per_xcd, min((i+1) * tiles_per_xcd, total_tiles)) + +Each XCD has its own atomic counter (tile_counter[xcd_id]) so CUs only +contend with the ~38 other CUs on the same die, not all 304 CUs. """ import triton import triton.language as tl import torch -from .stages.indexing.pid_transforms import chiplet_transform_chunked @triton.jit() def ws_persistent_matmul( @@ -21,7 +28,7 @@ def ws_persistent_matmul( A_scale_ptr, # Optional: None for fp16/bf16, pointer for int8/fp8 B_scale_ptr, # Optional: None for fp16/bf16, pointer for int8/fp8 bias_ptr, - tile_counter, # Global atomic counter for work-stealing (int32[1]) + tile_counter, # Per-XCD atomic counters (int32[NUM_XCDS]) M, N, K, @@ -38,7 +45,6 @@ def ws_persistent_matmul( GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr, NUM_XCDS: tl.constexpr, - CHUNK_SIZE: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr, CACHE_MODIFIER_A: tl.constexpr, @@ -46,10 +52,22 @@ def ws_persistent_matmul( QUANTIZED: tl.constexpr = False, # True for int8/fp8, False for fp16/bf16 ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32, ): + pid = tl.program_id(0) + + # ── Per-XCD work-stealing ────────────────────────────────────────── + # PIDs are round-robin across XCDs: 0→XCD0, 1→XCD1, …, 7→XCD7, 8→XCD0… + xcd_id = pid % NUM_XCDS + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n + # Partition tiles into contiguous per-XCD regions. + tiles_per_xcd = tl.cdiv(total_tiles, NUM_XCDS) + xcd_base = xcd_id * tiles_per_xcd + xcd_end = tl.minimum(xcd_base + tiles_per_xcd, total_tiles) + tiles_this_xcd = xcd_end - xcd_base + tl.assume(stride_am > 0) tl.assume(stride_ak > 0) tl.assume(stride_bn > 0) @@ -59,18 +77,14 @@ def ws_persistent_matmul( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 - # ── Work-stealing with chiplet swizzle ──────────────────────────────── - # 1. Grab a raw tile index from a single global atomic counter. - # 2. Swizzle it through chiplet_transform_chunked so that consecutive - # tile_ids land on the same XCD → better L2 locality. - # 3. The GROUP_SIZE_M decomposition below turns the swizzled tile_id - # into (pid_m, pid_n). - tile_id = tl.atomic_add(tile_counter, 1) - for _ in range(total_tiles): - if tile_id < total_tiles: - # Chiplet-aware swizzle - if NUM_XCDS != 1: - tile_id = chiplet_transform_chunked(tile_id, total_tiles, NUM_XCDS, CHUNK_SIZE) + # Per-XCD atomic counter — only CUs on the same die contend. + counter_ptr = tile_counter + xcd_id + local_tile_idx = tl.atomic_add(counter_ptr, 1) + + for _ in range(tiles_per_xcd): + if local_tile_idx < tiles_this_xcd: + # Map local index → global tile_id + tile_id = xcd_base + local_tile_idx # GROUP_SIZE_M swizzle → (pid_m, pid_n) num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -170,5 +184,5 @@ def ws_persistent_matmul( C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn tl.store(C_, c, c_mask) - # Grab next tile - tile_id = tl.atomic_add(tile_counter, 1) + # Grab next tile from this XCD's counter + local_tile_idx = tl.atomic_add(counter_ptr, 1) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 5aa8995..442479a 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -12,6 +12,7 @@ # TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4 _DEFAULT_MAX_BLOCK_SIZE = 65536 +_DEFAULT_MAX_XCDS = 16 # Covers MI300X (8 XCDs) with headroom for future chips class MatmulConfig: @@ -25,19 +26,22 @@ class MatmulConfig: Attributes: num_sms: Number of SMs / CUs on the device. max_block_size: Maximum tile footprint (BLOCK_M * BLOCK_N). - tile_counter: ``int32[1]`` atomic counter for work-stealing. + max_xcds: Maximum number of XCDs (chiplets) supported. + tile_counter: ``int32[max_xcds]`` per-XCD atomic counters for work-stealing. locks: ``uint8[num_sms]`` stream-K lock array. P: ``float32[num_sms, max_block_size]`` stream-K partial buffer. """ - def __init__(self, device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE): + def __init__(self, device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE, + max_xcds: int = _DEFAULT_MAX_XCDS): props = torch.cuda.get_device_properties(device) self.device = device self.num_sms: int = props.multi_processor_count self.max_block_size: int = max_block_size + self.max_xcds: int = max_xcds - # Work-stealing tile counter - self.tile_counter = torch.zeros(1, device=device, dtype=torch.int32) + # Work-stealing per-XCD tile counters + self.tile_counter = torch.zeros(max_xcds, device=device, dtype=torch.int32) # Stream-K buffers self.locks = torch.empty(self.num_sms, device=device, dtype=torch.uint8) @@ -62,11 +66,12 @@ def reset_streamk(self): def __repr__(self): return ( f"MatmulConfig(device={self.device!r}, num_sms={self.num_sms}, " - f"max_block_size={self.max_block_size})" + f"max_block_size={self.max_block_size}, max_xcds={self.max_xcds})" ) -def matmul_preamble(device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE) -> MatmulConfig: +def matmul_preamble(device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLOCK_SIZE, + max_xcds: int = _DEFAULT_MAX_XCDS) -> MatmulConfig: """ Allocate all GPU-side buffers needed by the tritonBLAS GEMM kernels. @@ -76,11 +81,12 @@ def matmul_preamble(device: str = "cuda", max_block_size: int = _DEFAULT_MAX_BLO Args: device: CUDA device string (default ``"cuda"``). max_block_size: Maximum tile footprint (default 65536 = 256*256). + max_xcds: Maximum XCD count for per-XCD counters (default 16). Returns: A :class:`MatmulConfig` ready for kernel launches. """ - return MatmulConfig(device=device, max_block_size=max_block_size) + return MatmulConfig(device=device, max_block_size=max_block_size, max_xcds=max_xcds) # Lazy default config -- created on first use so import-time CUDA init is @@ -168,10 +174,11 @@ def persistent_matmul_lt( chunk_size = min(chunk_size, max(1, total_programs // num_xcds)) if work_stealing: - # Work-stealing: launch grid = num CUs, tiles assigned dynamically. + # Work-stealing: launch grid = num CUs, tiles assigned dynamically + # via per-XCD atomic counters. grids = cfg.num_sms - # Reset the tile counter before each launch. + # Reset all per-XCD tile counters before each launch. cfg.reset_tile_counter() kk = ws_persistent_matmul[(grids,)]( @@ -181,7 +188,7 @@ def persistent_matmul_lt( a_scale if quantized else None, # A_scale_ptr b_scale if quantized else None, # B_scale_ptr None, # TODO: Enable bias. - cfg.tile_counter, # Work-stealing tile counter + cfg.tile_counter, # Per-XCD tile counters (int32[max_xcds]) M, N, K, @@ -198,7 +205,6 @@ def persistent_matmul_lt( GROUP_SIZE_M=gsize_m, NUM_SMS=grids, NUM_XCDS=num_xcds, - CHUNK_SIZE=chunk_size, BIAS=False, EVEN_K=even_k, CACHE_MODIFIER_A=CACHE_MODIFIER_A, diff --git a/tests/test_work_stealing.py b/tests/test_work_stealing.py index 7131c67..c6dacf9 100644 --- a/tests/test_work_stealing.py +++ b/tests/test_work_stealing.py @@ -1,5 +1,6 @@ """ -Standalone smoke-test for the work-stealing persistent GEMM kernel. +Standalone smoke-test for the work-stealing persistent GEMM kernel +with per-XCD atomic counters. Directly imports the work-stealing kernel to avoid the stages/streamk import chain that requires a newer Triton with `constexpr_function`. @@ -64,33 +65,28 @@ def _make_stub_package(fqn, path): ) ws_persistent_matmul = _ws_mod.ws_persistent_matmul +NUM_XCDS = 8 # MI300X + def make_tile_counter(device="cuda"): - """Allocate a fresh work-stealing tile counter.""" - return torch.zeros(1, device=device, dtype=torch.int32) + """Allocate per-XCD work-stealing tile counters.""" + return torch.zeros(NUM_XCDS, device=device, dtype=torch.int32) def run_ws_persistent_matmul(A, B, C, tile_counter, BLK_M=128, BLK_N=128, BLK_K=64, GROUP_M=8): - """Launch the work-stealing persistent kernel.""" + """Launch the work-stealing persistent kernel with per-XCD counters.""" M, K = A.shape _, N = B.shape props = torch.cuda.get_device_properties(A.device) NUM_SMS = props.multi_processor_count - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - total_tiles = total_blocks_M * total_blocks_N even_k = K % BLK_K == 0 - NUM_XCDS = 8 - chunk_size = GROUP_M * GROUP_M - chunk_size = min(chunk_size, max(1, total_tiles // NUM_XCDS)) - # Grid = number of CUs (work-stealing) grids = NUM_SMS - # Reset counter + # Reset all per-XCD counters tile_counter.zero_() ws_persistent_matmul[(grids,)]( @@ -113,7 +109,6 @@ def run_ws_persistent_matmul(A, B, C, tile_counter, BLK_M=128, BLK_N=128, BLK_K= GROUP_SIZE_M=GROUP_M, NUM_SMS=grids, NUM_XCDS=NUM_XCDS, - CHUNK_SIZE=chunk_size, BIAS=False, EVEN_K=even_k, CACHE_MODIFIER_A=None, @@ -183,11 +178,12 @@ def main(): props = torch.cuda.get_device_properties(device) print(f"Device: {props.name} (CUs: {props.multi_processor_count})") print(f"HIP_VISIBLE_DEVICES = {os.environ.get('HIP_VISIBLE_DEVICES', '')}") + print(f"Per-XCD counters: {NUM_XCDS}") print() # ── Correctness ─────────────────────────────────────────────────── print("=" * 68) - print("Correctness (work-stealing kernel vs torch.matmul)") + print("Correctness (per-XCD work-stealing kernel vs torch.matmul)") print("=" * 68) all_pass = True for m, n, k in [ @@ -209,7 +205,7 @@ def main(): # ── Throughput ──────────────────────────────────────────────────── print() print("=" * 68) - print("Throughput (work-stealing kernel)") + print("Throughput (per-XCD work-stealing kernel)") print("=" * 68) for m, n, k in [ (1024, 1024, 1024), From b4ee5b1c7f68580e89ae304cbd9a1a41703618c7 Mon Sep 17 00:00:00 2001 From: Alexandru Dutu Date: Fri, 6 Feb 2026 00:06:26 -0500 Subject: [PATCH 3/8] Adding a test for CPU & GPU atomics --- tests/test_cpu_gpu_atomics.py | 231 ++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tests/test_cpu_gpu_atomics.py diff --git a/tests/test_cpu_gpu_atomics.py b/tests/test_cpu_gpu_atomics.py new file mode 100644 index 0000000..17a96a9 --- /dev/null +++ b/tests/test_cpu_gpu_atomics.py @@ -0,0 +1,231 @@ +""" +Concurrent Triton kernels on AMD/ROCm without deprecated 'stream='. + +Key points: + * Use torch.cuda.Stream() and 'with torch.cuda.stream(s): ...' to choose streams. + * Triton launches are async; measure with explicit synchronization. + * AMD guidance: prefer num_stages >= 2 with the current stream pipeliner, + and size BLOCK_SIZE/num_warps to leave headroom for overlap. + +References: + - PyTorch HIP semantics reuse torch.cuda API on AMD (streams, etc.). + - torch.cuda.StreamContext enqueues ops on the chosen stream. + - Triton kernels run asynchronously; torch.cuda.synchronize() is appropriate. +""" + +import ctypes +import math +import sys +import time + +import array +import numpy as np + +import torch +from hip import hip, hiprtc + +import triton +import triton.language as tl + +def hip_check(call_result): + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + elif ( + isinstance(err, hiprtc.hiprtcResult) + and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS + ): + raise RuntimeError(str(err)) + return result + + +# ------------------------- +# Triton kernels +# ------------------------- + +@triton.jit +def sch(live_flags_ptr, x_ptr, n_elements, + ITERS: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + live = 1 + done = 0 + + while tl.atomic_add(live_flags_ptr + pid, 0) == live: + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + for _ in range(ITERS): + x = x * 1.000000119 + 0.000000137 + + tl.store(x_ptr + offsets, x, mask=mask) + + ret = tl.inline_asm_elementwise( + asm="""s_sleep 128""", + constraints=("=s"), + args=[], + dtype=tl.int64, + is_pure=False, + pack=1, + ) + +@triton.jit +def gemm(x_ptr, n_elements, + ITERS: tl.constexpr, BLOCK_SIZE: tl.constexpr): + # 1D program index + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + # simple compute loop + for _ in range(ITERS): + x = x * 1.000000119 + 0.000000137 + tl.store(x_ptr + offsets, x, mask=mask) + + +@triton.jit +def comm(x_ptr, n_elements, + ITERS: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + for _ in range(ITERS): + x = tl.sin(x) + 1.000000119 * x + 0.000000071 + tl.store(x_ptr + offsets, x, mask=mask) + + +def main(): + # ------------------------- + # Tunables for AMD GPUs + # ------------------------- + N = 32 * 1024 * 1024 # elements + ITERS_A = 200 + ITERS_B = 200 + BLOCK_SIZE = 256 # elements per Triton "program"/CTA + NUM_WARPS = 4 # try 4~8; AMD wavefront is 64-wide + NUM_STAGES = 2 # AMD's current stream pipeliner favors >=2 + + # Load libatomic + libatomic = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libatomic.so.1.2.0") # Adjust path as needed + # Define the function signature + # __atomic_fetch_add_4(int *ptr, int val, int memorder) + libatomic.__atomic_fetch_add_4.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.c_int, ctypes.c_int] + libatomic.__atomic_fetch_add_4.restype = ctypes.c_int + + assert torch.cuda.is_available(), "Need a ROCm-enabled PyTorch build with an AMD GPU" + dev = torch.cuda.current_device() + prop = torch.cuda.get_device_properties(dev) + backend = "HIP" if getattr(torch.version, "hip", None) else "CUDA" + print(f"Device: {prop.name} | backend: {backend} | total_mem: {prop.total_memory/1e9:.1f} GB") + if backend != "HIP": + print("Warning: This script targets AMD/ROCm, but a non-HIP backend was detected.") + + # Data + a = torch.linspace(0, 1, N, device="cuda", dtype=torch.float32).contiguous() + b = torch.linspace(1, 2, N, device="cuda", dtype=torch.float32).contiguous() + + sch_stream = torch.cuda.Stream() + + # Two independent streams (ROCm uses torch.cuda.* too) + gemm_stream = torch.cuda.Stream() + comm_stream = torch.cuda.Stream() + + num_xcds = 8 + sch_grid = num_xcds * BLOCK_SIZE + + done = 0 + live = 1 + + flags = array.array("I", [live for i in range(0, sch_grid)]) + # allocate a Pointer class to be passed to the GPU kernel, flags_h is a void* + flags_h = hip_check(hip.hipHostMalloc(sch_grid * sys.getsizeof(live), 1)) + flags_h.fromObj(flags) # initialize the storage pointed by flags_h + # casting flags_h to a typed pointer to access the pointed contents + flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * sch_grid)) + print(f'Flags (init):') + for i in range(0, sch_grid): + print(f'{flags_typed_ptr.contents[i]}') + + flags_h_np_array = np.ctypeslib.as_array(flags_typed_ptr, shape=(sch_grid,)) + flags_h_tensor = torch.from_numpy(flags_h_np_array) + + print(f'Scheduler kernel started') + with torch.cuda.stream(sch_stream): + sch[(sch_grid, 1, 1)](flags_h_tensor, a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + + # Memory order: 0 = relaxed, 1 = consume, 2 = acquire, 3 = release, 4 = acq_rel, 5 = seq_cst + MEMORDER_RELAXED = 0 + # Stop the scheduler kernel + print(f'Flags (__atomic_fetch_add flags_h to signal the GPU kernel to proceed):') + for i in range(0, sch_grid): + ptr = ctypes.cast(ctypes.byref(flags_typed_ptr.contents, i * ctypes.sizeof(ctypes.c_int)), ctypes.POINTER(ctypes.c_int)) + prev = libatomic.__atomic_fetch_add_4(ptr, done, MEMORDER_RELAXED) + print(f'{prev} {flags_typed_ptr.contents[i]}') + + sch_stream.synchronize() + print(f'Scheduler kernel done') + + # Grid: one program per BLOCK_SIZE chunk + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + + # ------------------------- + # Warm-up (JIT & cache) + # ------------------------- + with torch.cuda.stream(gemm_stream): + gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + with torch.cuda.stream(comm_stream): + comm[grid](b, N, ITERS=ITERS_B, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + # Triton is async; explicit sync for clean timing + gemm_stream.synchronize(); comm_stream.synchronize() + print("Warm-up complete.\n") + + # ------------------------- + # Sequential timing (A then B) + # ------------------------- + t0 = time.perf_counter() + with torch.cuda.stream(gemm_stream): + gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + gemm_stream.synchronize() + with torch.cuda.stream(comm_stream): + comm[grid](b, N, ITERS=ITERS_B, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + comm_stream.synchronize() + t_seq = time.perf_counter() - t0 + print(f"Sequential total time: {t_seq:.3f} s") + + # ------------------------- + # Concurrent timing (A || B) + # ------------------------- + t0 = time.perf_counter() + with torch.cuda.stream(gemm_stream): + gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + with torch.cuda.stream(comm_stream): + comm[grid](b, N, ITERS=ITERS_B, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + # wait for both + gemm_stream.synchronize(); comm_stream.synchronize() + t_conc = time.perf_counter() - t0 + print(f"Concurrent total time: {t_conc:.3f} s") + + # Check a couple of results + print("\nResults (samples):") + print(f"A[123456] = {a[123_456].item():.6f}") + print(f"B[234567] = {b[234_567].item():.6f}") + + print("\nTip: If there's little or no overlap, reduce NUM_WARPS or BLOCK_SIZE " + "to leave headroom for both kernels to co-reside on the GPU. " + "On AMD, keep num_stages>=2 for the current stream pipeliner.") + + +if __name__ == "__main__": + main() From 7d29bf04821ebb7c5f0a414c885a6fdf3f0802c2 Mon Sep 17 00:00:00 2001 From: Alexandru Dutu Date: Sun, 8 Feb 2026 01:18:25 +0000 Subject: [PATCH 4/8] Changing to one live flag per XCD & cleanup --- tests/test_cpu_gpu_atomics.py | 96 ++++++++++++++--------------------- 1 file changed, 39 insertions(+), 57 deletions(-) diff --git a/tests/test_cpu_gpu_atomics.py b/tests/test_cpu_gpu_atomics.py index 17a96a9..cb5c89d 100644 --- a/tests/test_cpu_gpu_atomics.py +++ b/tests/test_cpu_gpu_atomics.py @@ -1,18 +1,3 @@ -""" -Concurrent Triton kernels on AMD/ROCm without deprecated 'stream='. - -Key points: - * Use torch.cuda.Stream() and 'with torch.cuda.stream(s): ...' to choose streams. - * Triton launches are async; measure with explicit synchronization. - * AMD guidance: prefer num_stages >= 2 with the current stream pipeliner, - and size BLOCK_SIZE/num_warps to leave headroom for overlap. - -References: - - PyTorch HIP semantics reuse torch.cuda API on AMD (streams, etc.). - - torch.cuda.StreamContext enqueues ops on the chosen stream. - - Triton kernels run asynchronously; torch.cuda.synchronize() is appropriate. -""" - import ctypes import math import sys @@ -41,27 +26,30 @@ def hip_check(call_result): raise RuntimeError(str(err)) return result - -# ------------------------- # Triton kernels -# ------------------------- - @triton.jit -def sch(live_flags_ptr, x_ptr, n_elements, - ITERS: tl.constexpr, BLOCK_SIZE: tl.constexpr): +def sch(live_flags_ptr, x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tid = tl.arange(0, BLOCK_SIZE) + offsets = pid * BLOCK_SIZE + tid mask = offsets < n_elements + sync_mask = tid == 0 + live = 1 done = 0 - while tl.atomic_add(live_flags_ptr + pid, 0) == live: + live_flags_ptr_scalar = live_flags_ptr + pid + live_flags_ptr_block = tl.broadcast_to(live_flags_ptr_scalar, [BLOCK_SIZE]) + zero_i32 = tl.full([BLOCK_SIZE], 0, dtype=tl.int32) - x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - for _ in range(ITERS): - x = x * 1.000000119 + 0.000000137 + flag_v = tl.atomic_add(live_flags_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="cta") + flag_only_leader = tl.where(sync_mask, flag_v, tl.zeros([BLOCK_SIZE], dtype=flag_v.dtype)) + flag = tl.sum(flag_only_leader, axis=0) + while flag == live: + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + x += 1 tl.store(x_ptr + offsets, x, mask=mask) ret = tl.inline_asm_elementwise( @@ -73,6 +61,10 @@ def sch(live_flags_ptr, x_ptr, n_elements, pack=1, ) + flag_v = tl.atomic_add(live_flags_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="cta") + flag_only_leader = tl.where(sync_mask, flag_v, tl.zeros([BLOCK_SIZE], dtype=flag_v.dtype)) + flag = tl.sum(flag_only_leader, axis=0) + @triton.jit def gemm(x_ptr, n_elements, ITERS: tl.constexpr, BLOCK_SIZE: tl.constexpr): @@ -100,19 +92,17 @@ def comm(x_ptr, n_elements, def main(): - # ------------------------- - # Tunables for AMD GPUs - # ------------------------- N = 32 * 1024 * 1024 # elements ITERS_A = 200 ITERS_B = 200 - BLOCK_SIZE = 256 # elements per Triton "program"/CTA - NUM_WARPS = 4 # try 4~8; AMD wavefront is 64-wide + BLOCK_SIZE_SCH = 256 # elements per scheduler WGs + BLOCK_SIZE = 256 # elements per WG + NUM_WARPS = 4 # try 4~8; MI300X wavefront is 64-wide NUM_STAGES = 2 # AMD's current stream pipeliner favors >=2 # Load libatomic libatomic = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libatomic.so.1.2.0") # Adjust path as needed - # Define the function signature + # Function signature for __atomic_fetch_add_4 # __atomic_fetch_add_4(int *ptr, int val, int memorder) libatomic.__atomic_fetch_add_4.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.c_int, ctypes.c_int] libatomic.__atomic_fetch_add_4.restype = ctypes.c_int @@ -131,65 +121,64 @@ def main(): sch_stream = torch.cuda.Stream() - # Two independent streams (ROCm uses torch.cuda.* too) + # Independent streams gemm_stream = torch.cuda.Stream() comm_stream = torch.cuda.Stream() num_xcds = 8 - sch_grid = num_xcds * BLOCK_SIZE + sch_grid = num_xcds done = 0 live = 1 - flags = array.array("I", [live for i in range(0, sch_grid)]) - # allocate a Pointer class to be passed to the GPU kernel, flags_h is a void* - flags_h = hip_check(hip.hipHostMalloc(sch_grid * sys.getsizeof(live), 1)) - flags_h.fromObj(flags) # initialize the storage pointed by flags_h - # casting flags_h to a typed pointer to access the pointed contents + # Flags passed to the GPU kernel, flags_h is a void* + flags_h = hip_check(hip.hipMalloc(sch_grid * sys.getsizeof(live))) + # Casting flags_h to a typed pointer, for content access flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * sch_grid)) print(f'Flags (init):') for i in range(0, sch_grid): + flags_typed_ptr.contents[i] = live print(f'{flags_typed_ptr.contents[i]}') flags_h_np_array = np.ctypeslib.as_array(flags_typed_ptr, shape=(sch_grid,)) flags_h_tensor = torch.from_numpy(flags_h_np_array) + sch_comp = torch.ones(num_xcds * BLOCK_SIZE, device="cuda", dtype=torch.float32).contiguous() + print(f'Scheduler kernel started') with torch.cuda.stream(sch_stream): - sch[(sch_grid, 1, 1)](flags_h_tensor, a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + sch[(sch_grid, 1, 1)](flags_h_tensor, sch_comp, num_xcds * BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE_SCH, num_warps=NUM_WARPS, num_stages=NUM_STAGES) + time.sleep(1) + # Memory order: 0 = relaxed, 1 = consume, 2 = acquire, 3 = release, 4 = acq_rel, 5 = seq_cst MEMORDER_RELAXED = 0 - # Stop the scheduler kernel + # Signal the scheduler kernel to complete print(f'Flags (__atomic_fetch_add flags_h to signal the GPU kernel to proceed):') for i in range(0, sch_grid): ptr = ctypes.cast(ctypes.byref(flags_typed_ptr.contents, i * ctypes.sizeof(ctypes.c_int)), ctypes.POINTER(ctypes.c_int)) - prev = libatomic.__atomic_fetch_add_4(ptr, done, MEMORDER_RELAXED) + prev = libatomic.__atomic_fetch_add_4(ptr, -live, MEMORDER_RELAXED) print(f'{prev} {flags_typed_ptr.contents[i]}') sch_stream.synchronize() print(f'Scheduler kernel done') + print(f'sch_comp[0]: {sch_comp[0]}') - # Grid: one program per BLOCK_SIZE chunk grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) - # ------------------------- # Warm-up (JIT & cache) - # ------------------------- with torch.cuda.stream(gemm_stream): gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) with torch.cuda.stream(comm_stream): comm[grid](b, N, ITERS=ITERS_B, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) - # Triton is async; explicit sync for clean timing + gemm_stream.synchronize(); comm_stream.synchronize() print("Warm-up complete.\n") - # ------------------------- # Sequential timing (A then B) - # ------------------------- t0 = time.perf_counter() with torch.cuda.stream(gemm_stream): gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, @@ -202,9 +191,7 @@ def main(): t_seq = time.perf_counter() - t0 print(f"Sequential total time: {t_seq:.3f} s") - # ------------------------- # Concurrent timing (A || B) - # ------------------------- t0 = time.perf_counter() with torch.cuda.stream(gemm_stream): gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, @@ -219,13 +206,8 @@ def main(): # Check a couple of results print("\nResults (samples):") - print(f"A[123456] = {a[123_456].item():.6f}") - print(f"B[234567] = {b[234_567].item():.6f}") - - print("\nTip: If there's little or no overlap, reduce NUM_WARPS or BLOCK_SIZE " - "to leave headroom for both kernels to co-reside on the GPU. " - "On AMD, keep num_stages>=2 for the current stream pipeliner.") - + print(f"A[0] = {a[0].item():.6f}") + print(f"B[0] = {b[0].item():.6f}") if __name__ == "__main__": main() From 28ca8c8eb4485866e636cf11b691a5a6ee75b137 Mon Sep 17 00:00:00 2001 From: Alexandru Dutu Date: Sun, 8 Feb 2026 13:30:08 -0500 Subject: [PATCH 5/8] Request CU release from the scheduler kernel --- tests/test_cpu_gpu_atomics.py | 179 +++++++++++++++++++++++----------- 1 file changed, 122 insertions(+), 57 deletions(-) diff --git a/tests/test_cpu_gpu_atomics.py b/tests/test_cpu_gpu_atomics.py index cb5c89d..5087a21 100644 --- a/tests/test_cpu_gpu_atomics.py +++ b/tests/test_cpu_gpu_atomics.py @@ -28,7 +28,7 @@ def hip_check(call_result): # Triton kernels @triton.jit -def sch(live_flags_ptr, x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): +def sch(live_flags_ptr, req_res_ptr, req_wgs_ptr, num_wgs_per_xcd, x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) tid = tl.arange(0, BLOCK_SIZE) offsets = pid * BLOCK_SIZE + tid @@ -43,7 +43,7 @@ def sch(live_flags_ptr, x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): live_flags_ptr_block = tl.broadcast_to(live_flags_ptr_scalar, [BLOCK_SIZE]) zero_i32 = tl.full([BLOCK_SIZE], 0, dtype=tl.int32) - flag_v = tl.atomic_add(live_flags_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="cta") + flag_v = tl.atomic_add(live_flags_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="sys") flag_only_leader = tl.where(sync_mask, flag_v, tl.zeros([BLOCK_SIZE], dtype=flag_v.dtype)) flag = tl.sum(flag_only_leader, axis=0) @@ -52,6 +52,19 @@ def sch(live_flags_ptr, x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): x += 1 tl.store(x_ptr + offsets, x, mask=mask) + # check reqs for release of resources + req_res_ptr_scalar = req_res_ptr + pid + req_res_ptr_block = tl.broadcast_to(req_res_ptr_scalar, [BLOCK_SIZE]) + + req_res_v = tl.atomic_add(req_res_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="gpu") + req_res_only_leader = tl.where(sync_mask, req_res_v, tl.zeros([BLOCK_SIZE], dtype=flag_v.dtype)) + req_res = tl.sum(req_res_only_leader, axis=0) + if req_res != 0: + for i in range(req_res): + req_wgs_ptr_scalar = req_wgs_ptr + pid * num_wgs_per_xcd + i + req_wgs_ptr_block = tl.broadcast_to(req_wgs_ptr_scalar, [BLOCK_SIZE]) + tl.atomic_add(req_wgs_ptr_block, 1, mask=sync_mask, sem="acquire", scope="gpu") + ret = tl.inline_asm_elementwise( asm="""s_sleep 128""", constraints=("=s"), @@ -61,21 +74,37 @@ def sch(live_flags_ptr, x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pack=1, ) - flag_v = tl.atomic_add(live_flags_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="cta") + flag_v = tl.atomic_add(live_flags_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="sys") flag_only_leader = tl.where(sync_mask, flag_v, tl.zeros([BLOCK_SIZE], dtype=flag_v.dtype)) flag = tl.sum(flag_only_leader, axis=0) @triton.jit -def gemm(x_ptr, n_elements, - ITERS: tl.constexpr, BLOCK_SIZE: tl.constexpr): +def gemm(req_wgs_ptr, num_wgs_per_xcd, x_ptr, n_elements, + ITERS: tl.constexpr, ITERS_PER_CHKPNT: tl.constexpr, BLOCK_SIZE: tl.constexpr): # 1D program index pid = tl.program_id(axis=0) + tid = tl.arange(0, BLOCK_SIZE) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements + sync_mask = tid == 0 + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + i = 0; req_wgs = 0 # simple compute loop - for _ in range(ITERS): + while req_wgs == 0 and i < ITERS: + for _ in range(ITERS_PER_CHKPNT): + x = x * 1.000000119 + 0.000000137 + i += 1 + + req_wgs_ptr_scalar = req_wgs_ptr + pid * num_wgs_per_xcd + pid % num_wgs_per_xcd + req_wgs_ptr_block = tl.broadcast_to(req_wgs_ptr_scalar, [BLOCK_SIZE]) + req_wgs_v = tl.atomic_add(req_wgs_ptr_block, 0, mask=sync_mask, sem="acquire", scope="gpu") # read + req_wgs_only_leader = tl.where(sync_mask, req_wgs_v, tl.zeros([BLOCK_SIZE], dtype=req_wgs_v.dtype)) + req_wgs = tl.sum(req_wgs_only_leader, axis=0) + for _ in range(ITERS_PER_CHKPNT): x = x * 1.000000119 + 0.000000137 + i += 1 + tl.store(x_ptr + offsets, x, mask=mask) @@ -93,12 +122,13 @@ def comm(x_ptr, n_elements, def main(): N = 32 * 1024 * 1024 # elements - ITERS_A = 200 - ITERS_B = 200 - BLOCK_SIZE_SCH = 256 # elements per scheduler WGs - BLOCK_SIZE = 256 # elements per WG - NUM_WARPS = 4 # try 4~8; MI300X wavefront is 64-wide - NUM_STAGES = 2 # AMD's current stream pipeliner favors >=2 + ITERS_GEMM = 1024 * 1024 * 1024 + ITERS_PER_CHKPNT = 256 + ITERS_COMM = 1024 + BLOCK_SIZE_SCH = 256 # scheduler WG size + BLOCK_SIZE = 256 # WG size + NUM_WARPS = 4 + NUM_STAGES = 2 # Load libatomic libatomic = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libatomic.so.1.2.0") # Adjust path as needed @@ -111,7 +141,9 @@ def main(): dev = torch.cuda.current_device() prop = torch.cuda.get_device_properties(dev) backend = "HIP" if getattr(torch.version, "hip", None) else "CUDA" - print(f"Device: {prop.name} | backend: {backend} | total_mem: {prop.total_memory/1e9:.1f} GB") + hip_prop = hip.hipDeviceProp_t(); hip.hipGetDeviceProperties(hip_prop, 0) + num_cus = prop.multi_processor_count + print(f"Device: {prop.name} | num_cus: {num_cus} | backend: {backend} | total_mem: {prop.total_memory/1e9:.1f} GB") if backend != "HIP": print("Warning: This script targets AMD/ROCm, but a non-HIP backend was detected.") @@ -126,53 +158,24 @@ def main(): comm_stream = torch.cuda.Stream() num_xcds = 8 - sch_grid = num_xcds - - done = 0 - live = 1 - - # Flags passed to the GPU kernel, flags_h is a void* - flags_h = hip_check(hip.hipMalloc(sch_grid * sys.getsizeof(live))) - # Casting flags_h to a typed pointer, for content access - flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * sch_grid)) - print(f'Flags (init):') - for i in range(0, sch_grid): - flags_typed_ptr.contents[i] = live - print(f'{flags_typed_ptr.contents[i]}') - - flags_h_np_array = np.ctypeslib.as_array(flags_typed_ptr, shape=(sch_grid,)) - flags_h_tensor = torch.from_numpy(flags_h_np_array) - - sch_comp = torch.ones(num_xcds * BLOCK_SIZE, device="cuda", dtype=torch.float32).contiguous() - - print(f'Scheduler kernel started') - with torch.cuda.stream(sch_stream): - sch[(sch_grid, 1, 1)](flags_h_tensor, sch_comp, num_xcds * BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE_SCH, - num_warps=NUM_WARPS, num_stages=NUM_STAGES) - - time.sleep(1) - - # Memory order: 0 = relaxed, 1 = consume, 2 = acquire, 3 = release, 4 = acq_rel, 5 = seq_cst - MEMORDER_RELAXED = 0 - # Signal the scheduler kernel to complete - print(f'Flags (__atomic_fetch_add flags_h to signal the GPU kernel to proceed):') - for i in range(0, sch_grid): - ptr = ctypes.cast(ctypes.byref(flags_typed_ptr.contents, i * ctypes.sizeof(ctypes.c_int)), ctypes.POINTER(ctypes.c_int)) - prev = libatomic.__atomic_fetch_add_4(ptr, -live, MEMORDER_RELAXED) - print(f'{prev} {flags_typed_ptr.contents[i]}') + num_wgs_per_xcd = 31 # assuming MI355X, 32 WGs / XCD, 1 for the sch kernel and 31 persistent + num_sch_wgs = num_xcds - sch_stream.synchronize() - print(f'Scheduler kernel done') - print(f'sch_comp[0]: {sch_comp[0]}') + req_wgs_ptr = torch.zeros(num_xcds * num_wgs_per_xcd, device="cuda", dtype=torch.int32).contiguous() - grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + # this deadlocks because the next two kernels are taking up all the CUs + # grid = (num_cus, 1, 1) + grid = (num_cus - num_sch_wgs, 1, 1) + grid_comm = (num_xcds, 1, 1) + grid_sch = (num_sch_wgs, 1, 1) + print("Warm-up\n") # Warm-up (JIT & cache) with torch.cuda.stream(gemm_stream): - gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + gemm[grid](req_wgs_ptr, num_wgs_per_xcd, a, N, ITERS=ITERS_GEMM, ITERS_PER_CHKPNT=ITERS_PER_CHKPNT, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) with torch.cuda.stream(comm_stream): - comm[grid](b, N, ITERS=ITERS_B, BLOCK_SIZE=BLOCK_SIZE, + comm[grid_comm](b, N, ITERS=ITERS_COMM, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) gemm_stream.synchronize(); comm_stream.synchronize() @@ -181,11 +184,11 @@ def main(): # Sequential timing (A then B) t0 = time.perf_counter() with torch.cuda.stream(gemm_stream): - gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + gemm[grid](req_wgs_ptr, num_wgs_per_xcd, a, N, ITERS=ITERS_GEMM, ITERS_PER_CHKPNT=ITERS_PER_CHKPNT, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) gemm_stream.synchronize() with torch.cuda.stream(comm_stream): - comm[grid](b, N, ITERS=ITERS_B, BLOCK_SIZE=BLOCK_SIZE, + comm[grid_comm](b, N, ITERS=ITERS_COMM, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) comm_stream.synchronize() t_seq = time.perf_counter() - t0 @@ -194,10 +197,12 @@ def main(): # Concurrent timing (A || B) t0 = time.perf_counter() with torch.cuda.stream(gemm_stream): - gemm[grid](a, N, ITERS=ITERS_A, BLOCK_SIZE=BLOCK_SIZE, + gemm[grid](req_wgs_ptr, num_wgs_per_xcd, a, N, ITERS=ITERS_GEMM, ITERS_PER_CHKPNT=ITERS_PER_CHKPNT, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) + # Request the GEMM kernel to release of 1 CU per XCD + with torch.cuda.stream(comm_stream): - comm[grid](b, N, ITERS=ITERS_B, BLOCK_SIZE=BLOCK_SIZE, + comm[grid_comm](b, N, ITERS=ITERS_COMM, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) # wait for both gemm_stream.synchronize(); comm_stream.synchronize() @@ -209,5 +214,65 @@ def main(): print(f"A[0] = {a[0].item():.6f}") print(f"B[0] = {b[0].item():.6f}") + done = 0 + live = 1 + + # Flags passed to the GPU kernel, flags_h is a void* + flags_h = hip_check(hip.hipMalloc(num_sch_wgs * sys.getsizeof(live))) + # Casting flags_h to a typed pointer, for content access + flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * num_sch_wgs)) + print(f'Flags (init):') + for i in range(0, num_sch_wgs): + flags_typed_ptr.contents[i] = live + print(f'{flags_typed_ptr.contents[i]}') + + flags_h_np_array = np.ctypeslib.as_array(flags_typed_ptr, shape=(num_sch_wgs,)) + flags_h_tensor = torch.from_numpy(flags_h_np_array) + + # Requested resources passed to the GPU kernel, req_res_h is a void* + req_res_h = hip_check(hip.hipMalloc(num_sch_wgs * sys.getsizeof(live))) + # Casting req_res_h to a typed pointer, for content access + req_res_typed_ptr = ctypes.cast(req_res_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * num_sch_wgs)) + print(f'Flags (init):') + for i in range(0, num_sch_wgs): + req_res_typed_ptr.contents[i] = 0 + print(f'{req_res_typed_ptr.contents[i]}') + + req_res_h_np_array = np.ctypeslib.as_array(req_res_typed_ptr, shape=(num_sch_wgs,)) + req_res_h_tensor = torch.from_numpy(req_res_h_np_array) + + sch_comp = torch.ones(num_xcds * BLOCK_SIZE, device="cuda", dtype=torch.float32).contiguous() + + print(f'Scheduler kernel started') + with torch.cuda.stream(sch_stream): + sch[grid_sch](flags_h_tensor, req_res_h_tensor, req_wgs_ptr, num_wgs_per_xcd, sch_comp, num_xcds * BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE_SCH, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + + # Concurrent timing (A || B) + t0 = time.perf_counter() + with torch.cuda.stream(gemm_stream): + gemm[grid](req_wgs_ptr, num_wgs_per_xcd, a, N, ITERS=ITERS_GEMM, ITERS_PER_CHKPNT=ITERS_PER_CHKPNT, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + with torch.cuda.stream(comm_stream): + comm[grid_comm](b, N, ITERS=ITERS_COMM, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, num_stages=NUM_STAGES) + # wait for both + gemm_stream.synchronize(); comm_stream.synchronize() + t_conc = time.perf_counter() - t0 + print(f"With scheduler kernel concurrent total time: {t_conc:.3f} s") + + # Memory order: 0 = relaxed, 1 = consume, 2 = acquire, 3 = release, 4 = acq_rel, 5 = seq_cst + MEMORDER_RELAXED = 0 + # Signal the scheduler kernel to complete + print(f'Flags (__atomic_fetch_add flags_h to signal the GPU kernel to proceed):') + for i in range(0, num_sch_wgs): + ptr = ctypes.cast(ctypes.byref(flags_typed_ptr.contents, i * ctypes.sizeof(ctypes.c_int)), ctypes.POINTER(ctypes.c_int)) + prev = libatomic.__atomic_fetch_add_4(ptr, -live, MEMORDER_RELAXED) + print(f'{prev} {flags_typed_ptr.contents[i]}') + + sch_stream.synchronize() + print(f'Scheduler kernel done') + print(f'sch_comp[0]: {sch_comp[0]}') + if __name__ == "__main__": main() From a65daee676a94c8354301c6b69ddf35d0abb908c Mon Sep 17 00:00:00 2001 From: Alexandru Dutu Date: Mon, 9 Feb 2026 12:43:12 -0500 Subject: [PATCH 6/8] Removing extra loop --- tests/test_cpu_gpu_atomics.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_cpu_gpu_atomics.py b/tests/test_cpu_gpu_atomics.py index 5087a21..e27f28f 100644 --- a/tests/test_cpu_gpu_atomics.py +++ b/tests/test_cpu_gpu_atomics.py @@ -101,9 +101,6 @@ def gemm(req_wgs_ptr, num_wgs_per_xcd, x_ptr, n_elements, req_wgs_v = tl.atomic_add(req_wgs_ptr_block, 0, mask=sync_mask, sem="acquire", scope="gpu") # read req_wgs_only_leader = tl.where(sync_mask, req_wgs_v, tl.zeros([BLOCK_SIZE], dtype=req_wgs_v.dtype)) req_wgs = tl.sum(req_wgs_only_leader, axis=0) - for _ in range(ITERS_PER_CHKPNT): - x = x * 1.000000119 + 0.000000137 - i += 1 tl.store(x_ptr + offsets, x, mask=mask) From 3a675e84aa6ecb4a008cc9b3f7b4de6212c533fe Mon Sep 17 00:00:00 2001 From: Alexandru Dutu Date: Mon, 9 Feb 2026 13:17:09 -0500 Subject: [PATCH 7/8] Adding req to release CUs --- tests/test_cpu_gpu_atomics.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/test_cpu_gpu_atomics.py b/tests/test_cpu_gpu_atomics.py index e27f28f..8b8e7f7 100644 --- a/tests/test_cpu_gpu_atomics.py +++ b/tests/test_cpu_gpu_atomics.py @@ -176,7 +176,7 @@ def main(): num_warps=NUM_WARPS, num_stages=NUM_STAGES) gemm_stream.synchronize(); comm_stream.synchronize() - print("Warm-up complete.\n") + print("Warm-up complete\n") # Sequential timing (A then B) t0 = time.perf_counter() @@ -218,7 +218,7 @@ def main(): flags_h = hip_check(hip.hipMalloc(num_sch_wgs * sys.getsizeof(live))) # Casting flags_h to a typed pointer, for content access flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * num_sch_wgs)) - print(f'Flags (init):') + print(f'Scheduler live flags (init):') for i in range(0, num_sch_wgs): flags_typed_ptr.contents[i] = live print(f'{flags_typed_ptr.contents[i]}') @@ -226,11 +226,11 @@ def main(): flags_h_np_array = np.ctypeslib.as_array(flags_typed_ptr, shape=(num_sch_wgs,)) flags_h_tensor = torch.from_numpy(flags_h_np_array) - # Requested resources passed to the GPU kernel, req_res_h is a void* + # Sync var used to request the release of resources (CUs), passed to the GPU kernel, req_res_h is a void* req_res_h = hip_check(hip.hipMalloc(num_sch_wgs * sys.getsizeof(live))) # Casting req_res_h to a typed pointer, for content access req_res_typed_ptr = ctypes.cast(req_res_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * num_sch_wgs)) - print(f'Flags (init):') + print(f'Request release sync vars (init):') for i in range(0, num_sch_wgs): req_res_typed_ptr.contents[i] = 0 print(f'{req_res_typed_ptr.contents[i]}') @@ -250,6 +250,16 @@ def main(): with torch.cuda.stream(gemm_stream): gemm[grid](req_wgs_ptr, num_wgs_per_xcd, a, N, ITERS=ITERS_GEMM, ITERS_PER_CHKPNT=ITERS_PER_CHKPNT, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) + # Memory order: 0 = relaxed, 1 = consume, 2 = acquire, 3 = release, 4 = acq_rel, 5 = seq_cst + MEMORDER_RELAXED = 0 + # Signal the scheduler kernel to complete + print(f'Req the GPU scheduler kernel to release CUs (__atomic_fetch_add req_res_h):') + for i in range(0, num_sch_wgs): + ptr = ctypes.cast(ctypes.byref(req_res_typed_ptr.contents, i * ctypes.sizeof(ctypes.c_int)), ctypes.POINTER(ctypes.c_int)) + comm_wgs = grid_comm[0] * grid_comm[1] * grid_comm[2] + wgs_to_release = comm_wgs // num_sch_wgs + prev = libatomic.__atomic_fetch_add_4(ptr, wgs_to_release, MEMORDER_RELAXED) + print(f'{prev} {flags_typed_ptr.contents[i]}') with torch.cuda.stream(comm_stream): comm[grid_comm](b, N, ITERS=ITERS_COMM, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) @@ -258,10 +268,8 @@ def main(): t_conc = time.perf_counter() - t0 print(f"With scheduler kernel concurrent total time: {t_conc:.3f} s") - # Memory order: 0 = relaxed, 1 = consume, 2 = acquire, 3 = release, 4 = acq_rel, 5 = seq_cst - MEMORDER_RELAXED = 0 # Signal the scheduler kernel to complete - print(f'Flags (__atomic_fetch_add flags_h to signal the GPU kernel to proceed):') + print(f'Signal the GPU scheduler kernel to stop (__atomic_fetch_add flags_h):') for i in range(0, num_sch_wgs): ptr = ctypes.cast(ctypes.byref(flags_typed_ptr.contents, i * ctypes.sizeof(ctypes.c_int)), ctypes.POINTER(ctypes.c_int)) prev = libatomic.__atomic_fetch_add_4(ptr, -live, MEMORDER_RELAXED) From c57b885467bf12c29c68fa06feef9f0a8977b30a Mon Sep 17 00:00:00 2001 From: Alexandru Dutu Date: Mon, 9 Feb 2026 14:01:43 -0500 Subject: [PATCH 8/8] Reseting req rel sync vars --- tests/test_cpu_gpu_atomics.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/test_cpu_gpu_atomics.py b/tests/test_cpu_gpu_atomics.py index 8b8e7f7..0f2cbdf 100644 --- a/tests/test_cpu_gpu_atomics.py +++ b/tests/test_cpu_gpu_atomics.py @@ -56,7 +56,7 @@ def sch(live_flags_ptr, req_res_ptr, req_wgs_ptr, num_wgs_per_xcd, x_ptr, n_elem req_res_ptr_scalar = req_res_ptr + pid req_res_ptr_block = tl.broadcast_to(req_res_ptr_scalar, [BLOCK_SIZE]) - req_res_v = tl.atomic_add(req_res_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="gpu") + req_res_v = tl.atomic_add(req_res_ptr_block, zero_i32, mask=sync_mask, sem="acquire", scope="sys") req_res_only_leader = tl.where(sync_mask, req_res_v, tl.zeros([BLOCK_SIZE], dtype=flag_v.dtype)) req_res = tl.sum(req_res_only_leader, axis=0) if req_res != 0: @@ -64,6 +64,7 @@ def sch(live_flags_ptr, req_res_ptr, req_wgs_ptr, num_wgs_per_xcd, x_ptr, n_elem req_wgs_ptr_scalar = req_wgs_ptr + pid * num_wgs_per_xcd + i req_wgs_ptr_block = tl.broadcast_to(req_wgs_ptr_scalar, [BLOCK_SIZE]) tl.atomic_add(req_wgs_ptr_block, 1, mask=sync_mask, sem="acquire", scope="gpu") + tl.atomic_add(req_res_ptr_block, -req_res, mask=sync_mask, sem="acquire", scope="sys") ret = tl.inline_asm_elementwise( asm="""s_sleep 128""", @@ -101,6 +102,8 @@ def gemm(req_wgs_ptr, num_wgs_per_xcd, x_ptr, n_elements, req_wgs_v = tl.atomic_add(req_wgs_ptr_block, 0, mask=sync_mask, sem="acquire", scope="gpu") # read req_wgs_only_leader = tl.where(sync_mask, req_wgs_v, tl.zeros([BLOCK_SIZE], dtype=req_wgs_v.dtype)) req_wgs = tl.sum(req_wgs_only_leader, axis=0) + if req_wgs > 0: + tl.atomic_add(req_wgs_ptr_block, -req_wgs, mask=sync_mask, sem="acquire", scope="gpu") # reset tl.store(x_ptr + offsets, x, mask=mask) @@ -120,7 +123,7 @@ def comm(x_ptr, n_elements, def main(): N = 32 * 1024 * 1024 # elements ITERS_GEMM = 1024 * 1024 * 1024 - ITERS_PER_CHKPNT = 256 + ITERS_PER_CHKPNT = 1024 * 1024 ITERS_COMM = 1024 BLOCK_SIZE_SCH = 256 # scheduler WG size BLOCK_SIZE = 256 # WG size @@ -218,7 +221,7 @@ def main(): flags_h = hip_check(hip.hipMalloc(num_sch_wgs * sys.getsizeof(live))) # Casting flags_h to a typed pointer, for content access flags_typed_ptr = ctypes.cast(flags_h.as_c_void_p(), ctypes.POINTER(ctypes.c_int * num_sch_wgs)) - print(f'Scheduler live flags (init):') + print(f'\nScheduler live flags (init):') for i in range(0, num_sch_wgs): flags_typed_ptr.contents[i] = live print(f'{flags_typed_ptr.contents[i]}') @@ -259,14 +262,18 @@ def main(): comm_wgs = grid_comm[0] * grid_comm[1] * grid_comm[2] wgs_to_release = comm_wgs // num_sch_wgs prev = libatomic.__atomic_fetch_add_4(ptr, wgs_to_release, MEMORDER_RELAXED) - print(f'{prev} {flags_typed_ptr.contents[i]}') + print(f'{prev} {req_res_typed_ptr.contents[i]}') with torch.cuda.stream(comm_stream): comm[grid_comm](b, N, ITERS=ITERS_COMM, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, num_stages=NUM_STAGES) # wait for both gemm_stream.synchronize(); comm_stream.synchronize() t_conc = time.perf_counter() - t0 - print(f"With scheduler kernel concurrent total time: {t_conc:.3f} s") + print(f"Concurrent total time w/ scheduler kernel: {t_conc:.3f} s\n") + + print(f'Read req release sync vars:') + for i in range(0, num_sch_wgs): + print(f'{req_res_typed_ptr.contents[i]}') # Signal the scheduler kernel to complete print(f'Signal the GPU scheduler kernel to stop (__atomic_fetch_add flags_h):')