-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable decompositions for aot.export and torch.export helpers. (#574)
This was accidentally left out of the PyTorch 2.3 upgrade rework. Adds APIs to manage decompositions with a context manager so that we don't thread it through every API permutation. Exposes a bug in PyTorch 2.3 for decompositions with programs that use python registered ops: pytorch/pytorch#122752 (I have included a really gross workaround).
- Loading branch information
1 parent
b73c5c3
commit 55e8703
Showing
9 changed files
with
287 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import contextlib | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from ..dynamo.decompositions import ( | ||
_current, | ||
_extend_context_manager, | ||
DecompositionOpsList, | ||
DecompositionTable, | ||
) | ||
|
||
__all__ = [ | ||
"current_aot_decompositions", | ||
"extend_aot_decompositions", | ||
] | ||
|
||
|
||
def current_aot_decompositions() -> DecompositionTable: | ||
"""Gets the current decomposition table for AOT.""" | ||
return _current("aot") | ||
|
||
|
||
def extend_aot_decompositions( | ||
*, | ||
from_current: bool = True, | ||
add_ops: Optional[DecompositionOpsList] = None, | ||
remove_ops: Optional[DecompositionOpsList] = None | ||
): | ||
"""Context manager which extends the list of decompositions used for AOT.""" | ||
return _extend_context_manager( | ||
"aot", from_current=from_current, add_ops=add_ops, remove_ops=remove_ops | ||
) | ||
|
||
|
||
############################################################################### | ||
# Workarounds | ||
############################################################################### | ||
|
||
|
||
def _patch_op_dispatch(op): | ||
if torch.__version__ >= "2.3.0" and torch.__version__ < "2.4": | ||
# Around the torch 2.3.0 release cut, there was a regression such that | ||
# running decompositions in a functionalized context did not work | ||
# with Python registered ops. The issue is that they have an incomplete | ||
# list of mode handler registrations and cannot handle the | ||
# FunctionalTensorMode. Since we only have a handful of these, and | ||
# since we can assume that for the sake of expediency, functional | ||
# dispatch is basically the same as fake tensor dispatch, we just | ||
# take the fake tensor registration and dup it onto the functional | ||
# registration. | ||
# Note that the torch._higher_order_ops.auto_functionalize is registered | ||
# in Python and is itself broken, it needs to be monkey patched. | ||
# See: https://github.com/pytorch/pytorch/issues/122752 | ||
from torch._subclasses.fake_tensor import FakeTensorMode | ||
from torch._subclasses.functional_tensor import FunctionalTensorMode | ||
|
||
t = op.python_key_mode_table | ||
if FunctionalTensorMode not in t: | ||
handler = t[FakeTensorMode] | ||
t[FunctionalTensorMode] = handler | ||
|
||
|
||
_patched_op_dispatch_for_export = False | ||
|
||
|
||
def _patch_op_dispatch_for_export(): | ||
global _patched_op_dispatch_for_export | ||
if _patched_op_dispatch_for_export: | ||
return | ||
_patched_op_dispatch_for_export = True | ||
import torch._higher_order_ops.auto_functionalize | ||
|
||
_patch_op_dispatch(torch._higher_order_ops.auto_functionalize.auto_functionalized) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Callable, Dict, List, Optional, Sequence, Union | ||
|
||
import contextlib | ||
import threading | ||
|
||
import torch | ||
from torch._decomp import get_decompositions, remove_decompositions | ||
|
||
DecompositionTable = Dict[torch._ops.OperatorBase, Callable] | ||
DecompositionOpsList = Sequence[ | ||
Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket] | ||
] | ||
|
||
# Manages "scopes" for decompositions used. Each unique scope is an attribute on | ||
# the _decomp_local. If the attribute is missing, then the default | ||
# decompositions are used. The scope "aot" is used for all AOT cases. | ||
_decomp_local = threading.local() | ||
|
||
|
||
def _get_decomp_stack(scope: str) -> List[DecompositionTable]: | ||
try: | ||
return getattr(_decomp_local, scope) | ||
except AttributeError: | ||
stack: List[DecompositionTable] = [] | ||
setattr(_decomp_local, scope, stack) | ||
return stack | ||
|
||
|
||
def _current(scope: str) -> DecompositionTable: | ||
"""Gets the current decomposition table (which may be the default).""" | ||
stack = _get_decomp_stack(scope) | ||
if stack: | ||
return dict(stack[-1]) | ||
else: | ||
return dict(DEFAULT_DECOMPOSITION_TABLE) | ||
|
||
|
||
@contextlib.contextmanager | ||
def _extend_context_manager( | ||
scope: str, | ||
*, | ||
from_current: bool = True, | ||
add_ops: Optional[DecompositionOpsList] = None, | ||
remove_ops: Optional[DecompositionOpsList] = None | ||
): | ||
table: DecompositionTable | ||
if from_current: | ||
table = dict(_current(scope)) | ||
else: | ||
table = {} | ||
if add_ops: | ||
table.update(get_decompositions(add_ops)) | ||
if remove_ops: | ||
remove_decompositions(table, remove_ops) # type: ignore | ||
stack = _get_decomp_stack(scope) | ||
stack.append(table) | ||
try: | ||
yield table | ||
finally: | ||
popped = stack.pop() | ||
assert ( | ||
popped is table | ||
), "contextmanager unbalanced: popped different that pushed" | ||
|
||
|
||
def _get_default_decomposition_ops() -> DecompositionOpsList: | ||
aten = torch.ops.aten | ||
# default decompositions pulled from SHARK / torch._decomp | ||
return [ | ||
aten.embedding_dense_backward, | ||
aten.native_layer_norm_backward, | ||
aten.slice_backward, | ||
aten.select_backward, | ||
aten.norm.ScalarOpt_dim, | ||
aten.native_group_norm, | ||
aten.upsample_bilinear2d.vec, | ||
aten.split.Tensor, | ||
aten.split_with_sizes, | ||
aten.native_layer_norm, | ||
aten.masked_fill.Tensor, | ||
aten.masked_fill.Scalar, | ||
aten.t, | ||
aten.addmm, | ||
# decompositions that aid us in handling nn.BatchNorm2d | ||
aten._native_batch_norm_legit_functional, | ||
aten._native_batch_norm_legit_no_training, | ||
aten._native_batch_norm_legit, | ||
aten._native_batch_norm_legit.no_stats, | ||
aten.squeeze.dims, | ||
# decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions | ||
aten.soft_margin_loss, | ||
aten.im2col, | ||
aten._euclidean_dist, | ||
aten.index_copy, | ||
aten.index_copy_, | ||
aten.grid_sampler_2d, | ||
aten.log_sigmoid_forward, | ||
aten.unsafe_split.Tensor, | ||
aten.binary_cross_entropy, | ||
aten.dot, | ||
aten._adaptive_avg_pool2d, | ||
aten._prelu_kernel, | ||
aten.full, | ||
aten._log_softmax, | ||
aten.nll_loss_forward, | ||
aten.nll_loss_backward, | ||
aten._to_copy, | ||
aten._log_softmax_backward_data, | ||
aten.lift_fresh_copy.default, | ||
aten._unsafe_index.Tensor, | ||
aten.unbind.int, | ||
# decompositions added manually in this file | ||
aten._scaled_dot_product_flash_attention.default, | ||
] | ||
|
||
|
||
# Some older APIs still use an op list instead of a table. | ||
DEFAULT_DECOMPOSITIONS: DecompositionOpsList = _get_default_decomposition_ops() | ||
|
||
# The table of default decompositions. | ||
DEFAULT_DECOMPOSITION_TABLE: DecompositionTable = get_decompositions( | ||
DEFAULT_DECOMPOSITIONS | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.