Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
"""
Shared GPU/CU utilities for TritonBLAS benchmarks.
"""
import ctypes

import torch # type: ignore


def get_num_xcds(device_id: int = 0) -> int:
"""Query the number of XCDs (chiplets) via HIP runtime."""
try:
hip = ctypes.cdll.LoadLibrary("libamdhip64.so")
except OSError:
return 1
try:
hipDeviceAttributeNumberOfXccs = 10018
xcc_count = ctypes.c_int()
hip.hipDeviceGetAttribute(ctypes.byref(xcc_count), hipDeviceAttributeNumberOfXccs, device_id)
return xcc_count.value
except Exception:
return 1


def get_cu_info(device_id: int = 0):
"""Return (total_cus, num_xcds, cus_per_xcd) for the current device."""
total_cus = torch.cuda.get_device_properties(device_id).multi_processor_count
num_xcds = get_num_xcds(device_id)
cus_per_xcd = total_cus // num_xcds
return total_cus, num_xcds, cus_per_xcd


def build_balanced_hex_mask(remove_per_xcd: int, num_xcds: int, cus_per_xcd: int) -> str:
"""Build a ROC_GLOBAL_CU_MASK hex string that removes CUs from the top of every XCD."""
if remove_per_xcd == 0:
return ""
total_cus = num_xcds * cus_per_xcd
mask = (1 << total_cus) - 1
for i in range(remove_per_xcd):
base = num_xcds * i
for xcd in range(num_xcds):
bit = base + xcd
mask &= ~(1 << bit)
return f"0x{mask:x}"
116 changes: 86 additions & 30 deletions benchmarks/torch_matmul.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
#!/usr/bin/env python3
"""
Benchmark torch.matmul performance.

Supports an optional --cu-sweep mode (MI300X) that re-invokes this script as
subprocesses with ROC_GLOBAL_CU_MASK set, producing results with an
``active_cus`` column.
"""
import json
import os
import subprocess
import sys
import yaml
import argparse
import torch
import csv

from common import get_cu_info, build_balanced_hex_mask


def str_to_dtype(dtype_str: str) -> torch.dtype:
"""
Expand Down Expand Up @@ -69,8 +82,6 @@ def bench_matmul(input_yaml: str):
if transB == "N":
B = B.T # Apply transpose to B if transB is "N"

# Initialize tensors with the appropriate dimensions

# Warm-up iterations
for _ in range(20):
_ = torch.matmul(A, B)
Expand All @@ -88,30 +99,20 @@ def bench_matmul(input_yaml: str):
elapsed_ms = start_event.elapsed_time(end_event) # time in milliseconds
times.append(elapsed_ms)

# Calculate mean execution time (ms) and derive performance in TFLOPS.
# Calculate mean execution time (ms) and derive performance.
mean_ms = sum(times) / len(times)
# Compute FLOPS count: 2 * m * n * k operations (each multiply-add counts as 2 operations) scaled to tera (1e-12)
flops = 2 * m * n * k * 1e-12
tflops = flops / (mean_ms * 1e-3)
gflops = (2 * m * n * k) / (mean_ms * 1e-3) / 1e9

print(
f"m={m}, n={n}, k={k}, in_dtype={in_dtype}, out_dtype={out_dtype} perf={tflops}"
f"m={m}, n={n}, k={k}, in_dtype={in_dtype}, out_dtype={out_dtype} perf={gflops:.1f} GFLOPS"
)

# Calculate bytes processed: considering both A, B, and the output tensor.
bytes_fn = lambda: (A.element_size() * (m * k + n * k)) + (
m * n * A.element_size()
)

# Collect the metrics in a dictionary for later CSV output.
metrics = {
"m": m,
"n": n,
"k": k,
"mnk": m * n * k,
"bytes": bytes_fn(),
"flops": flops,
"tflops": tflops,
"gflops": gflops,
"ms": mean_ms,
"in_dtype": str(in_dtype),
"out_dtype": str(out_dtype),
"transA": transA,
Expand All @@ -122,21 +123,18 @@ def bench_matmul(input_yaml: str):
return benchmark_results


def _build_child_cmd(args):
"""Build a subprocess command from parsed args (excludes --cu-sweep and --output-csv)."""
cmd = [sys.executable, os.path.abspath(__file__)]
cmd += ["--input-yaml", args.input_yaml]
return cmd


def write_csv(filename: str, results):
"""Write the benchmark results to a CSV file."""
fieldnames = [
"m",
"n",
"k",
"mnk",
"bytes",
"flops",
"tflops",
"in_dtype",
"out_dtype",
"transA",
"transB",
]
fieldnames = ["m", "n", "k", "gflops", "ms", "in_dtype", "out_dtype", "transA", "transB"]
if results and "active_cus" in results[0]:
fieldnames.insert(0, "active_cus")
with open(filename, mode="w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
Expand All @@ -161,9 +159,67 @@ def write_csv(filename: str, results):
default="",
help="Filename for CSV output (if not specified, CSV output is disabled).",
)
parser.add_argument(
"--cu-sweep", action="store_true",
help="Run a balanced CU-mask sweep (MI300X). Re-invokes this script as "
"subprocesses with ROC_GLOBAL_CU_MASK set.",
)
parser.add_argument(
"--cu-sweep-max-remove", type=int, default=34,
help="Max CUs to remove per XCD (default 34, minimum 4 CUs/XCD left).",
)

# Hidden: used by cu-sweep parent to tag subprocess results
parser.add_argument("--_active-cus", type=int, default=None, help=argparse.SUPPRESS)

args = parser.parse_args()

if args.cu_sweep:
full_cus, num_xcds, cus_per_xcd = get_cu_info()
all_results = []
child_base = _build_child_cmd(args)
max_remove = min(args.cu_sweep_max_remove, cus_per_xcd - 1)

for r in range(max_remove + 1):
active = full_cus - r * num_xcds
mask = build_balanced_hex_mask(r, num_xcds, cus_per_xcd)

child_cmd = child_base + ["--_active-cus", str(active)]

env = os.environ.copy()
if mask:
env["ROC_GLOBAL_CU_MASK"] = mask

proc = subprocess.run(
child_cmd, capture_output=True, text=True, env=env,
)
if proc.returncode != 0:
print(f"[active_cus={active}] subprocess failed:", file=sys.stderr)
sys.stderr.write(proc.stderr[-500:] if len(proc.stderr) > 500 else proc.stderr)
continue

step_results = json.loads(proc.stdout)
all_results.extend(step_results)

if args.output_csv:
write_csv(args.output_csv, all_results)
sys.exit(0)

is_worker = args._active_cus is not None

if is_worker:
# Suppress prints when running as a subprocess
import io
sys.stdout = io.StringIO()

benchmark_results = bench_matmul(args.input_yaml)

if is_worker:
sys.stdout = sys.__stdout__
for row in benchmark_results:
row["active_cus"] = args._active_cus
print(json.dumps(benchmark_results))
sys.exit(0)

if args.output_csv:
write_csv(args.output_csv, benchmark_results)
Loading