99from os .path import abspath , exists
1010from pathlib import Path
1111
12+ import yaml
13+
1214CURRENT_DIR = os .path .dirname (os .path .abspath (__file__ ))
1315
1416logger = logging .getLogger (__name__ )
@@ -31,41 +33,62 @@ def setup_tritonbench_cwd():
3133 sys .path .append (tritonbench_dir )
3234 return original_dir
3335
36+
3437setup_tritonbench_cwd ()
3538
39+ import inspect
40+
3641from tritonbench .operators import list_operators
37- from tritonbench .utils .run_utils import load_operator_by_args
3842from tritonbench .utils .operator_utils import get_backends_for_operator
39- import inspect
40- from .ast_analyzer import build_backend_callees , trace_callees
43+ from tritonbench .utils .run_utils import load_operator_by_args
44+
45+ try :
46+ from ast_analyzer import build_backend_callees , trace_callees
47+ except ImportError :
48+ from .ast_analyzer import build_backend_callees , trace_callees
49+
4150
4251def get_parser ():
43- parser = argparse .ArgumentParser (
44- description = "Trace op backends to generate tags."
45- )
52+ parser = argparse .ArgumentParser (description = "Trace op backends to generate tags." )
4653 parser .add_argument (
4754 "--op" ,
4855 type = str ,
4956 help = "Op name to trace. If unspecified, trace all ops." ,
5057 )
58+ parser .add_argument (
59+ "--only" ,
60+ type = str ,
61+ help = "Only trace the specified backend. If unspecified, trace all backends." ,
62+ )
5163 parser .add_argument (
5264 "--output" ,
5365 type = str ,
5466 default = "" ,
55- help = "Output file path." ,
67+ help = "Output file path. If none, print to stdout. " ,
5668 )
5769 return parser
5870
71+
5972def prevalidate_backends (backend_edges ):
6073 op_with_tags = {}
6174 # heuristic: do not search torch.nn, torch.compile, and xformers backends
6275 for backend , callees in backend_edges .items ():
63- if "torch.compile" in callees or any (["torch._inductor" in callee for callee in callees ]):
76+ if "torch.compile" in callees or any (
77+ ["torch._inductor" in callee for callee in callees ]
78+ ):
6479 op_with_tags [backend ] = {"tags" : ["pt2" ]}
6580 elif any (["torch.nn" in callee for callee in callees ]):
6681 op_with_tags [backend ] = {"tags" : ["aten" ]}
6782 elif any (["xformers" in callee for callee in callees ]):
6883 op_with_tags [backend ] = {"tags" : ["xformers" ]}
84+ elif any ([callee .startswith ("torch.ops." ) for callee in callees ]):
85+ custom_op_category = [
86+ callee [callee .rfind ("." ) + 1 :]
87+ for callee in callees
88+ if callee .startswith ("torch.ops." )
89+ ]
90+ op_with_tags [backend ] = {"tags" : custom_op_category + ["native_custom_ops" ]}
91+
6992 return op_with_tags
7093
7194
@@ -77,7 +100,11 @@ def trace_op(op):
77100 module_name = opbench .__module__
78101 with open (opbench_file , "r" ) as f :
79102 source = f .read ()
80- backends = get_backends_for_operator (opbench .name )
103+ backends = (
104+ get_backends_for_operator (opbench .name )
105+ if not args .only
106+ else args .only .split ("," )
107+ )
81108 backend_edges = build_backend_callees (
82109 source = source ,
83110 filename = opbench_file_name ,
@@ -86,35 +113,61 @@ def trace_op(op):
86113 )
87114 assert len (backend_edges ) == len (backends )
88115 op_with_tags [op ] = prevalidate_backends (backend_edges )
89- remaining_backends = [backend for backend in backends if backend not in op_with_tags [op ]]
116+ remaining_backends = [
117+ backend for backend in backends if backend not in op_with_tags [op ]
118+ ]
90119 # for backends without tags, we need to trace their callees to find tags
91120 # trace the callees of each backend, and return their tags
92121 for backend in remaining_backends :
93122 # special case for torch.compile
94123 callees = backend_edges [backend ]
95- base_module_name = module_name [:module_name .rfind ("." )]
96- callees_with_module = [(callee , base_module_name ) for callee in callees ]
124+ base_module_name = module_name [: module_name .rfind ("." )]
125+ callees_with_module : list [tuple [Unknown , Unknown ]] = [
126+ (callee , base_module_name ) for callee in callees
127+ ]
97128 op_with_tags [op ][backend ] = trace_callees (callees_with_module )
98129 # postprocess: add human heuristics
99130 if "liger" in backend :
100131 if not op_with_tags [op ][backend ]:
101132 op_with_tags [op ][backend ] = {"tags" : []}
102- op_with_tags [op ][backend ]["tags" ].append ("liger" )
133+ op_with_tags [op ][backend ]["tags" ].extend (["liger" ])
134+ if "triton" not in op_with_tags [op ][backend ]["tags" ]:
135+ op_with_tags [op ][backend ]["tags" ].append ("triton" )
136+ if "tlx_" in backend :
137+ if not op_with_tags [op ][backend ]:
138+ op_with_tags [op ][backend ] = {"tags" : []}
139+ op_with_tags [op ][backend ]["tags" ].extend (["tlx" ])
103140 if "eager" in backend or "aten" in backend :
104141 if not op_with_tags [op ][backend ]:
105142 op_with_tags [op ][backend ] = {"tags" : []}
106143 op_with_tags [op ][backend ]["tags" ].append ("aten" )
107144 return op_with_tags
108145
109146
147+ UNSUPPORTED_OPS = [
148+ "fp8_fused_quant_gemm_rowwise" ,
149+ "fp32_to_mx4" ,
150+ "flex_attention" ,
151+ "mx4_to_fp32" ,
152+ ]
153+
110154if __name__ == "__main__" :
111155 parser = get_parser ()
112156 args = parser .parse_args ()
113157 if not args .op :
114158 ops = list_operators ()
115159 else :
116160 ops = [args .op ]
161+ print (f"Running tagging test on ops: { ops } ..." )
117162 results = {}
118163 for op in ops :
164+ # deadloop on flex_attention
165+ if op in UNSUPPORTED_OPS :
166+ continue
119167 results .update (trace_op (op ))
120- print (results )
168+ if not args .output :
169+ print (results )
170+ else :
171+ with open (args .output , "w" ) as f :
172+ f .write (yaml .safe_dump (results ))
173+ print ("success!" )
0 commit comments