Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def ref_program(self, A: torch.Tensor, B: torch.Tensor):
if self.trans_B:
B = B.T
return torch.matmul(A, B)

def baseline_profile(self, *inputs, warmup=100, rep=100, device="cuda:0"):

print("===== Profiling MatMul torch backend =====")
return super().baseline_profile(self.ref_program, *inputs, backend="torch", warmup=warmup, rep=rep, device=device)
Comment on lines +38 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The print statement here is redundant, as the parent baseline_profile method already prints a descriptive header (===== Profiling {backend} =====). Removing this line and the extra blank line will make the output cleaner and the code more concise.

Suggested change
def baseline_profile(self, *inputs, warmup=100, rep=100, device="cuda:0"):
print("===== Profiling MatMul torch backend =====")
return super().baseline_profile(self.ref_program, *inputs, backend="torch", warmup=warmup, rep=rep, device=device)
def baseline_profile(self, *inputs, warmup=100, rep=100, device="cuda:0"):
return super().baseline_profile(self.ref_program, *inputs, backend="torch", warmup=warmup, rep=rep, device=device)



class matmul_benchmark(Benchmark):
Expand Down
49 changes: 49 additions & 0 deletions benchmarks/input_params/gemm_new.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
M,N,K,dtype,transA,transB
1,16384,16384,float16,False,False
1,18432,7168,bfloat16,False,False
1,7168,18432,float16,False,False
128,16384,4096,bfloat16,False,False
128,18432,7168,float16,False,False
128,7168,18432,bfloat16,False,False
4096,16384,16384,float16,False,False
4096,18432,7168,bfloat16,False,False
4096,7168,18432,float16,False,False
16384,16384,16384,bfloat16,False,False
16384,18432,7168,float16,False,False
16384,7168,18432,bfloat16,False,False
1,16384,16384,float16,False,True
1,18432,7168,bfloat16,False,True
1,7168,18432,float16,False,True
128,16384,4096,bfloat16,False,True
128,18432,7168,float16,False,True
128,7168,18432,bfloat16,False,True
4096,16384,16384,float16,False,True
4096,18432,7168,bfloat16,False,True
4096,7168,18432,float16,False,True
16384,16384,16384,bfloat16,False,True
16384,18432,7168,float16,False,True
16384,7168,18432,bfloat16,False,True
1,16384,16384,float16,True,False
1,18432,7168,bfloat16,True,False
1,7168,18432,float16,True,False
128,16384,4096,bfloat16,True,False
128,18432,7168,float16,True,False
128,7168,18432,bfloat16,True,False
4096,16384,16384,float16,True,False
4096,18432,7168,bfloat16,True,False
4096,7168,18432,float16,True,False
16384,16384,16384,bfloat16,True,False
16384,18432,7168,float16,True,False
16384,7168,18432,bfloat16,True,False
1,16384,16384,float16,True,True
1,18432,7168,bfloat16,True,True
1,7168,18432,float16,True,True
128,16384,4096,bfloat16,True,True
128,18432,7168,float16,True,True
128,7168,18432,bfloat16,True,True
4096,16384,16384,float16,True,True
4096,18432,7168,bfloat16,True,True
4096,7168,18432,float16,True,True
16384,16384,16384,bfloat16,True,True
16384,18432,7168,float16,True,True
16384,7168,18432,bfloat16,True,True
11 changes: 11 additions & 0 deletions benchmarks/input_params/gqa_new.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
batch,seq_len,heads,heads_kv,dim,causal,dtype
16,2048,64,4,128,False,bfloat16
8,4096,64,4,128,False,bfloat16
4,8192,64,4,128,False,bfloat16
2,16384,64,4,128,False,bfloat16
1,32768,64,4,128,False,bfloat16
16,2048,64,4,128,True,bfloat16
8,4096,64,4,128,True,bfloat16
4,8192,64,4,128,True,bfloat16
2,16384,64,4,128,True,bfloat16
1,32768,64,4,128,True,bfloat16
11 changes: 11 additions & 0 deletions benchmarks/input_params/mha_new.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
batch,seq_len,heads,dim,dtype,causal
16,2048,16,128,float16,FALSE
8,4096,16,128,float16,FALSE
4,8192,16,128,float16,FALSE
2,16384,16,128,float16,FALSE
1,32768,16,128,float16,FALSE
16,2048,16,128,float16,TRUE
8,4096,16,128,float16,TRUE
4,8192,16,128,float16,TRUE
2,16384,16,128,float16,TRUE
1,32768,16,128,float16,TRUE
29 changes: 29 additions & 0 deletions benchmarks/profile/profile_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import argparse
from top.ops import Gemm
from top.utils import str2dtype
from benchmarks import gemm_benchmark


def test_gemm(M, N, K, dtype, trans_A=False, trans_B=False, tune=False):
op = Gemm(M, N, K, trans_A=trans_A, trans_B=trans_B, dtype=dtype, tune=tune)
benchmark = gemm_benchmark(M, N, K, dtype, trans_A=trans_A, trans_B=trans_B)

inputs = benchmark.gen_inputs()
benchmark.check(op, *inputs)
benchmark.profile(op, *inputs)
benchmark.baseline_profile(*inputs)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=1024, help='M')
parser.add_argument('--N', type=int, default=1024, help='N')
parser.add_argument('--K', type=int, default=1024, help='K')
parser.add_argument(
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
parser.add_argument('--trans_A', action='store_true', default=False, help='transpose input A')
parser.add_argument('--trans_B', action='store_true', default=False, help='transpose input B')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_gemm(args.M, args.N, args.K, str2dtype[args.dtype], args.trans_A, args.trans_B, args.tune)
11 changes: 9 additions & 2 deletions benchmarks/profile/profile_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@ def build_gemm_cmd(args_dict):
str(args_dict['M']), '--N',
str(args_dict['N']), '--K',
str(args_dict['K']), '--dtype',
str(args_dict['dtype'])
str(args_dict['dtype']), '--tune'
]

if args_dict.get('trans_A', False):
cmd_args.append('--trans_A')

if args_dict.get('trans_B', False):
cmd_args.append('--trans_B')
Comment on lines +21 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two issues with the current implementation for building command arguments:

  1. The --tune flag is added unconditionally. This prevents running benchmarks without autotuning. It's better to make this configurable via the input CSV.
  2. The checks for trans_A and trans_B are incorrect. csv.DictReader reads all values as strings. A check like if args_dict.get('trans_A', False) will be True even if the value is the string 'False', leading to incorrect benchmark execution. The check should explicitly compare against 'true'.

The suggested change below fixes both issues by making all three flags (tune, trans_A, trans_B) conditional and using a safe string comparison. This will allow you to control these flags from your gemm.csv file.

Suggested change
str(args_dict['dtype']), '--tune'
]
if args_dict.get('trans_A', False):
cmd_args.append('--trans_A')
if args_dict.get('trans_B', False):
cmd_args.append('--trans_B')
str(args_dict['dtype'])
]
if args_dict.get('tune', 'false').lower() == 'true':
cmd_args.append('--tune')
if args_dict.get('trans_A', 'false').lower() == 'true':
cmd_args.append('--trans_A')
if args_dict.get('trans_B', 'false').lower() == 'true':
cmd_args.append('--trans_B')


return cmd_args


Expand Down Expand Up @@ -210,7 +217,7 @@ def run_test_script(script_path, args_dict):

try:
# Run script and capture output
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3000)
if result.returncode != 0:
print(f"Error running script: {result.stderr}")
return None
Expand Down
Loading