Skip to content

Commit

Permalink
Upgrade to PyTorch 2.3. (#546)
Browse files Browse the repository at this point in the history
As discussed on Discord, this is a significant upgrade because it is the
first stable release that has a fully functional `torch.export.export`
with the preferred dynamic shapes support. It is also just prior to
nightlies that completely remove support for the old constraints based
API, so is therefore a good point to stop for a moment and support both
styles.

This patch makes a number of API changes:

* Issues deprecation warnings if the `constraints=` keyword for jittable
is used, otherwise not passing it to PyTorch. This should make jittable
not immediately incompatible with later nightlies unless if that feature
is used.
* Adds the ability for a `CompiledModule` to directly have an attribute
of a `torch.export.ExportedProgram`, allowing the user to pre-export
with Torch and then construct a compiled module from that (vs the
`jittable` approach where the `CompiledModule` API was directly invoking
Torch internals to do so). This defaults to exporting as `public` if
given a name not starting with an underscore and private otherwise.
Private ExportedPrograms can be called from procedures just as with
`jittable`.
* `shark_turbine.aot.export()` now accepts either an `CompiledModule`,
`nn.Module`, a or a `torch.export.ExportedProgram`. For the last two, a
new `external_params=` bool is available to control whether parameters
are inlined or externalized. For an `nn.Module` arguments corresponding
to `torch.export.export` are added. Internally, for an `nn.Module`, it
simply calls `torch.export.export`. `jittable` is no longer used
internally.

Some attempt has been made to be backwards compatible with Torch 2.1.0.
New features will not work, but we should be able to support a short
buffer window where older pinned systems are not completely broken. The
repository prior to this patch will be branched to `torch_2.1`.

Breaking changes:

* ops.iree.trace_tensors (plural) had to be removed because the PyTorch
auto functionalization thing has a TODO around lists of tensors. We can
add a wrapper that takes a list and invokves trace_tensors multiple
times and/or ass a `functional_trace_tensors` which works a bit better
with the infra.
* stateless_llama_test.py::test_rerotated_torch_comparison marked as
expectedFailure. Filed #560
  • Loading branch information
stellaraccident authored Mar 27, 2024
1 parent b785714 commit b73c5c3
Show file tree
Hide file tree
Showing 28 changed files with 669 additions and 201 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ jobs:
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
pip install -r core/pytorch-cpu-requirements.txt
pip install --upgrade \
-r core/requirements.txt \
-r mypy-requirements.txt
-r mypy-requirements.txt \
-r serving/requirements.txt
pip install -e core[testing] -e serving[testing]
- name: Run core tests
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ jobs:
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
pip install --upgrade -r core/requirements.txt
pip install -e core[testing]
pip install -e models
pip install -r core/pytorch-cpu-requirements.txt
pip install --pre --upgrade -r core/requirements.txt
pip install --pre -e core[testing]
pip install --pre -e models
- name: Show current free memory
run: |
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/test_sdxl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ jobs:
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
-r core/pytorch-cpu-requirements.txt
pip install --upgrade -r core/requirements.txt
pip install -e core[testing,torch-cpu-nightly]
pip install --upgrade -r models/requirements.txt
Expand Down
1 change: 0 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
include README.md
include requirements.txt
include pytorch-cpu-requirements.txt
include torchvision-requirements.txt
include version_info.json
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ pip install shark-turbine
The above does install some unecessary cuda/cudnn packages for cpu use. To avoid this you
can specify pytorch-cpu and install via:
```
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
pip install -r core/pytorch-cpu-requirements.txt
pip install shark-turbine
```

Expand Down
7 changes: 6 additions & 1 deletion core/examples/aot_mlp/mlp_export_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def main(self, x=aot.AbstractTensor(None, 97, 8, dtype=torch.float32)):
)


exported = aot.export(CompiledMLP)
batch = torch.export.Dim("batch")
exported = aot.export(
model,
args=(torch.empty([2, 97, 8], dtype=torch.float32),),
dynamic_shapes={"x": {0: batch}},
)
# Note that dynamic Torch IR is created below.
exported.print_readable()

Expand Down
4 changes: 2 additions & 2 deletions core/iree-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
iree-compiler==20240311.828
iree-runtime==20240311.828
iree-compiler==20240327.844
iree-runtime==20240327.844
4 changes: 2 additions & 2 deletions core/pytorch-cpu-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
--pre
torch==2.1.0
mpmath==1.3.0
--index-url https://download.pytorch.org/whl/test/cpu
-r pytorch-requirements.txt
3 changes: 3 additions & 0 deletions core/pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch==2.3.0
torchaudio
torchvision
3 changes: 1 addition & 2 deletions core/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
# versions, not specific).
-f https://openxla.github.io/iree/pip-release-links.html

-r pytorch-cpu-requirements.txt
-r torchvision-requirements.txt
-r pytorch-requirements.txt
-r iree-requirements.txt
27 changes: 17 additions & 10 deletions core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@

from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union

import warnings

import torch
from torch._decomp import get_decompositions
import torch._dynamo as dynamo
from torch.export import (
Constraint,
dynamic_dim,
)
from torch.fx import (
Graph,
GraphModule,
)
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._pytree import (
tree_flatten,
tree_unflatten,
Expand Down Expand Up @@ -148,7 +144,7 @@ def __init__(
*,
decompose_ops: Optional[List[Any]] = None,
decomposition_table: Optional[Dict[Any, Callable[..., Any]]] = None,
constraints: Optional[List[Constraint]] = None,
constraints: Optional[List[Any]] = None,
function_name: Optional[str] = None,
passes: Sequence[str] = DEFAULT_PASSES,
):
Expand Down Expand Up @@ -176,7 +172,7 @@ def resolve_call(
self,
proc_trace: IrTrace,
*py_args,
constraints: Optional[List[Constraint]] = None,
constraints: Optional[List[Any]] = None,
**py_kwargs,
):
type_converter = proc_trace.module_builder.native_type_converter
Expand All @@ -188,6 +184,17 @@ def resolve_call(
if self.constraints is not None:
constraints.extend(self.constraints)

export_kwargs = {}
if len(constraints) > 0:
warnings.warn(
"Compiling program with the old PyTorch constraints system "
"for dynamic shapes is deprecated and will break on PyTorch "
"nightlies after the 2.3 release cut (expect either a PyTorch "
"warning or exception to follow)",
DeprecationWarning,
)
export_kwargs["constraints"] = constraints

# Convert procedural trace values to things that Dynamo can handle.
flat_py_args, args_tree = tree_flatten((py_args, py_kwargs))
flat_pytorch_args = []
Expand Down Expand Up @@ -220,8 +227,8 @@ def flat_wrapped_f(*args):
transformed_f,
aten_graph=True,
decomposition_table=self.decomposition_table,
constraints=constraints,
assume_static_by_default=True,
**export_kwargs, # type: ignore
)
logger.debug("Invoking dynamo trace")
gm, guards = exported_f(*flat_pytorch_args)
Expand Down Expand Up @@ -315,7 +322,7 @@ def flat_wrapped_f(*args):
tree_py_results = tree_unflatten(flat_py_results, out_spec)
return tree_py_results

def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]:
def _split_py_arg(self, arg, constraints: List[Any]) -> Tuple[Value, Any]:
if isinstance(arg, IrTensor):
meta_tensor, meta_constraints = arg._to_meta_tensor()
constraints.extend(meta_constraints)
Expand Down
61 changes: 59 additions & 2 deletions core/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import weakref
import sys

from torch.export import ExportedProgram

from . import builtins

from ..support.ir_imports import (
Expand All @@ -35,6 +37,8 @@
current_ir_trace,
)

from .support.procedural.exported_program import import_exported_program

from .support.ir_utils import (
ModuleBuilder,
)
Expand Down Expand Up @@ -130,7 +134,28 @@ def __repr__(self):
return f"<def {self.export_name}({self.signature})>"


Exportable = Union[ExportProcDef, PyOnlyDef, GlobalsDef]
class ExportedProgramDef:
def __init__(
self,
ep: ExportedProgram,
*,
export_name: Optional[str] = None,
public: bool = False,
):
self.export_name = export_name
self.exported_program = ep
self.public = public

def copy(self) -> "ExportedProgramDef":
return ExportedProgramDef(
self.exported_program, export_name=self.export_name, public=self.public
)

def __repr__(self):
return f"<exported_program {self.exported_program}>"


Exportable = Union[ExportProcDef, ExportedProgramDef, PyOnlyDef, GlobalsDef]


class CompiledModuleClassInfo:
Expand All @@ -155,6 +180,15 @@ def export_procs(self) -> Generator[Tuple[str, ExportProcDef], None, None]:
self.all_exports.items(),
) # type: ignore

@property
def exported_programs(
self,
) -> Generator[Tuple[str, ExportedProgramDef], None, None]:
return filter(
lambda kv_tuple: isinstance(kv_tuple[1], ExportedProgramDef),
self.all_exports.items(),
) # type: ignore

@property
def py_only_defs(self) -> Generator[Tuple[str, PyOnlyDef], None, None]:
return filter(
Expand All @@ -175,6 +209,12 @@ def def_attribute(self, key, value):
if isinstance(value, builtins.jittable):
value = PyOnlyDef(value)

# Promote a torch ExportedProgram to an ExportedProgramDef.
if isinstance(value, ExportedProgram):
value = ExportedProgramDef(
value, export_name=key, public=not key.startswith("_")
)

# Detect our own descriptors.
if isinstance(value, GlobalsDef):
logging.debug("DEFINE GLOBALS: %s = %r", key, value)
Expand All @@ -186,11 +226,17 @@ def def_attribute(self, key, value):
value.export_name = key
self.add_export(key, value)
return value

if isinstance(value, PyOnlyDef):
logging.debug("DEFINE PY_ONLY: %s = %r", key, value)
self.add_export(key, value)
return value
if isinstance(value, ExportedProgramDef):
if value.export_name is None:
value = value.copy()
value.export_name = key
logging.debug("DEFINE EXPORTED_PROGRAM: %r", value.export_name)
self.add_export(key, value)
return value

# Infer if it is an exported function.
if callable(value) and inspect.isfunction(value):
Expand Down Expand Up @@ -542,6 +588,17 @@ def __new__(
for key, py_def in info.class_info.py_only_defs:
info.shadow_dict[key] = py_def.py_value

# Instantiate exported programs.
# TODO: This should be done in two phases along with export_procs
# in order to enable dependence.
for key, ep_def in info.class_info.exported_programs:
info.shadow_dict[key] = import_exported_program(
module_builder,
ep_def.exported_program,
symbol_name=ep_def.export_name or "main",
symbol_visibility=None if ep_def.public else "private",
)

# Instantiate procs.
# TODO: This should be done in two phases, first binding the symbols
# and then defining them, enabling dependence.
Expand Down
Loading

0 comments on commit b73c5c3

Please sign in to comment.