From 4c6f299c40035aa661af82c64bf2c207438f65b1 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Fri, 21 Nov 2025 14:22:52 +0800 Subject: [PATCH 1/2] support real input --- benchmarks/flash_attn/mha.py | 28 +++++++++++++++++++++------- tests/ops/test_mha.py | 10 ++++++++-- top/utils/utils.py | 27 +++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/benchmarks/flash_attn/mha.py b/benchmarks/flash_attn/mha.py index 7b7a3438..8eb96453 100644 --- a/benchmarks/flash_attn/mha.py +++ b/benchmarks/flash_attn/mha.py @@ -3,6 +3,7 @@ import torch from torch.nn import functional as F from torch.nn.attention import sdpa_kernel, SDPBackend +from top.utils.utils import load_input_from_path class mha_fwd_benchmark(Benchmark): @@ -27,13 +28,26 @@ def total_flops(self): def total_memory(self): return 4 * self.batch * self.heads * self.seq_len * self.dim * self.dtype.itemsize - def gen_inputs(self): - Q = torch.randn( - self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) - K = torch.randn( - self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) - V = torch.randn( - self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + def gen_inputs(self, input_path=None): + if input_path is None: + # gen random inputs + Q = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + K = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + V = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + else: + # Load input data from file paths + paths = input_path.split(';') + if len(paths) != 3: + raise ValueError(f"Expected 3 input paths for Q, K, V, but got {len(paths)}") + + # Load Q, K, V + expected_shape = (self.batch, self.seq_len, self.heads, self.dim) + Q = load_input_from_path(paths[0], expected_shape, self.dtype) + K = load_input_from_path(paths[1], expected_shape, self.dtype) + V = load_input_from_path(paths[2], expected_shape, self.dtype) return Q, K, V def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor): diff --git a/tests/ops/test_mha.py b/tests/ops/test_mha.py index 726fe43e..5360fb3a 100644 --- a/tests/ops/test_mha.py +++ b/tests/ops/test_mha.py @@ -4,11 +4,11 @@ from benchmarks import mha_fwd_benchmark, mha_bwd_benchmark -def test_mha_fwd(B, S, H, D, causal, dtype, tune=False): +def test_mha_fwd(B, S, H, D, causal, dtype, tune=False, input_path=None): op = mha_fwd(B, H, S, D, causal, dtype, tune=tune) benchmark = mha_fwd_benchmark(B, H, S, D, causal, dtype) - inputs = benchmark.gen_inputs() + inputs = benchmark.gen_inputs(input_path) benchmark.check(op, *inputs) benchmark.profile(op, *inputs) @@ -34,6 +34,12 @@ def test_mha_bwd(B, S, H, D, causal, dtype, tune=False): parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') parser.add_argument( '--disable_bwd', action='store_false', default=True, help='when test fwd profile') + parser.add_argument( + '--input_path', + type=str, + default=None, + help='Path to real input data. Use ";" to separate multiple paths. If None, random inputs will be generated.' + ) args = parser.parse_args() test_mha_fwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, str2dtype[args.dtype], diff --git a/top/utils/utils.py b/top/utils/utils.py index a872e6ab..d8faabf2 100644 --- a/top/utils/utils.py +++ b/top/utils/utils.py @@ -1,3 +1,4 @@ +import os import torch # A mapping from string dtype names to torch dtypes @@ -62,3 +63,29 @@ def is_hopper(): def get_sm_version(): major, minor = torch.cuda.get_device_capability() return major * 10 + minor + + +def _load_input_from_path(path, expected_shape, dtype, device='cuda'): + """ + 从文件路径加载输入数据的公共函数 + + Args: + path: 文件路径 + expected_shape: 期望的张量形状 + dtype: 数据类型 + device: 设备类型 + + Returns: + 加载的张量 + """ + if not os.path.exists(path): + raise FileNotFoundError(f"Input file not found: {path}") + + tensor = torch.load(path) + + if tensor.shape != expected_shape: + raise ValueError( + f"Shape mismatch: expected {expected_shape}, got {tensor.shape} from {path}") + + tensor = tensor.to(dtype=dtype, device=device) + return tensor From 08b45f719a0877b88e5125057012ed381a375522 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Fri, 21 Nov 2025 18:01:05 +0800 Subject: [PATCH 2/2] support real input --- benchmarks/flash_attn/mha.py | 2 ++ tests/ops/test_mha.py | 8 ++++---- top/utils/utils.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/benchmarks/flash_attn/mha.py b/benchmarks/flash_attn/mha.py index 8eb96453..5e346360 100644 --- a/benchmarks/flash_attn/mha.py +++ b/benchmarks/flash_attn/mha.py @@ -31,6 +31,7 @@ def total_memory(self): def gen_inputs(self, input_path=None): if input_path is None: # gen random inputs + print("Gen random inputs!") Q = torch.randn( self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) K = torch.randn( @@ -39,6 +40,7 @@ def gen_inputs(self, input_path=None): self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) else: # Load input data from file paths + print("Gen inputs from file!") paths = input_path.split(';') if len(paths) != 3: raise ValueError(f"Expected 3 input paths for Q, K, V, but got {len(paths)}") diff --git a/tests/ops/test_mha.py b/tests/ops/test_mha.py index 5360fb3a..4f81f759 100644 --- a/tests/ops/test_mha.py +++ b/tests/ops/test_mha.py @@ -13,11 +13,11 @@ def test_mha_fwd(B, S, H, D, causal, dtype, tune=False, input_path=None): benchmark.profile(op, *inputs) -def test_mha_bwd(B, S, H, D, causal, dtype, tune=False): +def test_mha_bwd(B, S, H, D, causal, dtype, tune=False, input_path=None): op = mha_bwd(B, H, S, D, causal, dtype, tune=tune) benchmark = mha_bwd_benchmark(B, H, S, D, causal, dtype) - inputs = benchmark.gen_inputs() + inputs = benchmark.gen_inputs(input_path) benchmark.check(op, *inputs) benchmark.profile(op, *inputs) @@ -43,7 +43,7 @@ def test_mha_bwd(B, S, H, D, causal, dtype, tune=False): args = parser.parse_args() test_mha_fwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, str2dtype[args.dtype], - args.tune) + args.tune, args.input_path) if args.disable_bwd: test_mha_bwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, - str2dtype[args.dtype], args.tune) + str2dtype[args.dtype], args.tune, args.input_path) diff --git a/top/utils/utils.py b/top/utils/utils.py index d8faabf2..e2918f75 100644 --- a/top/utils/utils.py +++ b/top/utils/utils.py @@ -65,7 +65,7 @@ def get_sm_version(): return major * 10 + minor -def _load_input_from_path(path, expected_shape, dtype, device='cuda'): +def load_input_from_path(path, expected_shape, dtype, device='cuda'): """ 从文件路径加载输入数据的公共函数