|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -import functools |
7 | | -from typing import Any, Dict, Optional, Tuple, Callable, Union |
8 | | -import torch |
9 | | -from torch._C import _disabled_torch_function_impl |
10 | | -import torch.utils._pytree as pytree |
11 | | -from torch.fx import Tracer, GraphModule |
12 | | -import torch.fx as fx |
13 | | -from torch.fx.passes.shape_prop import _extract_tensor_metadata |
14 | | -from contextlib import contextmanager |
| 6 | +__all__ = ["make_fx", "ProxyTensor", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"] |
| 7 | +from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose |
15 | 8 |
|
16 | | -aten = torch.ops.aten |
17 | | - |
18 | | -CURRENT_DECOMPOSITION_TABLE = {} |
19 | | - |
20 | | - |
21 | | -@contextmanager |
22 | | -def no_dispatch(): |
23 | | - guard = torch._C._DisableTorchDispatch() |
24 | | - try: |
25 | | - yield |
26 | | - finally: |
27 | | - del guard |
28 | | - |
29 | | - |
30 | | -@contextmanager |
31 | | -def pythonkey_decompose(decomposition_table): |
32 | | - global CURRENT_DECOMPOSITION_TABLE |
33 | | - CURRENT_DECOMPOSITION_TABLE = decomposition_table |
34 | | - try: |
35 | | - yield CURRENT_DECOMPOSITION_TABLE |
36 | | - finally: |
37 | | - CURRENT_DECOMPOSITION_TABLE = {} |
38 | | - |
39 | | - |
40 | | -class PythonTensor(torch.Tensor): |
41 | | - elem: torch.Tensor |
42 | | - |
43 | | - __slots__ = ['elem', 'proxy'] |
44 | | - |
45 | | - @staticmethod |
46 | | - def __new__(cls, elem, proxy): |
47 | | - # Wrapping something in PythonTensor implicitly detaches |
48 | | - # gradients. If something required grad, we will collect it as if it |
49 | | - # were a leaf. A consequence of detaching in this way is you |
50 | | - # need to maintain a parameter cache when translating tensors |
51 | | - # into PythonTensor, so you don't create multiple copies of |
52 | | - # a gradient (they are aliased, but they would count as independent |
53 | | - # leaves). An alternate strategy would be to avoid implicitly |
54 | | - # detaching and instead "catch" gradients as they exit the |
55 | | - # PythonTensor boundary. |
56 | | - # assert not elem.requires_grad or not torch.is_grad_enabled() |
57 | | - |
58 | | - r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) |
59 | | - r.proxy = proxy |
60 | | - if elem.is_sparse: |
61 | | - proxy.node.meta['tensor_meta'] = {} |
62 | | - else: |
63 | | - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) |
64 | | - return r |
65 | | - |
66 | | - def __repr__(self): |
67 | | - with no_dispatch(): |
68 | | - return f"PythonTensor({self.as_subclass(torch.Tensor)})" |
69 | | - |
70 | | - __torch_function__ = _disabled_torch_function_impl |
71 | | - |
72 | | - def __deepcopy__(self, memo): |
73 | | - return self.clone() |
74 | | - |
75 | | - @classmethod |
76 | | - def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): |
77 | | - func = func_overload.overloadpacket |
78 | | - if func_overload in CURRENT_DECOMPOSITION_TABLE: |
79 | | - return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) |
80 | | - # Commenting this out for now since it causes some spurious failures (such as error checking) |
81 | | - # if func == aten._local_scalar_dense: |
82 | | - # raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " |
83 | | - # "It's likely that this is caused by data-dependent control flow or similar.") |
84 | | - |
85 | | - def unwrap_proxy(e): |
86 | | - return e.proxy if isinstance(e, PythonTensor) else e |
87 | | - |
88 | | - proxy_args = pytree.tree_map(unwrap_proxy, args) |
89 | | - proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) |
90 | | - |
91 | | - proxy_out = func(*proxy_args, **proxy_kwargs) |
92 | | - |
93 | | - # Kind of a hacky way to test if an op is in-place or not |
94 | | - if func.__name__[-1] == "_" and func.__name__[0] != "_": |
95 | | - args[0].proxy = proxy_out |
96 | | - proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) |
97 | | - |
98 | | - with no_dispatch(): |
99 | | - real_out = func_overload(*args, **kwargs) |
100 | | - |
101 | | - def wrap_with_proxy(e, proxy): |
102 | | - # Some ops (like native_batch_norm_backward) return undefined tensors that get |
103 | | - # converted into None in python. |
104 | | - # As the function signature expects tensors, if we directly return these None |
105 | | - # tensors back to C++, we'll error. |
106 | | - if e is None: |
107 | | - e = torch.empty(()) |
108 | | - if type(e) == torch.Tensor: |
109 | | - return PythonTensor(e, proxy) |
110 | | - else: |
111 | | - return e |
112 | | - if isinstance(real_out, tuple): |
113 | | - return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) |
114 | | - elif isinstance(real_out, list): |
115 | | - return [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)] |
116 | | - elif isinstance(real_out, torch.Tensor): |
117 | | - return wrap_with_proxy(real_out, proxy_out) |
118 | | - else: |
119 | | - return real_out |
120 | | - |
121 | | - |
122 | | -class PythonKeyTracer(Tracer): |
123 | | - def __init__(self): |
124 | | - super().__init__() |
125 | | - |
126 | | - def call_module( |
127 | | - self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] |
128 | | - ) -> Any: |
129 | | - return forward(*args, **kwargs) |
130 | | - |
131 | | - def _module_getattr(self, attr, attr_val, parameter_proxy_cache): |
132 | | - if isinstance(attr_val, torch.nn.Parameter): |
133 | | - for n, p in self.root.named_parameters(): |
134 | | - if attr_val is p: |
135 | | - if n not in parameter_proxy_cache: |
136 | | - proxy = self.create_proxy('get_attr', n, (), {}) |
137 | | - parameter_proxy_cache[n] = PythonTensor(attr_val, proxy) |
138 | | - return parameter_proxy_cache[n] |
139 | | - return attr_val |
140 | | - return attr_val |
141 | | - |
142 | | - # We need to do this so that parameters entering the `make_fx` context have |
143 | | - # a reference to them (and also have requires_grad set on them correctly |
144 | | - # I'm not actually sure if this is the right thing to do ... |
145 | | - def create_arg(self, a: Any): |
146 | | - if isinstance(a, torch.nn.Parameter): |
147 | | - for n, p in self.root.named_parameters(): |
148 | | - if a is p: |
149 | | - return self.create_node('get_attr', n, (), {}) |
150 | | - qualname: Optional[str] = None |
151 | | - |
152 | | - if not qualname: |
153 | | - i = 0 |
154 | | - while True: |
155 | | - qualname = f'_param_constant{i}' |
156 | | - if not hasattr(self.root, qualname): |
157 | | - break |
158 | | - i += 1 |
159 | | - setattr(self.root, qualname, a) |
160 | | - |
161 | | - return self.create_node('get_attr', qualname, (), {}) |
162 | | - return super().create_arg(a) |
163 | | - |
164 | | - |
165 | | -def pythonkey_trace( |
166 | | - root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None |
167 | | -) -> GraphModule: |
168 | | - tracer = PythonKeyTracer() |
169 | | - graph = tracer.trace(root, concrete_args) |
170 | | - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ |
171 | | - return GraphModule(tracer.root, graph, name) |
172 | | - |
173 | | - |
174 | | -def wrap_key(f, inps): |
175 | | - flat_inps, inp_spec = pytree.tree_flatten(inps) |
176 | | - |
177 | | - @functools.wraps(f) |
178 | | - def wrapped(*args): |
179 | | - flat_args, args_spec = pytree.tree_flatten(args) |
180 | | - assert(len(flat_args) == len(flat_inps)) |
181 | | - for idx, arg in enumerate(flat_args): |
182 | | - if isinstance(flat_inps[idx], torch.Tensor): |
183 | | - flat_args[idx] = PythonTensor(flat_inps[idx], arg) |
184 | | - else: |
185 | | - flat_args[idx] = flat_inps[idx] |
186 | | - |
187 | | - tree_args = pytree.tree_unflatten(flat_args, args_spec) |
188 | | - out = f(*tree_args) |
189 | | - flat_outs, out_spec = pytree.tree_flatten(out) |
190 | | - for idx in range(len(flat_outs)): |
191 | | - if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], PythonTensor): |
192 | | - flat_outs[idx] = flat_outs[idx].proxy |
193 | | - return pytree.tree_unflatten(flat_outs, out_spec) |
194 | | - |
195 | | - return wrapped |
196 | | - |
197 | | - |
198 | | -def make_fx(f, decomposition_table=None): |
199 | | - if decomposition_table is None: |
200 | | - decomposition_table = {} |
201 | | - |
202 | | - @functools.wraps(f) |
203 | | - def wrapped(*args): |
204 | | - phs = pytree.tree_map(lambda x: fx.PH, args) |
205 | | - with pythonkey_decompose(decomposition_table): |
206 | | - t = pythonkey_trace(wrap_key(f, args), concrete_args=tuple(phs)) |
207 | | - return t |
208 | | - |
209 | | - return wrapped |
| 9 | +pythonkey_decompose = decompose |
0 commit comments