diff --git a/benchmarks/common.py b/benchmarks/common.py new file mode 100644 index 0000000..f325918 --- /dev/null +++ b/benchmarks/common.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +Shared GPU/CU utilities for TritonBLAS benchmarks. +""" +import ctypes + +import torch # type: ignore + + +def get_num_xcds(device_id: int = 0) -> int: + """Query the number of XCDs (chiplets) via HIP runtime.""" + try: + hip = ctypes.cdll.LoadLibrary("libamdhip64.so") + except OSError: + return 1 + try: + hipDeviceAttributeNumberOfXccs = 10018 + xcc_count = ctypes.c_int() + hip.hipDeviceGetAttribute(ctypes.byref(xcc_count), hipDeviceAttributeNumberOfXccs, device_id) + return xcc_count.value + except Exception: + return 1 + + +def get_cu_info(device_id: int = 0): + """Return (total_cus, num_xcds, cus_per_xcd) for the current device.""" + total_cus = torch.cuda.get_device_properties(device_id).multi_processor_count + num_xcds = get_num_xcds(device_id) + cus_per_xcd = total_cus // num_xcds + return total_cus, num_xcds, cus_per_xcd + + +def build_balanced_hex_mask(remove_per_xcd: int, num_xcds: int, cus_per_xcd: int) -> str: + """Build a ROC_GLOBAL_CU_MASK hex string that removes CUs from the top of every XCD.""" + if remove_per_xcd == 0: + return "" + total_cus = num_xcds * cus_per_xcd + mask = (1 << total_cus) - 1 + for i in range(remove_per_xcd): + base = num_xcds * i + for xcd in range(num_xcds): + bit = base + xcd + mask &= ~(1 << bit) + return f"0x{mask:x}" diff --git a/benchmarks/torch_matmul.py b/benchmarks/torch_matmul.py index 6b3ec2a..ca9b49c 100644 --- a/benchmarks/torch_matmul.py +++ b/benchmarks/torch_matmul.py @@ -1,9 +1,22 @@ #!/usr/bin/env python3 +""" +Benchmark torch.matmul performance. + +Supports an optional --cu-sweep mode (MI300X) that re-invokes this script as +subprocesses with ROC_GLOBAL_CU_MASK set, producing results with an +``active_cus`` column. +""" +import json +import os +import subprocess +import sys import yaml import argparse import torch import csv +from common import get_cu_info, build_balanced_hex_mask + def str_to_dtype(dtype_str: str) -> torch.dtype: """ @@ -69,8 +82,6 @@ def bench_matmul(input_yaml: str): if transB == "N": B = B.T # Apply transpose to B if transB is "N" - # Initialize tensors with the appropriate dimensions - # Warm-up iterations for _ in range(20): _ = torch.matmul(A, B) @@ -88,30 +99,20 @@ def bench_matmul(input_yaml: str): elapsed_ms = start_event.elapsed_time(end_event) # time in milliseconds times.append(elapsed_ms) - # Calculate mean execution time (ms) and derive performance in TFLOPS. + # Calculate mean execution time (ms) and derive performance. mean_ms = sum(times) / len(times) - # Compute FLOPS count: 2 * m * n * k operations (each multiply-add counts as 2 operations) scaled to tera (1e-12) - flops = 2 * m * n * k * 1e-12 - tflops = flops / (mean_ms * 1e-3) + gflops = (2 * m * n * k) / (mean_ms * 1e-3) / 1e9 print( - f"m={m}, n={n}, k={k}, in_dtype={in_dtype}, out_dtype={out_dtype} perf={tflops}" + f"m={m}, n={n}, k={k}, in_dtype={in_dtype}, out_dtype={out_dtype} perf={gflops:.1f} GFLOPS" ) - # Calculate bytes processed: considering both A, B, and the output tensor. - bytes_fn = lambda: (A.element_size() * (m * k + n * k)) + ( - m * n * A.element_size() - ) - - # Collect the metrics in a dictionary for later CSV output. metrics = { "m": m, "n": n, "k": k, - "mnk": m * n * k, - "bytes": bytes_fn(), - "flops": flops, - "tflops": tflops, + "gflops": gflops, + "ms": mean_ms, "in_dtype": str(in_dtype), "out_dtype": str(out_dtype), "transA": transA, @@ -122,21 +123,18 @@ def bench_matmul(input_yaml: str): return benchmark_results +def _build_child_cmd(args): + """Build a subprocess command from parsed args (excludes --cu-sweep and --output-csv).""" + cmd = [sys.executable, os.path.abspath(__file__)] + cmd += ["--input-yaml", args.input_yaml] + return cmd + + def write_csv(filename: str, results): """Write the benchmark results to a CSV file.""" - fieldnames = [ - "m", - "n", - "k", - "mnk", - "bytes", - "flops", - "tflops", - "in_dtype", - "out_dtype", - "transA", - "transB", - ] + fieldnames = ["m", "n", "k", "gflops", "ms", "in_dtype", "out_dtype", "transA", "transB"] + if results and "active_cus" in results[0]: + fieldnames.insert(0, "active_cus") with open(filename, mode="w", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() @@ -161,9 +159,67 @@ def write_csv(filename: str, results): default="", help="Filename for CSV output (if not specified, CSV output is disabled).", ) + parser.add_argument( + "--cu-sweep", action="store_true", + help="Run a balanced CU-mask sweep (MI300X). Re-invokes this script as " + "subprocesses with ROC_GLOBAL_CU_MASK set.", + ) + parser.add_argument( + "--cu-sweep-max-remove", type=int, default=34, + help="Max CUs to remove per XCD (default 34, minimum 4 CUs/XCD left).", + ) + + # Hidden: used by cu-sweep parent to tag subprocess results + parser.add_argument("--_active-cus", type=int, default=None, help=argparse.SUPPRESS) + args = parser.parse_args() + if args.cu_sweep: + full_cus, num_xcds, cus_per_xcd = get_cu_info() + all_results = [] + child_base = _build_child_cmd(args) + max_remove = min(args.cu_sweep_max_remove, cus_per_xcd - 1) + + for r in range(max_remove + 1): + active = full_cus - r * num_xcds + mask = build_balanced_hex_mask(r, num_xcds, cus_per_xcd) + + child_cmd = child_base + ["--_active-cus", str(active)] + + env = os.environ.copy() + if mask: + env["ROC_GLOBAL_CU_MASK"] = mask + + proc = subprocess.run( + child_cmd, capture_output=True, text=True, env=env, + ) + if proc.returncode != 0: + print(f"[active_cus={active}] subprocess failed:", file=sys.stderr) + sys.stderr.write(proc.stderr[-500:] if len(proc.stderr) > 500 else proc.stderr) + continue + + step_results = json.loads(proc.stdout) + all_results.extend(step_results) + + if args.output_csv: + write_csv(args.output_csv, all_results) + sys.exit(0) + + is_worker = args._active_cus is not None + + if is_worker: + # Suppress prints when running as a subprocess + import io + sys.stdout = io.StringIO() + benchmark_results = bench_matmul(args.input_yaml) + if is_worker: + sys.stdout = sys.__stdout__ + for row in benchmark_results: + row["active_cus"] = args._active_cus + print(json.dumps(benchmark_results)) + sys.exit(0) + if args.output_csv: write_csv(args.output_csv, benchmark_results) diff --git a/benchmarks/tritonblas_matmul.py b/benchmarks/tritonblas_matmul.py index 5ab6ac3..fa1a006 100644 --- a/benchmarks/tritonblas_matmul.py +++ b/benchmarks/tritonblas_matmul.py @@ -2,36 +2,61 @@ """ TritonBLAS Unified Matrix Multiplication Benchmark -This benchmark script supports both standard (fp16/bf16/fp32) and quantized (fp8/int8) +This benchmark script supports both standard (fp16/bf16/fp32) and quantized (fp8/int8) matrix multiplication. It automatically detects the dtype and uses the appropriate API. + +Supports three kernel modes via CLI flags: + (default) Persistent GEMM + --enable-streamk Stream-K GEMM + --work-stealing Work-stealing persistent GEMM + +Optional --cu-sweep runs a balanced CU-mask sweep (MI300X) by re-invoking this +script as subprocesses with ROC_GLOBAL_CU_MASK set. Results flow into the same +CSV with an extra ``active_cus`` column. """ import argparse import csv +import json +import os import random +import subprocess +import sys import torch # type: ignore import triton # type: ignore -import tritonblas # type: ignore import yaml # type: ignore from tqdm import tqdm # type: ignore -from tritonblas.utils import MatmulInputs, generate_matmul_inputs, str_to_dtype, _is_float8_like # type: ignore +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "include")) +import tritonblas # type: ignore + +from common import get_cu_info, build_balanced_hex_mask # type: ignore +from tritonblas.utils import generate_matmul_inputs, str_to_dtype, _is_float8_like # type: ignore -def test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, enable_streamk, init_type="randn"): +def test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, enable_streamk, + work_stealing=False, init_type="randn", total_cus=None, + global_atomic=False): """Test matmul with proper input generation - handles both quantized and non-quantized dtypes""" inputs = generate_matmul_inputs(m, n, k, in_dtype, out_dtype, transA, transB, init_type) selector = tritonblas.OrigamiMatmulSelector( - m, n, k, inputs.A.dtype, inputs.B.dtype, inputs.C.dtype, inputs.A.device + m, n, k, inputs.A.dtype, inputs.B.dtype, inputs.C.dtype, inputs.A.device, + streamk=enable_streamk, total_cus=total_cus, ) + cfg = tritonblas.matmul_preamble(selector) + cfg.global_atomic = global_atomic if inputs.is_quantized: tritonblas.matmul_a8w8_lt( - inputs.A, inputs.B, inputs.scaleA, inputs.scaleB, inputs.C, selector, enable_streamk + inputs.A, inputs.B, inputs.scaleA, inputs.scaleB, inputs.C, selector, cfg, + enable_streamk, work_stealing=work_stealing, ) else: - tritonblas.matmul_lt(inputs.A, inputs.B, inputs.C, selector, enable_streamk) + tritonblas.matmul_lt( + inputs.A, inputs.B, inputs.C, selector, cfg, + enable_streamk, work_stealing=work_stealing, + ) if inputs.is_quantized: acc = torch.matmul(inputs.A.to(torch.float32), inputs.B.to(torch.float32)) @@ -89,6 +114,7 @@ def test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, enable_streamk, in print(f"{size_str} Correct✅") + def bench_matmul( input_yaml: str, init_type: str, @@ -97,7 +123,11 @@ def bench_matmul( output_csv=None, write_csv_freq=100, enable_streamk=False, + work_stealing=False, check_correctness=False, + total_cus=None, + counters_per_xcd=None, + global_atomic=False, ): with open(input_yaml, "r") as f: dataset = yaml.safe_load(f) @@ -147,30 +177,44 @@ def bench_matmul( ) ) - # Build a tritonBLAS selector config and launch matmul selector = tritonblas.OrigamiMatmulSelector( m, n, k, inputs.A.dtype, inputs.B.dtype, inputs.C.dtype, inputs.A.device, - streamk=enable_streamk + streamk=enable_streamk, total_cus=total_cus, ) - ## TODO + if counters_per_xcd is not None: + selector.COUNTERS_PER_XCD = counters_per_xcd + cfg = tritonblas.matmul_preamble(selector) + cfg.global_atomic = global_atomic config = (selector.block_m, selector.block_n, selector.block_k) - # Use appropriate API based on quantization if inputs.is_quantized: matmul = lambda: tritonblas.matmul_a8w8_lt( - inputs.A, inputs.B, inputs.scaleA, inputs.scaleB, inputs.C, selector, enable_streamk + inputs.A, inputs.B, inputs.scaleA, inputs.scaleB, inputs.C, selector, cfg, + enable_streamk, work_stealing=work_stealing, ) else: - matmul = lambda: tritonblas.matmul( - inputs.A, inputs.B, inputs.C, enable_streamk=enable_streamk + matmul = lambda: tritonblas.matmul_lt( + inputs.A, inputs.B, inputs.C, selector, cfg, + enable_streamk, work_stealing=work_stealing, ) - ms = triton.testing.do_bench(matmul, warmup=20, rep=20) + reset = lambda: cfg.reset(streamk=enable_streamk, work_stealing=work_stealing) + ms = tritonblas.do_bench(matmul, reset_fn=reset, n_warmup=20, n_repeat=100) perf = gflops(ms) + # Determine mode string for output + if work_stealing: + mode_str = "work_stealing" + elif enable_streamk: + mode_str = "streamk" + else: + mode_str = "persistent" + if print_verbose: print( - f"m={m}, n={n}, k={k}, in_dtype={in_dtype}, out_dtype={out_dtype}, init={init_type}, perf={perf}(GFLOPs) selected_tile={selector.block_m}x{selector.block_n}x{selector.block_k}" + f"m={m}, n={n}, k={k}, in_dtype={in_dtype}, out_dtype={out_dtype}, " + f"init={init_type}, mode={mode_str}, perf={perf}(GFLOPs) " + f"selected_tile={selector.block_m}x{selector.block_n}x{selector.block_k}" ) metrics = { @@ -195,7 +239,7 @@ def bench_matmul( "us": ms / 1000, "alpha": 1, "beta": 0, - "enable_streamk": enable_streamk, + "mode": mode_str, } benchmark_results.append(metrics) @@ -207,36 +251,45 @@ def bench_matmul( if check_correctness: print("correctness: ", end=" ", flush=True) - test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, enable_streamk, init_type) + test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, + enable_streamk, work_stealing=work_stealing, + init_type=init_type, total_cus=total_cus, + global_atomic=global_atomic) return benchmark_results +def _build_child_cmd(args): + """Build a subprocess command from parsed args (excludes --cu-sweep and --output-csv).""" + cmd = [sys.executable, os.path.abspath(__file__)] + cmd += ["--input-yaml", args.input_yaml] + cmd += ["--init_type", args.init_type] + cmd += ["--csv-write-freq", str(args.csv_write_freq)] + if args.shuffle_bench: + cmd.append("--shuffle-bench") + if args.checkcorrectness: + cmd.append("--checkcorrectness") + if args.enable_streamk: + cmd.append("--enable-streamk") + elif args.work_stealing: + cmd.append("--work-stealing") + if args.counters_per_xcd is not None: + cmd += ["--counters-per-xcd", str(args.counters_per_xcd)] + if args.global_atomic: + cmd.append("--global-atomic") + return cmd + + def write_csv(filename: str, results): fieldnames = [ - "m", - "n", - "k", - "mnk", - "macro_tile", - "bytes", - "flops", + "m", "n", "k", "mnk", "macro_tile", "bytes", "flops", "tritonblas_gflops", - "a_type", - "b_type", - "c_type", - "d_type", - "compute_type", - "in_dtype", - "out_dtype", - "init_type", - "transA", - "transB", - "us", - "alpha", - "beta", - "enable_streamk", + "a_type", "b_type", "c_type", "d_type", "compute_type", + "in_dtype", "out_dtype", "init_type", + "transA", "transB", "us", "alpha", "beta", "mode", ] + if results and "active_cus" in results[0]: + fieldnames.insert(0, "active_cus") with open(filename, mode="w", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() @@ -245,68 +298,140 @@ def write_csv(filename: str, results): print(f"Benchmark results saved to '{filename}'.") + if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Benchmark matmul performance (supports both standard and quantized dtypes) and optionally output performance metrics to a CSV file." + description="Benchmark matmul performance (supports both standard and quantized dtypes) " + "and optionally output performance metrics to a CSV file." ) parser.add_argument( - "--input-yaml", - type=str, - default="../datasets/matmul_random.yaml", + "--input-yaml", type=str, default="../datasets/matmul_random.yaml", help="Input YAML file containing benchmark cases (default: ./matmul_random.yaml).", ) parser.add_argument( - "--output-csv", - type=str, - default="", + "--output-csv", type=str, default="", help="Filename for CSV output (if not specified, CSV output is disabled).", ) parser.add_argument( - "--init_type", - type=str, - default="randn", + "--init_type", type=str, default="randn", choices=["hpl", "trig_float", "zeros", "randn"], help="Tensor initialization type (default: randn).", ) parser.add_argument( - "--shuffle-bench", - action="store_true", + "--shuffle-bench", action="store_true", help="Randomly shuffle the order the benchmark runs", ) parser.add_argument( - "--csv-write-freq", - type=int, - default=1000, + "--csv-write-freq", type=int, default=1000, help="Number of problems to run before writing to csv", ) parser.add_argument( - "--print-verbose", - action="store_true", + "--print-verbose", action="store_true", help="Print detailed information for each benchmark.", ) parser.add_argument( - "--checkcorrectness", - action="store_true", - default=False, + "--checkcorrectness", action="store_true", default=False, help="Check result correctness", ) + + mode_group = parser.add_mutually_exclusive_group() + mode_group.add_argument( + "--enable-streamk", action="store_true", + help="Enable Stream-K mode for matrix multiplication.", + ) + mode_group.add_argument( + "--work-stealing", action="store_true", + help="Enable work-stealing persistent GEMM with per-XCD atomic counters.", + ) + + parser.add_argument( + "--cu-sweep", action="store_true", + help="Run a balanced CU-mask sweep (MI300X). Uses the same --input-yaml " + "shapes and kernel mode. Re-invokes this script as subprocesses with " + "ROC_GLOBAL_CU_MASK set; results include an active_cus column.", + ) parser.add_argument( - "--enable-streamk", - action="store_true", - help="Enable Stream-K mode for matrix multiplication (default: False for persistent mode).", + "--cu-sweep-max-remove", type=int, default=34, + help="Max CUs to remove per XCD (default 34, minimum 4 CUs/XCD left).", ) + parser.add_argument( + "--counters-per-xcd", type=int, default=None, + help="Override COUNTERS_PER_XCD for work-stealing (default: use selector value).", + ) + parser.add_argument( + "--global-atomic", action="store_true", + help="Use a single device-wide atomic counter instead of per-XCD counters " + "(only meaningful with --work-stealing).", + ) + + # Hidden: used by cu-sweep parent to tag subprocess results + parser.add_argument("--_active-cus", type=int, default=None, help=argparse.SUPPRESS) + parser.add_argument("--_total-cus", type=int, default=None, help=argparse.SUPPRESS) + args = parser.parse_args() + if args.cu_sweep: + full_cus, num_xcds, cus_per_xcd = get_cu_info() + all_results = [] + child_base = _build_child_cmd(args) + max_remove = min(args.cu_sweep_max_remove, cus_per_xcd - 1) + + for r in range(max_remove + 1): + active = full_cus - r * num_xcds + mask = build_balanced_hex_mask(r, num_xcds, cus_per_xcd) + + child_cmd = child_base + ["--_active-cus", str(active), + "--_total-cus", str(full_cus)] + + env = os.environ.copy() + if mask: + env["ROC_GLOBAL_CU_MASK"] = mask + + proc = subprocess.run( + child_cmd, capture_output=True, text=True, env=env, + ) + if proc.returncode != 0: + print(f"[active_cus={active}] subprocess failed:", file=sys.stderr) + sys.stderr.write(proc.stderr[-500:] if len(proc.stderr) > 500 else proc.stderr) + continue + + step_results = json.loads(proc.stdout) + all_results.extend(step_results) + + if args.output_csv: + write_csv(args.output_csv, all_results) + sys.exit(0) + + is_worker = args._active_cus is not None + + # Suppress prints when running as a CU-sweep subprocess so they + # don't corrupt the JSON payload sent back to the parent. + if is_worker: + import io + sys.stdout = io.StringIO() + benchmark_results = bench_matmul( args.input_yaml, args.init_type, shuffle_benchmark=args.shuffle_bench, - output_csv=args.output_csv, + output_csv=args.output_csv if not is_worker else None, write_csv_freq=args.csv_write_freq, - print_verbose=args.print_verbose, + print_verbose=args.print_verbose if not is_worker else False, enable_streamk=args.enable_streamk, + work_stealing=args.work_stealing, check_correctness=args.checkcorrectness, + total_cus=args._total_cus, + counters_per_xcd=args.counters_per_xcd, + global_atomic=args.global_atomic, ) + if is_worker: + sys.stdout = sys.__stdout__ + # Tag each result with CU count and dump JSON to stdout for parent + for row in benchmark_results: + row["active_cus"] = args._active_cus + print(json.dumps(benchmark_results)) + sys.exit(0) + if args.output_csv: write_csv(args.output_csv, benchmark_results) diff --git a/datasets/bench_8k.yaml b/datasets/bench_8k.yaml new file mode 100644 index 0000000..bedbf62 --- /dev/null +++ b/datasets/bench_8k.yaml @@ -0,0 +1,7 @@ +- m: 8192 + n: 8192 + k: 8192 + in_dtype: float16 + out_dtype: float16 + transA: "N" + transB: "N" diff --git a/examples/example_matmul_lt.py b/examples/example_matmul_lt.py index 15d3982..8ef3a15 100644 --- a/examples/example_matmul_lt.py +++ b/examples/example_matmul_lt.py @@ -12,7 +12,8 @@ def example_matmul(m, n, k): # Run TritonBLAS matmul selector = tritonblas.OrigamiMatmulSelector(m, n, k, A.dtype, B.dtype, C.dtype, A.device) - tritonblas.matmul_lt(A, B, C, selector) + config = tritonblas.matmul_preamble(selector) + tritonblas.matmul_lt(A, B, C, selector, config) # Print result print(C) diff --git a/include/tritonblas/__init__.py b/include/tritonblas/__init__.py index 6f9365f..d00c08f 100644 --- a/include/tritonblas/__init__.py +++ b/include/tritonblas/__init__.py @@ -1,4 +1,6 @@ from .matmul import matmul, matmul_a8w8 from .matmul import matmul_lt, matmul_a8w8_lt from .matmul import matmul_fp4 +from .config import MatmulConfig, matmul_preamble +from .bench import do_bench from .origami import OrigamiMatmulSelector diff --git a/include/tritonblas/bench.py b/include/tritonblas/bench.py new file mode 100644 index 0000000..014b859 --- /dev/null +++ b/include/tritonblas/bench.py @@ -0,0 +1,116 @@ +import math +import statistics +import torch + + +def _get_empty_cache_for_benchmark(): + cache_size = 512 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda") + + +def _clear_cache(cache): + cache.zero_() + + +def _quantile(a, q): + n = len(a) + a = sorted(a) + + def get_quantile(q): + if not (0 <= q <= 1): + raise ValueError("Quantiles must be in the range [0, 1]") + point = q * (n - 1) + lower = math.floor(point) + upper = math.ceil(point) + t = point - lower + return (1 - t) * a[lower] + t * a[upper] + + return [get_quantile(qi) for qi in q] + + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = _quantile(times, quantiles) + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times + elif return_mode == "min": + return min(times) + elif return_mode == "max": + return max(times) + elif return_mode == "mean": + return statistics.mean(times) + elif return_mode == "median": + return statistics.median(times) + + +def do_bench( + fn, + reset_fn=lambda: None, + preamble_fn=lambda: None, + n_warmup=25, + n_repeat=100, + quantiles=None, + return_mode="mean", +): + """ + Benchmark a function by timing its execution using CUDA events. + + ``reset_fn`` is called before every invocation (warmup and timed) and is + **not** included in the measured time. Use it to zero mutable kernel state + such as ``MatmulConfig.reset()``. + + ``preamble_fn`` is called once before each invocation (after reset) and is + also **not** timed. Use it for any one-time setup that should not be + measured (e.g. ``matmul_preamble``). + + Args: + fn: Function to benchmark. + reset_fn: Called before each invocation to reset kernel state. + preamble_fn: Called before each invocation for setup. + n_warmup: Number of warmup iterations. + n_repeat: Number of timed iterations. + quantiles: Quantiles to return instead of a summary statistic. + return_mode: ``"mean"``, ``"min"``, ``"max"``, ``"median"``, or ``"all"``. + + Returns: + float or list: Timing result(s) in milliseconds. + """ + # Initial sync + single run to compile / warm caches + torch.cuda.synchronize() + preamble_fn() + reset_fn() + fn() + torch.cuda.synchronize() + + cache = _get_empty_cache_for_benchmark() + + start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + + # Warmup + for _ in range(n_warmup): + torch.cuda.synchronize() + reset_fn() + preamble_fn() + _clear_cache(cache) + torch.cuda.synchronize() + fn() + + # Timed runs + for i in range(n_repeat): + torch.cuda.synchronize() + reset_fn() + preamble_fn() + _clear_cache(cache) + torch.cuda.synchronize() + start_event[i].record() + fn() + end_event[i].record() + + torch.cuda.synchronize() + + times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + return _summarize_statistics(times, quantiles, return_mode) diff --git a/include/tritonblas/config.py b/include/tritonblas/config.py new file mode 100644 index 0000000..8c40170 --- /dev/null +++ b/include/tritonblas/config.py @@ -0,0 +1,82 @@ +import torch + +# 256-byte separation between atomic counters to avoid false sharing +# across L2 cache lines. Each int32 is 4 bytes → stride = 256 / 4 = 64 elements. +COUNTER_STRIDE = 64 + + +class MatmulConfig: + """ + Pre-allocated GPU buffers for GEMM kernel launches. + + Create via :func:`matmul_preamble` with an ``OrigamiMatmulSelector``. + Buffer sizes are derived from the selector's tile configuration. + + Attributes: + device: ``torch.device`` the buffers live on. + tile_counter: ``int32[num_counters * COUNTER_STRIDE]`` work-stealing + counters, padded to 256B per slot to avoid false sharing. + locks: ``uint8[sk_grid]`` stream-K lock array. + P: ``float32[sk_grid, block_size]`` stream-K partial buffer. + """ + + def __init__(self, device: torch.device, tile_counter: torch.Tensor, + locks: torch.Tensor, P: torch.Tensor, + global_atomic: bool = False): + self.device = device + self.tile_counter = tile_counter + self.locks = locks + self.P = P + self.global_atomic = global_atomic + + def reset(self, streamk: bool = False, work_stealing: bool = False): + """Reset mutable state based on the active kernel mode. + + Args: + streamk: Zero the stream-K lock array. + work_stealing: Zero the work-stealing tile counter. + """ + if work_stealing: + self.tile_counter.zero_() + if streamk: + self.locks.zero_() + self.P.zero_() + + def __repr__(self): + return ( + f"MatmulConfig(device={self.device!r}, " + f"tile_counter={list(self.tile_counter.shape)}, " + f"locks={list(self.locks.shape)}, " + f"P={list(self.P.shape)})" + ) + + +def matmul_preamble(selector, device: torch.device = None) -> MatmulConfig: + """ + Allocate all GPU-side buffers needed by the tritonBLAS GEMM kernels. + + Call this once per problem shape (or once with the largest expected shape) + and pass the returned config into ``matmul_lt``, ``matmul_a8w8_lt``, etc. + + Args: + selector: An ``OrigamiMatmulSelector`` providing tile sizes, XCD count, + stream-K grid, and ``COUNTERS_PER_XCD``. + device: ``torch.device`` for buffer allocation (default: current CUDA device). + + Returns: + A :class:`MatmulConfig` ready for kernel launches. + """ + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + + num_xcds = selector._hardware.NUM_XCD + counters_per_xcd = selector.COUNTERS_PER_XCD + block_size = selector.block_m * selector.block_n + sk_grid = selector.sk_grid + + num_counters = num_xcds * counters_per_xcd + tile_counter = torch.zeros(num_counters * COUNTER_STRIDE, device=device, dtype=torch.int32) + locks = torch.zeros(sk_grid, device=device, dtype=torch.uint8) + P = torch.empty(sk_grid, block_size, device=device, dtype=torch.float32) + + return MatmulConfig(device=device, tile_counter=tile_counter, locks=locks, P=P) diff --git a/include/tritonblas/kernels/__init__.py b/include/tritonblas/kernels/__init__.py index b7ea70c..81bbd57 100644 --- a/include/tritonblas/kernels/__init__.py +++ b/include/tritonblas/kernels/__init__.py @@ -4,6 +4,7 @@ This package contains specific GEMM kernel implementations: - persistent_gemm: Persistent (data-parallel) GEMM kernel using composable stages - persistent_gemm_monolithic: Monolithic persistent GEMM kernel (legacy, for debugging) +- persistent_gemm_work_stealing: Work-stealing persistent GEMM kernel - streamk_gemm: Stream-K GEMM kernel for load balancing - stages: Composable kernel building blocks @@ -23,6 +24,9 @@ # Use composable stages version (default) from .persistent_gemm import persistent_matmul +# Work-stealing kernel (opt-in via work_stealing=True in matmul calls) +from .persistent_gemm_work_stealing import ws_persistent_matmul + # Stream-K kernel is always the same from .streamk_gemm import streamk_matmul @@ -32,4 +36,4 @@ # Export stages submodule from . import stages -__all__ = ['persistent_matmul', 'streamk_matmul', 'fp4_matmul', 'stages'] +__all__ = ['persistent_matmul', 'ws_persistent_matmul', 'streamk_matmul', 'fp4_matmul', 'stages'] diff --git a/include/tritonblas/kernels/persistent_gemm_work_stealing.py b/include/tritonblas/kernels/persistent_gemm_work_stealing.py new file mode 100644 index 0000000..64bd6ff --- /dev/null +++ b/include/tritonblas/kernels/persistent_gemm_work_stealing.py @@ -0,0 +1,210 @@ +""" +Work-stealing persistent GEMM kernel with per-XCD atomic counters. + +Instead of statically partitioning tiles across workgroups (for tile_id in +range(pid, total_tiles, NUM_SMS)), each WG dynamically grabs the next +available tile via an atomic counter that is local to its XCD. + +PIDs are assigned round-robin across XCDs: + pid 0 → XCD 0, pid 1 → XCD 1, …, pid 7 → XCD 7, pid 8 → XCD 0, … + +The tile space is partitioned into contiguous per-XCD regions: + XCD i owns tiles [i * tiles_per_xcd, min((i+1) * tiles_per_xcd, total_tiles)) + +To reduce atomic contention, each XCD uses COUNTERS_PER_XCD independent +counters (default 16). The XCD's tiles are further sub-partitioned: + counter slot j within XCD i owns tiles + [xcd_base + j * tiles_per_slot, xcd_base + min((j+1) * tiles_per_slot, tiles_this_xcd)) + +Each WG picks its slot via: slot = (pid // NUM_XCDS) % COUNTERS_PER_XCD + +With 38 CUs per XCD and 16 slots, only ~2-3 CUs contend on each counter. +""" + +import triton +import triton.language as tl +import torch + +from .stages.indexing.pid_transforms import chiplet_transform + + +@triton.jit() +def ws_persistent_matmul( + A, + B, + C, + A_scale_ptr, # Optional: None for fp16/bf16, pointer for int8/fp8 + B_scale_ptr, # Optional: None for fp16/bf16, pointer for int8/fp8 + bias_ptr, + tile_counter, # Per-XCD×slot atomic counters (int32[NUM_XCDS * COUNTERS_PER_XCD]) + M, + N, + K, + stride_am, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + COUNTERS_PER_XCD: tl.constexpr, + COUNTER_STRIDE: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + CACHE_MODIFIER_A: tl.constexpr, + CACHE_MODIFIER_B: tl.constexpr, + QUANTIZED: tl.constexpr = False, # True for int8/fp8, False for fp16/bf16 + ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32, + GLOBAL_ATOMIC: tl.constexpr = False, # True: single device-wide counter +): + pid = tl.program_id(0) + xcd_id = pid % NUM_XCDS + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + tiles_per_xcd = tl.cdiv(total_tiles, NUM_XCDS) + + if GLOBAL_ATOMIC: + # Single device-wide atomic — all CUs contend on one counter. + counter_ptr = tile_counter + bound = total_tiles + else: + # Per-XCD counters with multiple slots to reduce contention. + local_wg_id = pid // NUM_XCDS + slot = local_wg_id % COUNTERS_PER_XCD + + xcd_base = xcd_id * tiles_per_xcd + xcd_end = tl.minimum(xcd_base + tiles_per_xcd, total_tiles) + tiles_this_xcd = xcd_end - xcd_base + + tiles_per_slot = tl.cdiv(tiles_this_xcd, COUNTERS_PER_XCD) + slot_base = slot * tiles_per_slot + slot_end = tl.minimum(slot_base + tiles_per_slot, tiles_this_xcd) + bound = slot_end - slot_base + + # Counters are padded to COUNTER_STRIDE int32 elements (256B) apart + # to avoid false sharing across L2 cache lines. + counter_ptr = tile_counter + (xcd_id * COUNTERS_PER_XCD + slot) * COUNTER_STRIDE + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + raw_idx = tl.atomic_add(counter_ptr, 1, scope="gpu") + + while raw_idx < bound: + # Map raw counter value → global tile_id + if GLOBAL_ATOMIC: + # Chiplet swizzle: remap global sequential index into + # per-XCD tile regions so data stays in the issuing XCD's L2. + tile_id = chiplet_transform(raw_idx, total_tiles, NUM_XCDS) + else: + tile_id = xcd_base + slot_base + raw_idx + + # GROUP_SIZE_M swizzle → (pid_m, pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + if BIAS: + bias_ = bias_ptr + rm * stride_bias + bias = tl.load(bias_, mask=rm < M, other=0.0) + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + tl.assume(loop_k > 1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + if stride_ak == 1: + a = tl.load(tl.multiple_of(A_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_A) + else: + a = tl.load(tl.multiple_of(A_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_A) + + if stride_bk == 1: + b = tl.load(tl.multiple_of(B_BASE, (16, 1)), cache_modifier=CACHE_MODIFIER_B) + else: + b = tl.load(tl.multiple_of(B_BASE, (1, 16)), cache_modifier=CACHE_MODIFIER_B) + + if QUANTIZED: + acc += tl.dot(a, b, input_precision="ieee") + else: + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + if stride_ak == 1: + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + else: + A_BASE = tl.multiple_of(A_BASE, (16, 1)) + + if stride_bk == 1: + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + else: + B_BASE = tl.multiple_of(B_BASE, (1, 16)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0, cache_modifier=CACHE_MODIFIER_A) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0, cache_modifier=CACHE_MODIFIER_B) + + if QUANTIZED: + acc += tl.dot(a, b, input_precision="ieee") + else: + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + + if QUANTIZED: + rm_A_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) % M + rn_B_scale = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N + A_scale = tl.load(A_scale_ptr + rm_A_scale) + B_scale = tl.load(B_scale_ptr + rn_B_scale) + acc *= A_scale[:, None] * B_scale[None, :] + + if BIAS: + if QUANTIZED: + bias_float = bias.to(tl.float32) + c = acc + bias_float[:, None] + c = c.to(C.type.element_ty) + else: + c = acc.to(C.type.element_ty) + c += bias[:, None] + else: + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, c, c_mask) + + raw_idx = tl.atomic_add(counter_ptr, 1, scope="gpu") diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 9b23986..26db802 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -3,21 +3,13 @@ import random import functools import time -from .kernels import persistent_matmul, streamk_matmul +from .kernels import persistent_matmul, ws_persistent_matmul, streamk_matmul from .kernels.fp4_matmul import fp4_matmul from .origami import OrigamiMatmulSelector +from .config import MatmulConfig, matmul_preamble, COUNTER_STRIDE from typing import Dict, Tuple, Optional _tensor_cache = {} -current_device_index = torch.cuda.current_device() -current_device = torch.cuda.get_device_properties(current_device_index) -MAX_SMS = current_device.multi_processor_count -# TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4 -MAX_BLOCK_SIZE = 65536 - -# Global pre-allocated buffers -_global_locks = torch.empty(MAX_SMS, device="cuda", dtype=torch.uint8) -_global_P = torch.empty(MAX_SMS, MAX_BLOCK_SIZE, device="cuda", dtype=torch.float32) # Function will behave like an LRU-Cache of heuristic results @@ -52,14 +44,17 @@ def persistent_matmul_lt( b: torch.Tensor, c: torch.Tensor, selector, + config: MatmulConfig, a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, + work_stealing: bool = False, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape + cfg = config BLK_M = selector.block_m BLK_N = selector.block_n BLK_K = selector.block_k @@ -84,49 +79,94 @@ def persistent_matmul_lt( CACHE_MODIFIER_A = None CACHE_MODIFIER_B = None - # Run in Data-parallel mode. - grids = total_tiles - # Set chunk size to same area as L2 tiles. chunk_size = gsize_m * gsize_m - chunk_size = min(chunk_size, total_programs // num_xcds) - - # TODO: Support other matmul algs. - kk = persistent_matmul[(grids,)]( - a, - b, - c, - a_scale if quantized else None, # A_scale_ptr - b_scale if quantized else None, # B_scale_ptr - None, # TODO: Enable bias. - M, - N, - K, - a.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - 0, # TODO: Enable bias stride. - stride_ak=a.stride(1), - stride_bk=b.stride(0), - BLOCK_SIZE_M=BLK_M, - BLOCK_SIZE_N=BLK_N, - BLOCK_SIZE_K=BLK_K, - GROUP_SIZE_M=gsize_m, - NUM_SMS=total_programs, - NUM_XCDS=num_xcds, - CHUNK_SIZE=chunk_size, - BIAS=False, - EVEN_K=even_k, - CACHE_MODIFIER_A=CACHE_MODIFIER_A, - CACHE_MODIFIER_B=CACHE_MODIFIER_B, - QUANTIZED=quantized, - num_stages=num_stages, - num_warps=num_warps, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, - ) + chunk_size = min(chunk_size, max(1, total_programs // num_xcds)) + + if work_stealing: + # Work-stealing: launch grid = num CUs, tiles assigned dynamically + # via per-XCD atomic counters. + grids = selector._hardware.N_CU + # grids = total_tiles + + kk = ws_persistent_matmul[(grids,)]( + a, + b, + c, + a_scale if quantized else None, # A_scale_ptr + b_scale if quantized else None, # B_scale_ptr + None, # TODO: Enable bias. + cfg.tile_counter, # Per-XCD×slot tile counters + M, + N, + K, + a.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + 0, # TODO: Enable bias stride. + stride_ak=a.stride(1), + stride_bk=b.stride(0), + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=grids, + NUM_XCDS=num_xcds, + COUNTERS_PER_XCD=selector.COUNTERS_PER_XCD, + COUNTER_STRIDE=COUNTER_STRIDE, + BIAS=False, + EVEN_K=even_k, + CACHE_MODIFIER_A=CACHE_MODIFIER_A, + CACHE_MODIFIER_B=CACHE_MODIFIER_B, + QUANTIZED=quantized, + GLOBAL_ATOMIC=cfg.global_atomic, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + ) + else: + # Default: data-parallel mode – one WG per tile. + grids = total_tiles + + # TODO: Support other matmul algs. + kk = persistent_matmul[(grids,)]( + a, + b, + c, + a_scale if quantized else None, # A_scale_ptr + b_scale if quantized else None, # B_scale_ptr + None, # TODO: Enable bias. + M, + N, + K, + a.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + 0, # TODO: Enable bias stride. + stride_ak=a.stride(1), + stride_bk=b.stride(0), + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=total_programs, + NUM_XCDS=num_xcds, + CHUNK_SIZE=chunk_size, + BIAS=False, + EVEN_K=even_k, + CACHE_MODIFIER_A=CACHE_MODIFIER_A, + CACHE_MODIFIER_B=CACHE_MODIFIER_B, + QUANTIZED=quantized, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + ) return c @@ -135,11 +175,13 @@ def streamk_matmul_lt( b: torch.Tensor, c: torch.Tensor, selector, - sk_grid: Optional[int] = None, + config: MatmulConfig, a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, ): + cfg = config + assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape @@ -174,19 +216,16 @@ def streamk_matmul_lt( CACHE_MODIFIER_A = None CACHE_MODIFIER_B = None - if sk_grid is not None: - total_programs_streamk = sk_grid - grids = total_programs_streamk block_size = BLK_M * BLK_N - # Use global buffers with optimized zeroing - if grids <= MAX_SMS and block_size <= MAX_BLOCK_SIZE: - locks = _global_locks[:grids] - P = _global_P[:grids, :block_size] + # Use pre-allocated config buffers when they fit; otherwise allocate fresh. + if grids <= cfg.locks.shape[0] and block_size <= cfg.P.shape[1]: + locks = cfg.locks[:grids] + P = cfg.P[:grids, :block_size] else: - locks = torch.empty(grids, device="cuda", dtype=torch.uint8) - P = torch.empty(grids, block_size, device="cuda", dtype=torch.float32) + locks = torch.empty(grids, device=cfg.device, dtype=torch.uint8) + P = torch.empty(grids, block_size, device=cfg.device, dtype=torch.float32) # Set chunk size to same area as L2 tiles. chunk_size = gsize_m * gsize_m @@ -234,41 +273,46 @@ def streamk_matmul_lt( return c def matmul_lt( - a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, + selector, config: MatmulConfig, + enable_streamk=False, work_stealing=False, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" if enable_streamk: - return streamk_matmul_lt(a, b, c, selector) + return streamk_matmul_lt(a, b, c, selector, config) else: - return persistent_matmul_lt(a, b, c, selector) + return persistent_matmul_lt(a, b, c, selector, config, work_stealing=work_stealing) def matmul_a8w8_lt( - a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False + a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, + c: torch.Tensor, selector, config: MatmulConfig, + enable_streamk=False, work_stealing=False, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) + return streamk_matmul_lt(a, b, c, selector, config, a_scale=a_scale, b_scale=b_scale, quantized=True) else: - return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) + return persistent_matmul_lt(a, b, c, selector, config, a_scale=a_scale, b_scale=b_scale, quantized=True, work_stealing=work_stealing) def matmul( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, enable_streamk=False, - sk_grid=None, + work_stealing=False, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk) + config = matmul_preamble(selector) if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid) + return streamk_matmul_lt(a, b, c, selector, config) else: - return persistent_matmul_lt(a, b, c, selector) + return persistent_matmul_lt(a, b, c, selector, config, work_stealing=work_stealing) def matmul_a8w8( a: torch.Tensor, @@ -277,17 +321,18 @@ def matmul_a8w8( b_scale: torch.Tensor, c: torch.Tensor, enable_streamk=False, - sk_grid=None, + work_stealing=False, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk) + config = matmul_preamble(selector) if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid, a_scale=a_scale, b_scale=b_scale, quantized=True) + return streamk_matmul_lt(a, b, c, selector, config, a_scale=a_scale, b_scale=b_scale, quantized=True) else: - return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) + return persistent_matmul_lt(a, b, c, selector, config, a_scale=a_scale, b_scale=b_scale, quantized=True, work_stealing=work_stealing) def matmul_fp4( a: torch.Tensor, diff --git a/include/tritonblas/origami.py b/include/tritonblas/origami.py index 291e06c..515b28e 100644 --- a/include/tritonblas/origami.py +++ b/include/tritonblas/origami.py @@ -24,6 +24,8 @@ class OrigamiMatmulSelector: if hasattr(torch, "float8_e4m3fnuz"): dtype_to_str[torch.float8_e4m3fnuz] = "f8" + COUNTERS_PER_XCD = 16 # work-stealing: atomic counter slots per XCD + def __init__( self, m: int, @@ -35,6 +37,7 @@ def __init__( device: torch.device, mx_block_size=0, streamk=False, + total_cus: int = None, ): # Save tensor sizes self._m = m @@ -89,6 +92,11 @@ def get_dtype_bits(dtype): # Get hardware info from Origami self._hardware = origami.get_hardware_for_device(device.index) + # When running under a CU mask (e.g. cu-sweep), the GPU reports a + # reduced N_CU. Override with the real total so architecture + # detection and config generation use the correct value. + if total_cus is not None: + self._hardware.N_CU = total_cus self._N_CU = self._hardware.N_CU # Create list of Origami config_t objects from defaults. @@ -105,24 +113,33 @@ def get_dtype_bits(dtype): self._problem, self._hardware, self._configs ) + # Heuristic to favor 256x256x64 tile when close~ + if((self._result.config.mt.m == 256 and self._result.config.mt.n != 256) or + (self._result.config.mt.m != 256 and self._result.config.mt.n == 256)): + self._result.config.mt.m = 256 + self._result.config.mt.n = 256 + self._result.config.mt.k = 64 + if streamk: self._grid = self._compute_sk_grid() else: self._grid = self._hardware.N_CU - # Try both workgroup mapping modes for compatibility with Origami Versions - try: - _mapping_mode, self._xcc_workgroup_mapping, self._workgroup_mapping = ( - origami.select_workgroup_mapping( - self._problem, self._hardware, self._result.config, self._grid - ) - ) - except ValueError: - self._xcc_workgroup_mapping, self._workgroup_mapping = ( - origami.select_workgroup_mapping( - self._problem, self._hardware, self._result.config, self._grid - ) - ) + # Handle different origami API versions for workgroup mapping + _wg_result = origami.select_workgroup_mapping( + self._problem, self._hardware, self._result.config, self._grid + ) + if isinstance(_wg_result, tuple): + # Older origami: returns (mode, xcc_mapping, mapping) or (xcc_mapping, mapping) + if len(_wg_result) == 3: + _, self._xcc_workgroup_mapping, self._workgroup_mapping = _wg_result + else: + self._xcc_workgroup_mapping, self._workgroup_mapping = _wg_result + else: + # origami >= 0.1.0: returns workgroup_mapping_t object + self._xcc_workgroup_mapping = _wg_result.wgmxcc + self._workgroup_mapping = _wg_result.wgm + @property def block_m(self): return self._result.config.mt.m @@ -377,7 +394,8 @@ def _infer_matrix_instruction_dimensions(self): # Architecture Detected is not valid if mi_dim == None: raise ValueError( - f"No Valid Matrix Instruction integrated for {element_size_A}-bit or {element_size_B}-bit datatypes" + f"No Valid Matrix Instruction for {self._a_dtype_bitsize}-bit/{self._b_dtype_bitsize}-bit dtypes " + f"on hardware with N_CU={self._hardware.N_CU}" ) return mi_dim diff --git a/setup.py b/setup.py index bd4fcc5..ce0ae11 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,9 @@ def run(self): setup( + name="tritonblas", + version="0.1.0", + package_dir={"": "include"}, cmdclass={"build_ext": CustomBuildExt}, ext_modules=[Extension("_trigger_ext", sources=[])], ) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 85275fd..1fb95f9 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -33,22 +33,23 @@ ], ) @pytest.mark.parametrize( - "enable_streamk", + "mode", [ - False, - True, + "persistent", + "streamk", + "work_stealing", ], ) -def test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, enable_streamk): +def test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, mode): """Test non-quantized matmul with all transpose combinations using shared input generation utilities.""" init_type = "randn" + enable_streamk = mode == "streamk" + work_stealing = mode == "work_stealing" - # Generate all inputs using shared utility (handles transposes automatically) inputs = generate_matmul_inputs(m, n, k, in_dtype, out_dtype, transA, transB, init_type) - # Run TritonBLAS matmul - tritonblas.matmul(inputs.A, inputs.B, inputs.C, enable_streamk=enable_streamk) + tritonblas.matmul(inputs.A, inputs.B, inputs.C, enable_streamk=enable_streamk, + work_stealing=work_stealing) - # Check correctness torch_c = torch.matmul(inputs.A, inputs.B) torch.testing.assert_close(inputs.C.to(out_dtype), torch_c, atol=1, rtol=1) diff --git a/tests/test_matmul_a8w8_lt.py b/tests/test_matmul_a8w8_lt.py index 532192e..0651952 100644 --- a/tests/test_matmul_a8w8_lt.py +++ b/tests/test_matmul_a8w8_lt.py @@ -86,10 +86,13 @@ def test_matmul_a8w8(m, n, k, in_dtype, out_dtype, transA, transB, enable_stream # Scales from generate_matmul_inputs are already 1D: (M,) and (N,) # which is what the kernel expects selector = tritonblas.OrigamiMatmulSelector( - m, n, k, inputs.A.dtype, inputs.B.dtype, inputs.C.dtype, inputs.A.device + m, n, k, inputs.A.dtype, inputs.B.dtype, inputs.C.dtype, inputs.A.device, + streamk=enable_streamk, ) + config = tritonblas.matmul_preamble(selector) tritonblas.matmul_a8w8_lt( - inputs.A, inputs.B, inputs.scaleA, inputs.scaleB, inputs.C, selector, enable_streamk + inputs.A, inputs.B, inputs.scaleA, inputs.scaleB, inputs.C, selector, config, + enable_streamk, ) # Check correctness using reference computation diff --git a/tests/test_matmul_lt.py b/tests/test_matmul_lt.py index d998c0c..a2924db 100644 --- a/tests/test_matmul_lt.py +++ b/tests/test_matmul_lt.py @@ -34,26 +34,28 @@ ], ) @pytest.mark.parametrize( - "enable_streamk", + "mode", [ - False, - True, + "persistent", + "streamk", + "work_stealing", ], ) -def test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, enable_streamk): +def test_matmul(m, n, k, in_dtype, out_dtype, transA, transB, mode): """Test non-quantized matmul with all transpose combinations using shared input generation utilities.""" init_type = "randn" + enable_streamk = mode == "streamk" + work_stealing = mode == "work_stealing" - # Generate all inputs using shared utility (handles transposes automatically) inputs = generate_matmul_inputs(m, n, k, in_dtype, out_dtype, transA, transB, init_type) - # Run TritonBLAS matmul selector = tritonblas.OrigamiMatmulSelector( - m, n, k, inputs.A.dtype, inputs.B.dtype, inputs.C.dtype, inputs.A.device + m, n, k, inputs.A.dtype, inputs.B.dtype, inputs.C.dtype, inputs.A.device, + streamk=enable_streamk, ) - tritonblas.matmul_lt(inputs.A, inputs.B, inputs.C, selector, enable_streamk) + config = tritonblas.matmul_preamble(selector) + tritonblas.matmul_lt(inputs.A, inputs.B, inputs.C, selector, config, + enable_streamk, work_stealing=work_stealing) - # Check correctness torch_c = torch.matmul(inputs.A, inputs.B) - # torch.testing.assert_close(inputs.C.to(out_dtype), torch_c, atol=1e-2, rtol=1e-3) torch.testing.assert_close(inputs.C.to(out_dtype), torch_c, atol=1, rtol=1) diff --git a/tools/plot_cu_sweep.py b/tools/plot_cu_sweep.py new file mode 100644 index 0000000..38a794d --- /dev/null +++ b/tools/plot_cu_sweep.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Plot CU-sweep results: one line per kernel variant. + +Usage: + python tools/plot_cu_sweep.py \ + --persistent results/cu_sweep_persistent_full.csv \ + --streamk results/cu_sweep_streamk_full.csv \ + --torch results/cu_sweep_torch_full.csv \ + --ws-cpc 1 results/cu_sweep_ws_cpc1.csv \ + --ws-cpc 2 results/cu_sweep_ws_cpc2.csv \ + --ws-cpc 4 results/cu_sweep_ws_cpc4.csv \ + --ws-cpc 8 results/cu_sweep_ws_cpc8.csv \ + --ws-cpc 16 results/cu_sweep_ws_cpc16.csv \ + -o results/cu_sweep_plot.png +""" +import argparse +import csv +import numpy as np +import matplotlib.pyplot as plt + + +def load_tritonblas_csv(path): + """Load a tritonblas benchmark CSV, return sorted (active_cus, gflops) lists.""" + cus, gflops = [], [] + with open(path) as f: + for row in csv.DictReader(f): + cus.append(int(row["active_cus"])) + gflops.append(float(row["tritonblas_gflops"])) + order = sorted(range(len(cus)), key=lambda i: cus[i]) + return [cus[i] for i in order], [gflops[i] for i in order] + + +def load_torch_csv(path): + """Load a torch.mm benchmark CSV, return sorted (active_cus, gflops) lists.""" + cus, gflops = [], [] + with open(path) as f: + for row in csv.DictReader(f): + cus.append(int(row["active_cus"])) + gflops.append(float(row["gflops"])) + order = sorted(range(len(cus)), key=lambda i: cus[i]) + return [cus[i] for i in order], [gflops[i] for i in order] + + +def main(): + parser = argparse.ArgumentParser(description="Plot CU-sweep benchmark results.") + parser.add_argument("--persistent", type=str, help="Persistent GEMM CSV") + parser.add_argument("--streamk", type=str, help="Stream-K GEMM CSV") + parser.add_argument("--ws", type=str, help="Work-stealing GEMM CSV (single line, legacy)") + parser.add_argument("--ws-cpc", nargs=2, action="append", metavar=("CPC", "CSV"), + help="Work-stealing CSV with counters-per-XCD value, e.g. --ws-cpc 4 file.csv") + parser.add_argument("--ws-global", type=str, help="Work-stealing global-atomic CSV") + parser.add_argument("--torch", type=str, help="torch.mm CSV") + parser.add_argument("-o", "--output", type=str, default="cu_sweep_plot.png", + help="Output image path (default: cu_sweep_plot.png)") + parser.add_argument("--title", type=str, default=None, + help="Custom plot title (default: auto-generated)") + args = parser.parse_args() + + fig, ax = plt.subplots(figsize=(14, 8)) + + if args.persistent: + cus, gf = load_tritonblas_csv(args.persistent) + ax.plot(cus, gf, label="Persistent", linewidth=2.5, markersize=5, + color="#2196F3", marker="o") + if args.streamk: + cus, gf = load_tritonblas_csv(args.streamk) + ax.plot(cus, gf, label="Stream-K", linewidth=2.5, markersize=5, + color="#4CAF50", marker="s") + if args.torch: + cus, gf = load_torch_csv(args.torch) + ax.plot(cus, gf, label="torch.mm (hipBLASLt)", linewidth=2.5, markersize=5, + color="#F44336", marker="D") + + ws_colors = { + "1": "#FF9800", + "2": "#9C27B0", + "4": "#00BCD4", + "8": "#795548", + "16": "#E91E63", + } + ws_markers = { + "1": "^", "2": "v", "4": "<", "8": ">", "16": "P", + } + + if args.ws and not args.ws_cpc: + cus, gf = load_tritonblas_csv(args.ws) + ax.plot(cus, gf, label="Work-Stealing", linewidth=2, markersize=5, + color="#FF9800", marker="^") + + if args.ws_global: + cus, gf = load_tritonblas_csv(args.ws_global) + ax.plot(cus, gf, label="Work-Stealing (Global Atomic)", linewidth=2.5, markersize=5, + color="#673AB7", marker="*", linestyle="-.") + + if args.ws_cpc: + for cpc_val, csv_path in sorted(args.ws_cpc, key=lambda x: int(x[0])): + cus, gf = load_tritonblas_csv(csv_path) + color = ws_colors.get(cpc_val, "#607D8B") + marker = ws_markers.get(cpc_val, "x") + ax.plot(cus, gf, label=f"Work-Stealing (Counters/XCD = {cpc_val})", linewidth=1.8, markersize=5, + color=color, marker=marker, linestyle="--") + + ax.set_xlabel("Active CUs", fontsize=13) + ax.set_ylabel("GFLOPS", fontsize=13) + title = args.title if args.title else "FP16 GEMM — CU Sweep (MI300X)" + ax.set_title(title, fontsize=15) + ax.legend(fontsize=11, loc="upper left", ncol=2) + ax.set_xticks(np.arange(32, 312, 8)) + ax.set_xlim(32, 312) + ax.grid(True, alpha=0.3) + ax.tick_params(labelsize=9) + + fig.tight_layout() + fig.savefig(args.output, dpi=150) + print(f"Plot saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tools/tile_sweep.py b/tools/tile_sweep.py index 2a2a2a6..7e3b990 100755 --- a/tools/tile_sweep.py +++ b/tools/tile_sweep.py @@ -73,7 +73,8 @@ def run_tritonblas_matmul( config = selector.get_config() # (BLK_M, BLK_N, BLK_K, group_size) # Benchmark - matmul_fn = lambda: tritonblas.matmul_lt(A, B, C, selector, False) + cfg = tritonblas.matmul_preamble(selector) + matmul_fn = lambda: tritonblas.matmul_lt(A, B, C, selector, cfg, False) elapsed_ms = triton.testing.do_bench(matmul_fn, warmup=20, rep=200) tflops = perf_ms(elapsed_ms, m, n, k) return tflops, elapsed_ms, config