Skip to content

Commit 942308c

Browse files
authored
[tagging] Add tagging system to tags.yaml (#625)
1 parent 2a14064 commit 942308c

File tree

8 files changed

+1073
-78
lines changed

8 files changed

+1073
-78
lines changed

benchmarks/tagging/ast_analyzer.py

Lines changed: 171 additions & 51 deletions
Large diffs are not rendered by default.

benchmarks/tagging/run.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from os.path import abspath, exists
1010
from pathlib import Path
1111

12+
import yaml
13+
1214
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
1315

1416
logger = logging.getLogger(__name__)
@@ -31,41 +33,62 @@ def setup_tritonbench_cwd():
3133
sys.path.append(tritonbench_dir)
3234
return original_dir
3335

36+
3437
setup_tritonbench_cwd()
3538

39+
import inspect
40+
3641
from tritonbench.operators import list_operators
37-
from tritonbench.utils.run_utils import load_operator_by_args
3842
from tritonbench.utils.operator_utils import get_backends_for_operator
39-
import inspect
40-
from .ast_analyzer import build_backend_callees, trace_callees
43+
from tritonbench.utils.run_utils import load_operator_by_args
44+
45+
try:
46+
from ast_analyzer import build_backend_callees, trace_callees
47+
except ImportError:
48+
from .ast_analyzer import build_backend_callees, trace_callees
49+
4150

4251
def get_parser():
43-
parser = argparse.ArgumentParser(
44-
description="Trace op backends to generate tags."
45-
)
52+
parser = argparse.ArgumentParser(description="Trace op backends to generate tags.")
4653
parser.add_argument(
4754
"--op",
4855
type=str,
4956
help="Op name to trace. If unspecified, trace all ops.",
5057
)
58+
parser.add_argument(
59+
"--only",
60+
type=str,
61+
help="Only trace the specified backend. If unspecified, trace all backends.",
62+
)
5163
parser.add_argument(
5264
"--output",
5365
type=str,
5466
default="",
55-
help="Output file path.",
67+
help="Output file path. If none, print to stdout.",
5668
)
5769
return parser
5870

71+
5972
def prevalidate_backends(backend_edges):
6073
op_with_tags = {}
6174
# heuristic: do not search torch.nn, torch.compile, and xformers backends
6275
for backend, callees in backend_edges.items():
63-
if "torch.compile" in callees or any(["torch._inductor" in callee for callee in callees]):
76+
if "torch.compile" in callees or any(
77+
["torch._inductor" in callee for callee in callees]
78+
):
6479
op_with_tags[backend] = {"tags": ["pt2"]}
6580
elif any(["torch.nn" in callee for callee in callees]):
6681
op_with_tags[backend] = {"tags": ["aten"]}
6782
elif any(["xformers" in callee for callee in callees]):
6883
op_with_tags[backend] = {"tags": ["xformers"]}
84+
elif any([callee.startswith("torch.ops.") for callee in callees]):
85+
custom_op_category = [
86+
callee[callee.rfind(".") + 1 :]
87+
for callee in callees
88+
if callee.startswith("torch.ops.")
89+
]
90+
op_with_tags[backend] = {"tags": custom_op_category + ["native_custom_ops"]}
91+
6992
return op_with_tags
7093

7194

@@ -77,7 +100,11 @@ def trace_op(op):
77100
module_name = opbench.__module__
78101
with open(opbench_file, "r") as f:
79102
source = f.read()
80-
backends = get_backends_for_operator(opbench.name)
103+
backends = (
104+
get_backends_for_operator(opbench.name)
105+
if not args.only
106+
else args.only.split(",")
107+
)
81108
backend_edges = build_backend_callees(
82109
source=source,
83110
filename=opbench_file_name,
@@ -86,35 +113,61 @@ def trace_op(op):
86113
)
87114
assert len(backend_edges) == len(backends)
88115
op_with_tags[op] = prevalidate_backends(backend_edges)
89-
remaining_backends = [backend for backend in backends if backend not in op_with_tags[op]]
116+
remaining_backends = [
117+
backend for backend in backends if backend not in op_with_tags[op]
118+
]
90119
# for backends without tags, we need to trace their callees to find tags
91120
# trace the callees of each backend, and return their tags
92121
for backend in remaining_backends:
93122
# special case for torch.compile
94123
callees = backend_edges[backend]
95-
base_module_name = module_name[:module_name.rfind(".")]
96-
callees_with_module = [(callee, base_module_name) for callee in callees]
124+
base_module_name = module_name[: module_name.rfind(".")]
125+
callees_with_module: list[tuple[Unknown, Unknown]] = [
126+
(callee, base_module_name) for callee in callees
127+
]
97128
op_with_tags[op][backend] = trace_callees(callees_with_module)
98129
# postprocess: add human heuristics
99130
if "liger" in backend:
100131
if not op_with_tags[op][backend]:
101132
op_with_tags[op][backend] = {"tags": []}
102-
op_with_tags[op][backend]["tags"].append("liger")
133+
op_with_tags[op][backend]["tags"].extend(["liger"])
134+
if "triton" not in op_with_tags[op][backend]["tags"]:
135+
op_with_tags[op][backend]["tags"].append("triton")
136+
if "tlx_" in backend:
137+
if not op_with_tags[op][backend]:
138+
op_with_tags[op][backend] = {"tags": []}
139+
op_with_tags[op][backend]["tags"].extend(["tlx"])
103140
if "eager" in backend or "aten" in backend:
104141
if not op_with_tags[op][backend]:
105142
op_with_tags[op][backend] = {"tags": []}
106143
op_with_tags[op][backend]["tags"].append("aten")
107144
return op_with_tags
108145

109146

147+
UNSUPPORTED_OPS = [
148+
"fp8_fused_quant_gemm_rowwise",
149+
"fp32_to_mx4",
150+
"flex_attention",
151+
"mx4_to_fp32",
152+
]
153+
110154
if __name__ == "__main__":
111155
parser = get_parser()
112156
args = parser.parse_args()
113157
if not args.op:
114158
ops = list_operators()
115159
else:
116160
ops = [args.op]
161+
print(f"Running tagging test on ops: {ops}...")
117162
results = {}
118163
for op in ops:
164+
# deadloop on flex_attention
165+
if op in UNSUPPORTED_OPS:
166+
continue
119167
results.update(trace_op(op))
120-
print(results)
168+
if not args.output:
169+
print(results)
170+
else:
171+
with open(args.output, "w") as f:
172+
f.write(yaml.safe_dump(results))
173+
print("success!")

tools/tilelang/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
tilelang==0.1.3
22
Cython
33
decorator
4+
cloudpickle

0 commit comments

Comments
 (0)