forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_annotated_fn_args.py
129 lines (110 loc) · 4.28 KB
/
gen_annotated_fn_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
For procedural tests needed for __torch_function__, we use this function
to export method names and signatures as needed by the tests in
test/test_overrides.py.
python -m tools.autograd.gen_annotated_fn_args \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
$OUTPUT_DIR \
tools/autograd
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/testing/_internal/generated
"""
import argparse
import os
import textwrap
from collections import defaultdict
from typing import Any, Dict, List, Sequence
import torchgen.api.python as python
from torchgen.context import with_native_function
from torchgen.gen import parse_native_yaml
from torchgen.model import Argument, BaseOperatorName, NativeFunction
from torchgen.utils import FileManager
from .gen_python_functions import (
is_py_fft_function,
is_py_linalg_function,
is_py_nn_function,
is_py_special_function,
is_py_torch_function,
is_py_variable_method,
should_generate_py_binding,
)
def gen_annotated(
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
) -> None:
native_functions = parse_native_yaml(
native_yaml_path, tags_yaml_path
).native_functions
mappings = (
(is_py_torch_function, "torch._C._VariableFunctions"),
(is_py_nn_function, "torch._C._nn"),
(is_py_linalg_function, "torch._C._linalg"),
(is_py_special_function, "torch._C._special"),
(is_py_fft_function, "torch._C._fft"),
(is_py_variable_method, "torch.Tensor"),
)
annotated_args: List[str] = []
for pred, namespace in mappings:
groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
for f in native_functions:
if not should_generate_py_binding(f) or not pred(f):
continue
groups[f.func.name.name].append(f)
for group in groups.values():
for f in group:
annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")
template_path = os.path.join(autograd_dir, "templates")
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
fm.write_with_template(
"annotated_fn_args.py",
"annotated_fn_args.py.in",
lambda: {
"annotated_args": textwrap.indent("\n".join(annotated_args), " "),
},
)
@with_native_function
def gen_annotated_args(f: NativeFunction) -> str:
def _get_kwargs_func_exclusion_list() -> List[str]:
# functions that currently don't work with kwargs in test_overrides.py
return [
"diagonal",
"round_",
"round",
"scatter_",
]
def _add_out_arg(
out_args: List[Dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
) -> None:
for arg in args:
if arg.default is not None:
continue
out_arg: Dict[str, Any] = {}
out_arg["is_kwarg_only"] = str(is_kwarg_only)
out_arg["name"] = arg.name
out_arg["simple_type"] = python.argument_type_str(
arg.type, simple_type=True
)
size_t = python.argument_type_size(arg.type)
if size_t:
out_arg["size"] = size_t
out_args.append(out_arg)
out_args: List[Dict[str, Any]] = []
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)
return f"{f.func.name.name}: {repr(out_args)},"
def main() -> None:
parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
parser.add_argument(
"native_functions", metavar="NATIVE", help="path to native_functions.yaml"
)
parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml")
parser.add_argument("out", metavar="OUT", help="path to output directory")
parser.add_argument(
"autograd", metavar="AUTOGRAD", help="path to template directory"
)
args = parser.parse_args()
gen_annotated(args.native_functions, args.tags, args.out, args.autograd)
if __name__ == "__main__":
main()