Skip to content

Commit

Permalink
Enable decompositions for aot.export and torch.export helpers. (#574)
Browse files Browse the repository at this point in the history
This was accidentally left out of the PyTorch 2.3 upgrade rework.

Adds APIs to manage decompositions with a context manager so that we
don't thread it through every API permutation.

Exposes a bug in PyTorch 2.3 for decompositions with programs that use
python registered ops: pytorch/pytorch#122752
(I have included a really gross workaround).
  • Loading branch information
stellaraccident authored Mar 29, 2024
1 parent b73c5c3 commit 55e8703
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 61 deletions.
6 changes: 3 additions & 3 deletions core/shark_turbine/aot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .compiled_module import CompiledModule
from .exporter import export

from .builtins import *
from .compiled_module import CompiledModule
from .decompositions import *
from .exporter import *
from .fx_programs import FxPrograms, FxProgramsBuilder
12 changes: 3 additions & 9 deletions core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
FxImporterHooks,
)

from ...dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)

from ...support.ir_imports import (
FlatSymbolRefAttr,
FunctionType,
Expand All @@ -46,6 +42,7 @@

from ...support.logging import aot_logger as logger

from ..decompositions import current_aot_decompositions
from ..passes import (
functorch_functionalize,
)
Expand Down Expand Up @@ -149,10 +146,7 @@ def __init__(
passes: Sequence[str] = DEFAULT_PASSES,
):
if decomposition_table is None:
decomposition_table = {}
if decompose_ops is None:
decompose_ops = DEFAULT_DECOMPOSITIONS

decomposition_table = current_aot_decompositions()
if decompose_ops:
decomposition_table.update(get_decompositions(decompose_ops))

Expand Down Expand Up @@ -226,7 +220,7 @@ def flat_wrapped_f(*args):
exported_f = dynamo.export(
transformed_f,
aten_graph=True,
decomposition_table=self.decomposition_table,
decomposition_table=self.decomposition_table, # type: ignore
assume_static_by_default=True,
**export_kwargs, # type: ignore
)
Expand Down
80 changes: 80 additions & 0 deletions core/shark_turbine/aot/decompositions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import contextlib
from typing import Optional

import torch

from ..dynamo.decompositions import (
_current,
_extend_context_manager,
DecompositionOpsList,
DecompositionTable,
)

__all__ = [
"current_aot_decompositions",
"extend_aot_decompositions",
]


def current_aot_decompositions() -> DecompositionTable:
"""Gets the current decomposition table for AOT."""
return _current("aot")


def extend_aot_decompositions(
*,
from_current: bool = True,
add_ops: Optional[DecompositionOpsList] = None,
remove_ops: Optional[DecompositionOpsList] = None
):
"""Context manager which extends the list of decompositions used for AOT."""
return _extend_context_manager(
"aot", from_current=from_current, add_ops=add_ops, remove_ops=remove_ops
)


###############################################################################
# Workarounds
###############################################################################


def _patch_op_dispatch(op):
if torch.__version__ >= "2.3.0" and torch.__version__ < "2.4":
# Around the torch 2.3.0 release cut, there was a regression such that
# running decompositions in a functionalized context did not work
# with Python registered ops. The issue is that they have an incomplete
# list of mode handler registrations and cannot handle the
# FunctionalTensorMode. Since we only have a handful of these, and
# since we can assume that for the sake of expediency, functional
# dispatch is basically the same as fake tensor dispatch, we just
# take the fake tensor registration and dup it onto the functional
# registration.
# Note that the torch._higher_order_ops.auto_functionalize is registered
# in Python and is itself broken, it needs to be monkey patched.
# See: https://github.com/pytorch/pytorch/issues/122752
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensorMode

t = op.python_key_mode_table
if FunctionalTensorMode not in t:
handler = t[FakeTensorMode]
t[FunctionalTensorMode] = handler


_patched_op_dispatch_for_export = False


def _patch_op_dispatch_for_export():
global _patched_op_dispatch_for_export
if _patched_op_dispatch_for_export:
return
_patched_op_dispatch_for_export = True
import torch._higher_order_ops.auto_functionalize

_patch_op_dispatch(torch._higher_order_ops.auto_functionalize.auto_functionalized)
11 changes: 11 additions & 0 deletions core/shark_turbine/aot/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
CompiledModuleMeta,
ImportPhase,
)
from . import decompositions

__all__ = [
"export",
"ExportOutput",
]

_is_windows = platform.system() == "Windows"

Expand Down Expand Up @@ -179,6 +184,7 @@ def export(
easy access.
"""
TransformedModule: Any
current_decomps = decompositions.current_aot_decompositions()
if isinstance(mdl, torch.export.ExportedProgram):
if (
len(example_args) > 0
Expand Down Expand Up @@ -210,6 +216,11 @@ class EpExported(CompiledModule, export_name=mdl.graph_module._get_name()):
exported_program = torch.export.export(
nn_module, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
if current_decomps:
from .decompositions import _patch_op_dispatch_for_export

_patch_op_dispatch_for_export()
exported_program = exported_program.run_decompositions(current_decomps)

class Exported(CompiledModule, export_name=nn_module._get_name()):
params = export_parameters(nn_module, external=external_params)
Expand Down
8 changes: 8 additions & 0 deletions core/shark_turbine/aot/fx_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
import torch.nn as nn

from .decompositions import current_aot_decompositions

# The dynamic_shapes support showed up in the Torch 2.3 timeframe.
_supports_dynamic_shapes = hasattr(torch.export, "Dim")

Expand Down Expand Up @@ -224,6 +226,12 @@ def new_forward(self, *forward_args, **forward_kwargs):
program = torch.export.export(
lambda_module, args=args, kwargs=kwargs, **extra_kwargs
)
current_decomps = current_aot_decompositions()
if current_decomps:
from .decompositions import _patch_op_dispatch_for_export

_patch_op_dispatch_for_export()
program = program.run_decompositions(current_decomps)
fx_builder.programs[name] = program
return program

Expand Down
129 changes: 129 additions & 0 deletions core/shark_turbine/dynamo/decompositions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable, Dict, List, Optional, Sequence, Union

import contextlib
import threading

import torch
from torch._decomp import get_decompositions, remove_decompositions

DecompositionTable = Dict[torch._ops.OperatorBase, Callable]
DecompositionOpsList = Sequence[
Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]
]

# Manages "scopes" for decompositions used. Each unique scope is an attribute on
# the _decomp_local. If the attribute is missing, then the default
# decompositions are used. The scope "aot" is used for all AOT cases.
_decomp_local = threading.local()


def _get_decomp_stack(scope: str) -> List[DecompositionTable]:
try:
return getattr(_decomp_local, scope)
except AttributeError:
stack: List[DecompositionTable] = []
setattr(_decomp_local, scope, stack)
return stack


def _current(scope: str) -> DecompositionTable:
"""Gets the current decomposition table (which may be the default)."""
stack = _get_decomp_stack(scope)
if stack:
return dict(stack[-1])
else:
return dict(DEFAULT_DECOMPOSITION_TABLE)


@contextlib.contextmanager
def _extend_context_manager(
scope: str,
*,
from_current: bool = True,
add_ops: Optional[DecompositionOpsList] = None,
remove_ops: Optional[DecompositionOpsList] = None
):
table: DecompositionTable
if from_current:
table = dict(_current(scope))
else:
table = {}
if add_ops:
table.update(get_decompositions(add_ops))
if remove_ops:
remove_decompositions(table, remove_ops) # type: ignore
stack = _get_decomp_stack(scope)
stack.append(table)
try:
yield table
finally:
popped = stack.pop()
assert (
popped is table
), "contextmanager unbalanced: popped different that pushed"


def _get_default_decomposition_ops() -> DecompositionOpsList:
aten = torch.ops.aten
# default decompositions pulled from SHARK / torch._decomp
return [
aten.embedding_dense_backward,
aten.native_layer_norm_backward,
aten.slice_backward,
aten.select_backward,
aten.norm.ScalarOpt_dim,
aten.native_group_norm,
aten.upsample_bilinear2d.vec,
aten.split.Tensor,
aten.split_with_sizes,
aten.native_layer_norm,
aten.masked_fill.Tensor,
aten.masked_fill.Scalar,
aten.t,
aten.addmm,
# decompositions that aid us in handling nn.BatchNorm2d
aten._native_batch_norm_legit_functional,
aten._native_batch_norm_legit_no_training,
aten._native_batch_norm_legit,
aten._native_batch_norm_legit.no_stats,
aten.squeeze.dims,
# decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions
aten.soft_margin_loss,
aten.im2col,
aten._euclidean_dist,
aten.index_copy,
aten.index_copy_,
aten.grid_sampler_2d,
aten.log_sigmoid_forward,
aten.unsafe_split.Tensor,
aten.binary_cross_entropy,
aten.dot,
aten._adaptive_avg_pool2d,
aten._prelu_kernel,
aten.full,
aten._log_softmax,
aten.nll_loss_forward,
aten.nll_loss_backward,
aten._to_copy,
aten._log_softmax_backward_data,
aten.lift_fresh_copy.default,
aten._unsafe_index.Tensor,
aten.unbind.int,
# decompositions added manually in this file
aten._scaled_dot_product_flash_attention.default,
]


# Some older APIs still use an op list instead of a table.
DEFAULT_DECOMPOSITIONS: DecompositionOpsList = _get_default_decomposition_ops()

# The table of default decompositions.
DEFAULT_DECOMPOSITION_TABLE: DecompositionTable = get_decompositions(
DEFAULT_DECOMPOSITIONS
)
50 changes: 2 additions & 48 deletions core/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,7 @@
from torch.func import functionalize
from typing import List, Optional

# default decompositions pulled from SHARK / torch._decomp
DEFAULT_DECOMPOSITIONS = [
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
torch.ops.aten.t,
torch.ops.aten.addmm,
# decompositions that aid us in handling nn.BatchNorm2d
torch.ops.aten._native_batch_norm_legit_functional,
torch.ops.aten._native_batch_norm_legit_no_training,
torch.ops.aten._native_batch_norm_legit,
torch.ops.aten._native_batch_norm_legit.no_stats,
torch.ops.aten.squeeze.dims,
# decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions
torch.ops.aten.soft_margin_loss,
torch.ops.aten.im2col,
torch.ops.aten._euclidean_dist,
torch.ops.aten.index_copy,
torch.ops.aten.index_copy_,
torch.ops.aten.grid_sampler_2d,
torch.ops.aten.log_sigmoid_forward,
torch.ops.aten.unsafe_split.Tensor,
torch.ops.aten.binary_cross_entropy,
torch.ops.aten.dot,
torch.ops.aten._adaptive_avg_pool2d,
torch.ops.aten._prelu_kernel,
torch.ops.aten.full,
torch.ops.aten._log_softmax,
torch.ops.aten.nll_loss_forward,
torch.ops.aten.nll_loss_backward,
torch.ops.aten._to_copy,
torch.ops.aten._log_softmax_backward_data,
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten._unsafe_index.Tensor,
torch.ops.aten.unbind.int,
# decompositions added manually in this file
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
from .decompositions import DEFAULT_DECOMPOSITIONS


def apply_decompositions(
Expand All @@ -72,4 +26,4 @@ def apply_decompositions(

def turbine_cpu_pass_pipeline(gm: torch.fx.GraphModule, example_inputs):
decompose_ops = DEFAULT_DECOMPOSITIONS
return apply_decompositions(gm, example_inputs, decompose_ops)
return apply_decompositions(gm, example_inputs, decompose_ops) # type: ignore
Loading

0 comments on commit 55e8703

Please sign in to comment.