From 9d50d0bf86e5695818d4c01ff703b781ac8e7937 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 17 Dec 2025 11:18:14 +0800 Subject: [PATCH 1/3] add gemm baseline profile --- benchmarks/gemm/gemm.py | 5 +++++ benchmarks/profile/profile_run.py | 9 ++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/benchmarks/gemm/gemm.py b/benchmarks/gemm/gemm.py index 5c6e60f6..1f64bbe1 100644 --- a/benchmarks/gemm/gemm.py +++ b/benchmarks/gemm/gemm.py @@ -34,6 +34,11 @@ def ref_program(self, A: torch.Tensor, B: torch.Tensor): if self.trans_B: B = B.T return torch.matmul(A, B) + + def baseline_profile(self, *inputs, warmup=100, rep=100, device="cuda:0"): + + print("===== Profiling MatMul torch backend =====") + return super().baseline_profile(self.ref_program, *inputs, backend="torch", warmup=warmup, rep=rep, device=device) class matmul_benchmark(Benchmark): diff --git a/benchmarks/profile/profile_run.py b/benchmarks/profile/profile_run.py index b0f00c17..b9f79ee8 100644 --- a/benchmarks/profile/profile_run.py +++ b/benchmarks/profile/profile_run.py @@ -18,8 +18,15 @@ def build_gemm_cmd(args_dict): str(args_dict['M']), '--N', str(args_dict['N']), '--K', str(args_dict['K']), '--dtype', - str(args_dict['dtype']) + str(args_dict['dtype']), '--tune' ] + + if args_dict.get('trans_A', False): + cmd_args.append('--trans_A') + + if args_dict.get('trans_B', False): + cmd_args.append('--trans_B') + return cmd_args From 7534d1410b1a73d45454420acc0e26ae93beadbd Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 17 Dec 2025 19:29:44 +0800 Subject: [PATCH 2/3] add new input csv --- benchmarks/input_params/gemm_new.csv | 49 ++++++++++++++++++++++++++++ benchmarks/input_params/gqa_new.csv | 11 +++++++ benchmarks/input_params/mha_new.csv | 11 +++++++ benchmarks/profile/profile_gemm.py | 29 ++++++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 benchmarks/input_params/gemm_new.csv create mode 100644 benchmarks/input_params/gqa_new.csv create mode 100644 benchmarks/input_params/mha_new.csv create mode 100644 benchmarks/profile/profile_gemm.py diff --git a/benchmarks/input_params/gemm_new.csv b/benchmarks/input_params/gemm_new.csv new file mode 100644 index 00000000..9b445d28 --- /dev/null +++ b/benchmarks/input_params/gemm_new.csv @@ -0,0 +1,49 @@ +M,N,K,dtype,transA,transB +1,16384,16384,float16,False,False +1,18432,7168,bfloat16,False,False +1,7168,18432,float16,False,False +128,16384,4096,bfloat16,False,False +128,18432,7168,float16,False,False +128,7168,18432,bfloat16,False,False +4096,16384,16384,float16,False,False +4096,18432,7168,bfloat16,False,False +4096,7168,18432,float16,False,False +16384,16384,16384,bfloat16,False,False +16384,18432,7168,float16,False,False +16384,7168,18432,bfloat16,False,False +1,16384,16384,float16,False,True +1,18432,7168,bfloat16,False,True +1,7168,18432,float16,False,True +128,16384,4096,bfloat16,False,True +128,18432,7168,float16,False,True +128,7168,18432,bfloat16,False,True +4096,16384,16384,float16,False,True +4096,18432,7168,bfloat16,False,True +4096,7168,18432,float16,False,True +16384,16384,16384,bfloat16,False,True +16384,18432,7168,float16,False,True +16384,7168,18432,bfloat16,False,True +1,16384,16384,float16,True,False +1,18432,7168,bfloat16,True,False +1,7168,18432,float16,True,False +128,16384,4096,bfloat16,True,False +128,18432,7168,float16,True,False +128,7168,18432,bfloat16,True,False +4096,16384,16384,float16,True,False +4096,18432,7168,bfloat16,True,False +4096,7168,18432,float16,True,False +16384,16384,16384,bfloat16,True,False +16384,18432,7168,float16,True,False +16384,7168,18432,bfloat16,True,False +1,16384,16384,float16,True,True +1,18432,7168,bfloat16,True,True +1,7168,18432,float16,True,True +128,16384,4096,bfloat16,True,True +128,18432,7168,float16,True,True +128,7168,18432,bfloat16,True,True +4096,16384,16384,float16,True,True +4096,18432,7168,bfloat16,True,True +4096,7168,18432,float16,True,True +16384,16384,16384,bfloat16,True,True +16384,18432,7168,float16,True,True +16384,7168,18432,bfloat16,True,True diff --git a/benchmarks/input_params/gqa_new.csv b/benchmarks/input_params/gqa_new.csv new file mode 100644 index 00000000..d86e1346 --- /dev/null +++ b/benchmarks/input_params/gqa_new.csv @@ -0,0 +1,11 @@ +batch,seq_len,heads,heads_kv,dim,causal,dtype +16,2048,64,4,128,False,bfloat16 +8,4096,64,4,128,False,bfloat16 +4,8192,64,4,128,False,bfloat16 +2,16384,64,4,128,False,bfloat16 +1,32768,64,4,128,False,bfloat16 +16,2048,64,4,128,True,bfloat16 +8,4096,64,4,128,True,bfloat16 +4,8192,64,4,128,True,bfloat16 +2,16384,64,4,128,True,bfloat16 +1,32768,64,4,128,True,bfloat16 \ No newline at end of file diff --git a/benchmarks/input_params/mha_new.csv b/benchmarks/input_params/mha_new.csv new file mode 100644 index 00000000..f34da688 --- /dev/null +++ b/benchmarks/input_params/mha_new.csv @@ -0,0 +1,11 @@ +batch,seq_len,heads,dim,dtype,causal +16,2048,16,128,float16,FALSE +8,4096,16,128,float16,FALSE +4,8192,16,128,float16,FALSE +2,16384,16,128,float16,FALSE +1,32768,16,128,float16,FALSE +16,2048,16,128,float16,TRUE +8,4096,16,128,float16,TRUE +4,8192,16,128,float16,TRUE +2,16384,16,128,float16,TRUE +1,32768,16,128,float16,TRUE diff --git a/benchmarks/profile/profile_gemm.py b/benchmarks/profile/profile_gemm.py new file mode 100644 index 00000000..1ef25f56 --- /dev/null +++ b/benchmarks/profile/profile_gemm.py @@ -0,0 +1,29 @@ +import argparse +from top.ops import Gemm +from top.utils import str2dtype +from benchmarks import gemm_benchmark + + +def test_gemm(M, N, K, dtype, trans_A=False, trans_B=False, tune=False): + op = Gemm(M, N, K, trans_A=trans_A, trans_B=trans_B, dtype=dtype, tune=tune) + benchmark = gemm_benchmark(M, N, K, dtype, trans_A=trans_A, trans_B=trans_B) + + inputs = benchmark.gen_inputs() + benchmark.check(op, *inputs) + benchmark.profile(op, *inputs) + benchmark.baseline_profile(*inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--M', type=int, default=1024, help='M') + parser.add_argument('--N', type=int, default=1024, help='N') + parser.add_argument('--K', type=int, default=1024, help='K') + parser.add_argument( + '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') + parser.add_argument('--trans_A', action='store_true', default=False, help='transpose input A') + parser.add_argument('--trans_B', action='store_true', default=False, help='transpose input B') + parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + args = parser.parse_args() + + test_gemm(args.M, args.N, args.K, str2dtype[args.dtype], args.trans_A, args.trans_B, args.tune) \ No newline at end of file From 65074f8e782e0b757808cf1324dc55d6b0293831 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Thu, 18 Dec 2025 14:25:35 +0800 Subject: [PATCH 3/3] fix bug in timeout --- benchmarks/profile/profile_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/profile/profile_run.py b/benchmarks/profile/profile_run.py index b9f79ee8..c666eb8f 100644 --- a/benchmarks/profile/profile_run.py +++ b/benchmarks/profile/profile_run.py @@ -217,7 +217,7 @@ def run_test_script(script_path, args_dict): try: # Run script and capture output - result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + result = subprocess.run(cmd, capture_output=True, text=True, timeout=3000) if result.returncode != 0: print(f"Error running script: {result.stderr}") return None