diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cf550a1b1..b38c210e7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,12 +39,11 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt + pip install -r core/pytorch-cpu-requirements.txt pip install --upgrade \ -r core/requirements.txt \ - -r mypy-requirements.txt + -r mypy-requirements.txt \ + -r serving/requirements.txt pip install -e core[testing] -e serving[testing] - name: Run core tests diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 18ba9ac73..abdf8f17b 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -38,12 +38,10 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt - pip install --upgrade -r core/requirements.txt - pip install -e core[testing] - pip install -e models + pip install -r core/pytorch-cpu-requirements.txt + pip install --pre --upgrade -r core/requirements.txt + pip install --pre -e core[testing] + pip install --pre -e models - name: Show current free memory run: | diff --git a/.github/workflows/test_sdxl.yml b/.github/workflows/test_sdxl.yml index 5babfcbe1..5b60acc07 100644 --- a/.github/workflows/test_sdxl.yml +++ b/.github/workflows/test_sdxl.yml @@ -31,8 +31,7 @@ jobs: # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt + -r core/pytorch-cpu-requirements.txt pip install --upgrade -r core/requirements.txt pip install -e core[testing,torch-cpu-nightly] pip install --upgrade -r models/requirements.txt diff --git a/MANIFEST.in b/MANIFEST.in index faa55e3f7..1ea7c0669 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ include README.md include requirements.txt include pytorch-cpu-requirements.txt -include torchvision-requirements.txt include version_info.json diff --git a/README.md b/README.md index ef982d26a..555cdaee9 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,7 @@ pip install shark-turbine The above does install some unecessary cuda/cudnn packages for cpu use. To avoid this you can specify pytorch-cpu and install via: ``` -pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt +pip install -r core/pytorch-cpu-requirements.txt pip install shark-turbine ``` diff --git a/core/examples/aot_mlp/mlp_export_dynamic.py b/core/examples/aot_mlp/mlp_export_dynamic.py index 66ca38554..cd8636554 100644 --- a/core/examples/aot_mlp/mlp_export_dynamic.py +++ b/core/examples/aot_mlp/mlp_export_dynamic.py @@ -49,7 +49,12 @@ def main(self, x=aot.AbstractTensor(None, 97, 8, dtype=torch.float32)): ) -exported = aot.export(CompiledMLP) +batch = torch.export.Dim("batch") +exported = aot.export( + model, + args=(torch.empty([2, 97, 8], dtype=torch.float32),), + dynamic_shapes={"x": {0: batch}}, +) # Note that dynamic Torch IR is created below. exported.print_readable() diff --git a/core/iree-requirements.txt b/core/iree-requirements.txt index 23866262c..9d22d2559 100644 --- a/core/iree-requirements.txt +++ b/core/iree-requirements.txt @@ -1,2 +1,2 @@ -iree-compiler==20240311.828 -iree-runtime==20240311.828 +iree-compiler==20240327.844 +iree-runtime==20240327.844 diff --git a/core/pytorch-cpu-requirements.txt b/core/pytorch-cpu-requirements.txt index 92e78464b..e4fa5c795 100644 --- a/core/pytorch-cpu-requirements.txt +++ b/core/pytorch-cpu-requirements.txt @@ -1,3 +1,3 @@ --pre -torch==2.1.0 -mpmath==1.3.0 +--index-url https://download.pytorch.org/whl/test/cpu +-r pytorch-requirements.txt diff --git a/core/pytorch-requirements.txt b/core/pytorch-requirements.txt new file mode 100644 index 000000000..63fc21602 --- /dev/null +++ b/core/pytorch-requirements.txt @@ -0,0 +1,3 @@ +torch==2.3.0 +torchaudio +torchvision diff --git a/core/requirements.txt b/core/requirements.txt index 128012cb7..3265a2b99 100644 --- a/core/requirements.txt +++ b/core/requirements.txt @@ -4,6 +4,5 @@ # versions, not specific). -f https://openxla.github.io/iree/pip-release-links.html --r pytorch-cpu-requirements.txt --r torchvision-requirements.txt +-r pytorch-requirements.txt -r iree-requirements.txt diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index 12942b22c..06d26dd35 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -9,18 +9,14 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +import warnings + import torch from torch._decomp import get_decompositions import torch._dynamo as dynamo -from torch.export import ( - Constraint, - dynamic_dim, -) from torch.fx import ( - Graph, GraphModule, ) -from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._pytree import ( tree_flatten, tree_unflatten, @@ -148,7 +144,7 @@ def __init__( *, decompose_ops: Optional[List[Any]] = None, decomposition_table: Optional[Dict[Any, Callable[..., Any]]] = None, - constraints: Optional[List[Constraint]] = None, + constraints: Optional[List[Any]] = None, function_name: Optional[str] = None, passes: Sequence[str] = DEFAULT_PASSES, ): @@ -176,7 +172,7 @@ def resolve_call( self, proc_trace: IrTrace, *py_args, - constraints: Optional[List[Constraint]] = None, + constraints: Optional[List[Any]] = None, **py_kwargs, ): type_converter = proc_trace.module_builder.native_type_converter @@ -188,6 +184,17 @@ def resolve_call( if self.constraints is not None: constraints.extend(self.constraints) + export_kwargs = {} + if len(constraints) > 0: + warnings.warn( + "Compiling program with the old PyTorch constraints system " + "for dynamic shapes is deprecated and will break on PyTorch " + "nightlies after the 2.3 release cut (expect either a PyTorch " + "warning or exception to follow)", + DeprecationWarning, + ) + export_kwargs["constraints"] = constraints + # Convert procedural trace values to things that Dynamo can handle. flat_py_args, args_tree = tree_flatten((py_args, py_kwargs)) flat_pytorch_args = [] @@ -220,8 +227,8 @@ def flat_wrapped_f(*args): transformed_f, aten_graph=True, decomposition_table=self.decomposition_table, - constraints=constraints, assume_static_by_default=True, + **export_kwargs, # type: ignore ) logger.debug("Invoking dynamo trace") gm, guards = exported_f(*flat_pytorch_args) @@ -315,7 +322,7 @@ def flat_wrapped_f(*args): tree_py_results = tree_unflatten(flat_py_results, out_spec) return tree_py_results - def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]: + def _split_py_arg(self, arg, constraints: List[Any]) -> Tuple[Value, Any]: if isinstance(arg, IrTensor): meta_tensor, meta_constraints = arg._to_meta_tensor() constraints.extend(meta_constraints) diff --git a/core/shark_turbine/aot/compiled_module.py b/core/shark_turbine/aot/compiled_module.py index f9b01e255..aa8e687c4 100644 --- a/core/shark_turbine/aot/compiled_module.py +++ b/core/shark_turbine/aot/compiled_module.py @@ -15,6 +15,8 @@ import weakref import sys +from torch.export import ExportedProgram + from . import builtins from ..support.ir_imports import ( @@ -35,6 +37,8 @@ current_ir_trace, ) +from .support.procedural.exported_program import import_exported_program + from .support.ir_utils import ( ModuleBuilder, ) @@ -130,7 +134,28 @@ def __repr__(self): return f"" -Exportable = Union[ExportProcDef, PyOnlyDef, GlobalsDef] +class ExportedProgramDef: + def __init__( + self, + ep: ExportedProgram, + *, + export_name: Optional[str] = None, + public: bool = False, + ): + self.export_name = export_name + self.exported_program = ep + self.public = public + + def copy(self) -> "ExportedProgramDef": + return ExportedProgramDef( + self.exported_program, export_name=self.export_name, public=self.public + ) + + def __repr__(self): + return f"" + + +Exportable = Union[ExportProcDef, ExportedProgramDef, PyOnlyDef, GlobalsDef] class CompiledModuleClassInfo: @@ -155,6 +180,15 @@ def export_procs(self) -> Generator[Tuple[str, ExportProcDef], None, None]: self.all_exports.items(), ) # type: ignore + @property + def exported_programs( + self, + ) -> Generator[Tuple[str, ExportedProgramDef], None, None]: + return filter( + lambda kv_tuple: isinstance(kv_tuple[1], ExportedProgramDef), + self.all_exports.items(), + ) # type: ignore + @property def py_only_defs(self) -> Generator[Tuple[str, PyOnlyDef], None, None]: return filter( @@ -175,6 +209,12 @@ def def_attribute(self, key, value): if isinstance(value, builtins.jittable): value = PyOnlyDef(value) + # Promote a torch ExportedProgram to an ExportedProgramDef. + if isinstance(value, ExportedProgram): + value = ExportedProgramDef( + value, export_name=key, public=not key.startswith("_") + ) + # Detect our own descriptors. if isinstance(value, GlobalsDef): logging.debug("DEFINE GLOBALS: %s = %r", key, value) @@ -186,11 +226,17 @@ def def_attribute(self, key, value): value.export_name = key self.add_export(key, value) return value - if isinstance(value, PyOnlyDef): logging.debug("DEFINE PY_ONLY: %s = %r", key, value) self.add_export(key, value) return value + if isinstance(value, ExportedProgramDef): + if value.export_name is None: + value = value.copy() + value.export_name = key + logging.debug("DEFINE EXPORTED_PROGRAM: %r", value.export_name) + self.add_export(key, value) + return value # Infer if it is an exported function. if callable(value) and inspect.isfunction(value): @@ -542,6 +588,17 @@ def __new__( for key, py_def in info.class_info.py_only_defs: info.shadow_dict[key] = py_def.py_value + # Instantiate exported programs. + # TODO: This should be done in two phases along with export_procs + # in order to enable dependence. + for key, ep_def in info.class_info.exported_programs: + info.shadow_dict[key] = import_exported_program( + module_builder, + ep_def.exported_program, + symbol_name=ep_def.export_name or "main", + symbol_visibility=None if ep_def.public else "private", + ) + # Instantiate procs. # TODO: This should be done in two phases, first binding the symbols # and then defining them, enabling dependence. diff --git a/core/shark_turbine/aot/exporter.py b/core/shark_turbine/aot/exporter.py index 2bb746df2..509f584a1 100644 --- a/core/shark_turbine/aot/exporter.py +++ b/core/shark_turbine/aot/exporter.py @@ -4,8 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Optional, Sequence, Union -import functools +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import io from pathlib import Path import platform @@ -27,18 +26,14 @@ from .compiled_module import ( CompiledModule, CompiledModuleMeta, - ExportProcDef, ImportPhase, ) -from .support.procedural import ( - AbstractTypedef, -) _is_windows = platform.system() == "Windows" -ModuleLike = Union[torch.nn.Module, CompiledModuleMeta] +ModuleLike = Union[torch.nn.Module, CompiledModuleMeta, torch.export.ExportedProgram] SaveableTarget = Union[str, Path, None, Output] @@ -150,48 +145,89 @@ def compile( return None -# Decorator which explicitly exports a function. -# TODO: Make this a public API on CompiledModule. -# See https://github.com/nod-ai/SHARK-Turbine/issues/126 -def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> Any: - if f is None: - return functools.partial(export_proc, signature=signature) - return ExportProcDef(f.__name__, f, signature=signature) - +def export( + mdl: ModuleLike, + *example_args: torch.Tensor, + args: Optional[tuple] = None, + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Dict[str, Any] | Tuple[Any] | List[Any] | None = None, + external_params: bool = False, +) -> ExportOutput: + """One shot export of an nn.Module or CompiledModule. -def export(mdl: ModuleLike, *example_args: torch.Tensor) -> ExportOutput: - """One shot export of an nn.Module. + This function behaves differently based on the type of the `mdl` argument: - This is a very restrictive API vs the lower level `CompiledModule` - facility. It is suitable for one-shot modules, with a single - entrypoint and static example arguments where no additional - configuration is needed for mutable parameters/buffers or state - management. Dynamic shape constraints are also not presently - exposed via this API, but we expect to allow this in the future. + * nn.Module: The module is traced with torch.export.export passing it + `args`, `kwargs`, and `dynamic_shapes`. + * CompiledModule: The module is imported to IR. Additional arguments are + illegal in this case. + * torch.export.ExportedProgram: A pre-exported program can be passed and + it will be used to construct a single-entrypoint module. Args: mdl: The nn.Module to export. *example_args: Example tensors. + args: Example arguments to torch.export (if present, then *example_args + must be empty. + kwargs: Example keyword arguments. + dynamic_shapes: Dynamic shape specs to pass to torch.export. + external_params: Whether to declare parameters as external vs inlining + contents. Returns: An ExportOutput object that wraps the compilation and provides easy access. """ TransformedModule: Any - if isinstance(mdl, torch.nn.Module): + if isinstance(mdl, torch.export.ExportedProgram): + if ( + len(example_args) > 0 + or args is not None + or kwargs is not None + or dynamic_shapes is not None + ): + raise ValueError( + "If passing an ExportedProgram to aot.export, cannot also pass " + "args, example_args, kwargs, or dynamic_dims" + ) + + class EpExported(CompiledModule, export_name=mdl.graph_module._get_name()): + params = export_global_tree( + dict(list(mdl.named_parameters())), external=external_params + ) + main = mdl + + TransformedModule = EpExported + elif isinstance(mdl, torch.nn.Module): + # Normalize arguments for torch.export. + if args is None: + args = example_args + elif len(example_args) > 0: + raise ValueError( + "Cannot pass args= and positional example_args at the same time" + ) nn_module = mdl - signature = [abstractify(t) for t in example_args] + exported_program = torch.export.export( + nn_module, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) class Exported(CompiledModule, export_name=nn_module._get_name()): - params = export_parameters(nn_module) - - @export_proc(signature=signature) - def main(self, *args): - return jittable(nn_module.forward)(*args) + params = export_parameters(nn_module, external=external_params) + main = exported_program TransformedModule = Exported else: assert isinstance(mdl, CompiledModuleMeta) + if ( + len(example_args) > 0 + or args is not None + or kwargs is not None + or dynamic_shapes is not None + ): + raise ValueError( + "If passing a CompiledModule to aot.export, cannot also pass " + "args, example_args, kwargs, or dynamic_dims" + ) TransformedModule = mdl session = Session() diff --git a/core/shark_turbine/aot/support/procedural/exported_program.py b/core/shark_turbine/aot/support/procedural/exported_program.py new file mode 100644 index 000000000..4fd0c166c --- /dev/null +++ b/core/shark_turbine/aot/support/procedural/exported_program.py @@ -0,0 +1,280 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# Portions Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Dict, List, Optional + +import inspect + +import torch + +from torch.utils._pytree import ( + tree_flatten, + tree_unflatten, +) + +try: + from torch.utils._pytree import treespec_pprint +except ImportError: + # torch < 2.3 does not include this. + treespec_pprint = lambda x: repr(x) # type: ignore + +from iree.compiler.extras.fx_importer import ( + FxImporter, + FxImporterHooks, + GraphNodeImporter, +) + +from ....support.logging import aot_logger as logger + +from ....support.ir_imports import ( + func_d, + util_d, + FlatSymbolRefAttr, + FunctionType, + IrType, + Operation, + StringAttr, + TypeAttr, + Value, +) + +from ..ir_utils import ( + ModuleBuilder, +) + +from .base import ( + CallableIntrinsic, +) + +from .primitives import ( + IrImmediateTensor, + IrTensor, +) + +from .tracer import ( + IrTrace, +) + + +class ExportedProgramIntrinsic(CallableIntrinsic): + def __init__( + self, + entry_func_op: Operation, + entry_sig: torch.export.ModuleCallSignature, + user_output_dtypes: List[Optional[torch.dtype]], + ): + self.entry_func_op = entry_func_op + self.entry_sig = entry_sig + self.user_output_dtypes = user_output_dtypes + + @property + def function_type(self) -> FunctionType: + return TypeAttr(self.entry_func_op.attributes["function_type"]).value + + @property + def function_symbol(self) -> StringAttr: + return StringAttr(self.entry_func_op.attributes["sym_name"]) + + @property + def function_visibility(self) -> StringAttr: + return StringAttr(self.entry_func_op.attributes["sym_visibility"]) + + def resolve_call( + self, + proc_trace: IrTrace, + *py_args, + **py_kwargs, + ): + visibility = self.function_visibility + if visibility.value != "private": + raise ValueError( + f"Currently, only private ExportedPrograms can be called: " + f"{self.function_symbol} is {visibility}" + ) + + # Flatten and convert py args to torch IR values by converting to + # the canonical tree structure for args + # (tuple of list of args, dict of kwargs). + flat_py_args, args_tree = tree_flatten(((list(py_args),), py_kwargs)) + if args_tree != self.entry_sig.in_spec: + raise ValueError( + f"Mismatched arguments to exported program. \n" + f" Got: {treespec_pprint(args_tree)}\n" + f" Expected: {treespec_pprint(self.entry_sig.in_spec)} " + ) + function_type = self.function_type + flat_ir_args = [ + self._py_to_torch_ir(proc_trace, py_arg, torch_type) + for py_arg, torch_type in zip(flat_py_args, function_type.inputs) + ] + + # Call. + with proc_trace.ip, proc_trace.loc: + flat_ir_results = func_d.CallOp( + function_type.results, + FlatSymbolRefAttr.get(self.function_symbol.value), + flat_ir_args, + ).results + + # Convert torch IR values to python. + flat_py_results = [ + self._torch_ir_to_py(proc_trace, ir_value, dtype) + for ir_value, dtype in zip(flat_ir_results, self.user_output_dtypes) + ] + + return tree_unflatten(flat_py_results, self.entry_sig.out_spec) + + def _py_to_torch_ir( + self, proc_trace: IrTrace, py_value, torch_type: IrType + ) -> Value: + type_converter = proc_trace.module_builder.native_type_converter + if isinstance(py_value, IrTensor): + # TODO: Allow certain static info casts. + return type_converter.materialize_native_to_torch( + py_value.ir_value, torch_type + ) + else: + raise ValueError( + f"Unsupported type in arguments of call to ExportedProgram: " + f"{type(py_value)}: {py_value}" + ) + + def _torch_ir_to_py( + self, proc_trace: IrTrace, ir_value: Value, dtype: Optional[torch.dtype] + ): + type_converter = proc_trace.module_builder.native_type_converter + native_ir_value = type_converter.materialize_torch_to_native(ir_value) + if dtype is not None: + return IrImmediateTensor(native_ir_value, dtype) + else: + raise TypeError( + f"Unknown PyTorch->IREE value mapping for ExportedProgram output: " + f"{native_ir_value}" + ) + + +def import_exported_program( + module_builder: ModuleBuilder, + exported_program: torch.export.ExportedProgram, + symbol_name: str, + symbol_visibility: Optional[str], +) -> ExportedProgramIntrinsic: + fx_importer = _create_fx_importer(module_builder) + entry_func_op = fx_importer.import_program( + exported_program, func_name=symbol_name, func_visibility=symbol_visibility + ) + + module_call_graph = exported_program.module_call_graph + assert len(module_call_graph) >= 1, "Expected at least one module call signature" + entry_module_call_entry = module_call_graph[0] + assert ( + entry_module_call_entry.fqn == "" + ), "Expected first module call entry to be unnamed" + + # We want additional torch-level metadata about any user outputs. + # This will help us create a true python fake without loss of information. + # TODO: It is unclear how much switchiness is actually needed here as + # modern use is pretty constrained. Potentially streamline the body of + # the for loop once done with full test cases available. + user_output_dtypes: list[Optional[torch.dtype]] = [] + node_map: Dict[str, torch.fx.Node] = { + n.name: n for n in exported_program.graph.nodes + } + for user_output in exported_program.graph_signature.user_outputs: + output_node = node_map[user_output] + tensor_meta = output_node.meta.get("tensor_meta") + fake_val = output_node.meta.get("val") + dtype = None + if tensor_meta is not None: + dtype = tensor_meta.dtype + elif fake_val is not None: + dtype = fake_val.dtype + user_output_dtypes.append(dtype) + + return ExportedProgramIntrinsic( + entry_func_op, entry_module_call_entry.signature, user_output_dtypes + ) + + +class _Hooks(FxImporterHooks): + def __init__(self, module_builder: ModuleBuilder): + self.module_builder = module_builder + + def resolve_literal(self, gni: GraphNodeImporter, literal: Any) -> Optional[Value]: + module_builder = self.module_builder + + # We support resolution of tracked reference types. Currently this + # only includes Tensors. All others we let the importer do what it + # is going to do. + if not isinstance(literal, torch.Tensor): + return None + + # See if we know about it. + mapping = module_builder.global_ref_tracker.track(literal) + if mapping.is_empty: + # If it is unknown, just let the default importer take it on. + return None + + # Already materialized. + logger.debug("Resolved defined global for literal %r", mapping) + materialized_global: MaterializedGlobal = mapping.value # type: ignore + + # Emit a global load and conversion. + vtensor_type = gni._cc.tensor_to_vtensor_type(literal) + loaded_value = util_d.GlobalLoadOp( + materialized_global.ir_type, materialized_global.symbol_name + ).result + converted_value = Operation.create( + "torch_c.from_builtin_tensor", + results=[vtensor_type], + operands=[loaded_value], + ).result + return converted_value + + +# In https://github.com/llvm/torch-mlir/pull/3046, the FxImporter was +# extended to accept a "module_op" as an Operation (vs a Module). Switch for +# compatibility. +_fx_importer_accepts_module_op = ( + "module_op" in inspect.getfullargspec(FxImporter).kwonlyargs +) + + +def _create_fx_importer(module_builder: ModuleBuilder) -> FxImporter: + hooks = _Hooks(module_builder) + if _fx_importer_accepts_module_op: + # New path. + return FxImporter( + module_op=module_builder.module_op, + config_check=False, + py_attr_tracker=module_builder.fx_py_attr_tracker, + hooks=hooks, + ) + else: + # Legacy path. + class FakeModule: + def __init__(self, op): + self._op = module_builder.module_op + + @property + def context(self): + return self._op.context + + @property + def operation(self): + return self._op + + @property + def body(self): + return self._op.regions[0].blocks[0] + + return FxImporter( + module=FakeModule(module_builder.module_op), + config_check=False, + py_attr_tracker=module_builder.fx_py_attr_tracker, + hooks=hooks, + ) diff --git a/core/shark_turbine/dynamo/passes.py b/core/shark_turbine/dynamo/passes.py index 5a9a7d16b..18220910f 100644 --- a/core/shark_turbine/dynamo/passes.py +++ b/core/shark_turbine/dynamo/passes.py @@ -1,7 +1,6 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions -from shark_turbine.dynamo import utils from torch.func import functionalize from typing import List, Optional diff --git a/core/shark_turbine/dynamo/utils.py b/core/shark_turbine/dynamo/utils.py deleted file mode 100644 index 05035e803..000000000 --- a/core/shark_turbine/dynamo/utils.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -from torch._prims_common.wrappers import out_wrapper -from torch._prims_common import ( - DeviceLikeType, - TensorLikeType, -) -import torch._refs as _refs -from torch._decomp import get_decompositions, register_decomposition -from torch import Tensor -from typing import Dict, List, Tuple, Optional - - -if torch.__version__ < "2.2.0": - # Torch versions prior to 2.2.0 lacked some decompositions, which we - # add manually. - @register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) - def scaled_dot_product_flash_attention( - query, - key, - value, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: - dtype = query.dtype - batchSize, num_head, qSize, headSize = ( - query.shape[0], - query.shape[1], - query.shape[2], - query.shape[3], - ) - - logsumexp = torch.empty( - [batchSize, qSize, num_head, headSize], dtype=torch.float - ) - cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - max_q, max_k = 0, 0 - philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - debug_attn_mask = torch.empty( - [], - dtype=query.dtype, - device="cpu", - requires_grad=query.requires_grad, - ) - output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( - query, key, value, None, dropout_p, is_causal, None, scale=scale - ) - output = output.transpose(1, 2).contiguous( - memory_format=torch.contiguous_format - ) - return ( - output.transpose(1, 2), - logsumexp, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - philox_seed, - philox_offset, - debug_attn_mask, - ) - - -# manually add decomposition to bypass the error that comes -# from VAE encode(inp).latent_dist.sample() failing to symbolically -# trace from torch fx. -# Expected Torch stable version: > 2.1.0 -# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239 -# temporary Torch fix: https://github.com/pytorch/pytorch/issues/107170 -@register_decomposition(torch.ops.aten.randn.generator) -@out_wrapper() -def randn_generator( - *shape, - generator: Optional[torch.Generator] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[DeviceLikeType] = None, - layout: Optional[torch.layout] = None, - requires_grad: bool = False, - pin_memory: bool = False, -) -> TensorLikeType: - # We should eventually support the generator overload. - # However, if someone passes in a None generator explicitly, - # we can jut fall back to randn.default - if generator is None: - return _refs.randn( - *shape, - dtype=dtype, - device=device, - layout=layout, - requires_grad=requires_grad, - pin_memory=pin_memory, - ) - return NotImplemented diff --git a/core/shark_turbine/ops/iree.py b/core/shark_turbine/ops/iree.py index 093c6c77e..e28826db8 100644 --- a/core/shark_turbine/ops/iree.py +++ b/core/shark_turbine/ops/iree.py @@ -50,29 +50,13 @@ def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]): @CustomOp.register(library=IREE_LIBRARY) class trace_tensor(CustomOp): - signature = "trace_tensor(str trace_key, Tensor tensor) -> ()" + signature = "trace_tensor(str trace_key, Tensor(a!) tensor) -> ()" def select(self, ksel: KernelSelection): ksel.attr_str(0) - ksel.arg_tensor(1) + ksel.arg_tensor(1, inplace_tied=True) def generate(self, ksel: KernelSelection, kb: KernelBuilder): key = cast(AttrArg, ksel.arg_descs[0]) _emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]]) - kb.yield_results() - - -@CustomOp.register(library=IREE_LIBRARY) -class trace_tensors(CustomOp): - signature = "trace_tensors(str trace_key, Tensor[] tensors) -> ()" - - def select(self, ksel: KernelSelection): - ksel.attr_str(0) - ksel.arg_tensor_list(1) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - key = cast(AttrArg, ksel.arg_descs[0]) - ts = kb.arg_bindings[1] - if len(ts) >= 1: - _emit_tensor_trace(kb, cast(str, key.v), ts) - kb.yield_results() + kb.yield_results(kb.arg_bindings[1]) diff --git a/core/shark_turbine/runtime/op_reg/base.py b/core/shark_turbine/runtime/op_reg/base.py index e7fc20338..3e2b84992 100644 --- a/core/shark_turbine/runtime/op_reg/base.py +++ b/core/shark_turbine/runtime/op_reg/base.py @@ -239,6 +239,7 @@ class KernelSelection(ABC): __slots__ = [ "arg_descs", + "inplace_tied_arg_descs", "op", "result_descs", "variant", @@ -247,6 +248,7 @@ class KernelSelection(ABC): def __init__(self, op: CustomOp, arg_arity: int): self.op = op self.arg_descs = cast(list[Optional[ArgDescriptor]], arg_arity * [None]) + self.inplace_tied_arg_descs: list[ArgDescriptor] = [] self.result_descs: list[ArgDescriptor] = [] self.variant: str = "default" @@ -295,12 +297,16 @@ def spec_key(self) -> str: ) from e @abstractmethod - def arg_tensor(self, arg: int) -> "TensorArg": + def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> "TensorArg": """Declares an argument to allow any ranked tensor and to specialize for each rank and dtype. Returns the argument descriptor, which can be used to further inspect or constrain the selection. It will default to allowing all dimensions to be dynamic. + + If inplace_tied is True, then this argument participates in in-place + semantics. The kernel must yield the result-mutated after all normal + results in the order declared. """ ... @@ -354,7 +360,7 @@ def __init__(self, op: CustomOp, args: list[Any]): super().__init__(op, len(args)) self.args = args - def arg_tensor(self, arg: int) -> "TensorArg": + def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> "TensorArg": arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -362,6 +368,8 @@ def arg_tensor(self, arg: int) -> "TensorArg": arg_value, Tensor ), f"Argument type mismatch from Torch for {arg}: Expected tensor, got {type(arg_value)}" arg_descs[arg] = desc = TensorArg(arg_value) + if inplace_tied: + self.inplace_tied_arg_descs.append(desc) return desc def arg_tensor_list(self, arg: int) -> "TensorListArg": @@ -676,7 +684,7 @@ def __init__( # Assemble result types. result_types = [] - for d in ksel.result_descs: + for d in (*ksel.result_descs, *ksel.inplace_tied_arg_descs): if not d.is_list: if d.ir_arity == 1: result_types.append(IrType.parse(d.mlir_type_asm)) @@ -744,6 +752,11 @@ def create_module( def yield_results(self, *results: Value): """Yields results of the kernel computation.""" assert not self.yielded, "yield_results has already been called" + ksel = self.ksel + expected_count = len(ksel.result_descs) + len(ksel.inplace_tied_arg_descs) + assert ( + len(results) == expected_count + ), f"Mismatched yielded results and declared+inplace: Expected={expected_count}, Got={len(results)}" with self.ip, Location.unknown(): func_d.ReturnOp(results) self.yielded = True diff --git a/core/shark_turbine/transforms/general/custom_op_expansion.py b/core/shark_turbine/transforms/general/custom_op_expansion.py index dae04d905..0a191dc2a 100644 --- a/core/shark_turbine/transforms/general/custom_op_expansion.py +++ b/core/shark_turbine/transforms/general/custom_op_expansion.py @@ -124,7 +124,7 @@ def __init__( self.results = results self.type_converter = type_converter - def arg_tensor(self, arg: int) -> TensorArg: + def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> TensorArg: # This is annoying: We have to go from the Torch MLIR type system to the # original torch.tensor Python type system. We do this by way of the native # type converter because it has the mapping pathway we need. This is one of the @@ -154,6 +154,8 @@ def arg_tensor(self, arg: int) -> TensorArg: ) t = torch.empty(rtt.shape, dtype=dtype, device="meta") arg_descs[arg] = desc = TensorArg(t) + if inplace_tied: + self.inplace_tied_arg_descs.append(desc) return desc def arg_tensor_list(self, arg: int) -> TensorListArg: @@ -235,6 +237,11 @@ def __init__( def yield_results(self, *results: Value): """Yields results of the kernel computation.""" assert not self.yielded, "yield_results has already been called" + ksel = self.ksel + expected_count = len(ksel.result_descs) + len(ksel.inplace_tied_arg_descs) + assert ( + len(results) == expected_count + ), f"Mismatched yielded results and declared+inplace: Expected={expected_count}, Got={len(results)}" with self.ip, self.location: torch_op_results: list[Value] = list(self.torch_op.results) assert len(results) == len( diff --git a/core/tests/aot/api_test.py b/core/tests/aot/api_test.py index ef13738ac..2bf6afabd 100644 --- a/core/tests/aot/api_test.py +++ b/core/tests/aot/api_test.py @@ -14,6 +14,7 @@ from shark_turbine.aot import * import torch +import torch.nn as nn class GeneralAPI(unittest.TestCase): @@ -71,6 +72,55 @@ def foobar(self): print(module_str) +class ExportAPI(unittest.TestCase): + def testStaticNNModule(self): + mdl = SimpleParams() + exported = export(mdl, args=(torch.empty([128, 20]),)) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertIn("dense_resource", asm) + + def testDynamicNNModule(self): + mdl = SimpleParams() + batch = torch.export.Dim("batch") + exported = export( + mdl, args=(torch.empty([128, 20]),), dynamic_shapes={"x": {0: batch}} + ) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertIn( + "func.func @main(%arg0: !torch.vtensor<[?,20],f32>) -> !torch.vtensor<[?,30],f32>", + asm, + ) + + def testExternalParamsNNModule(self): + mdl = SimpleParams() + exported = export(mdl, args=(torch.empty([128, 20]),), external_params=True) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertNotIn("dense_resource", asm) + self.assertIn("util.global.load", asm) + + def testTorchExportedProgram(self): + mdl = SimpleParams() + prg = torch.export.export(mdl, args=(torch.empty([128, 20]),)) + exported = export(prg, external_params=True) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertNotIn("dense_resource", asm) + self.assertIn("util.global private @_params.classifier.weight", asm) + self.assertIn("util.global private @_params.classifier.bias", asm) + + +class SimpleParams(nn.Module): + def __init__(self): + super().__init__() + self.classifier = nn.Linear(20, 30) + + def forward(self, x): + return self.classifier(x) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/core/tests/aot/compiled_exported_program_test.py b/core/tests/aot/compiled_exported_program_test.py new file mode 100644 index 000000000..0f79111c8 --- /dev/null +++ b/core/tests/aot/compiled_exported_program_test.py @@ -0,0 +1,138 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch +import torch.nn as nn + +from iree.compiler.ir import ( + Context, +) + +from shark_turbine.aot import * +from shark_turbine.aot.builtins import * + + +class TorchExportTests(unittest.TestCase): + def testImportPhases(self): + class MyModule(torch.nn.Module): + def forward(self): + ... + + fxb = FxProgramsBuilder(MyModule()) + + @fxb.export_program( + args=([torch.empty([3, 2]), torch.empty([1, 2])],), + kwargs={"foobar": torch.empty([3, 1])}, + ) + def compute(module, inputs, *, foobar): + t1 = inputs[0] + t2 = inputs[1] + t3 = t1 + t2 + foobar + return [t3 * t3, foobar] + + class ExportedProcModule(CompiledModule): + _compute = compute + + def foobar( + self, + t1=AbstractTensor(3, 2), + t2=AbstractTensor(1, 2), + t3=AbstractTensor(3, 1), + ): + return self._compute(t1, t2, foobar=t3) + + inst = ExportedProcModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("func.func private @_compute", module_str) + self.assertIn("func.func @foobar", module_str) + + def testMultiPublic(self): + class MyModule(torch.nn.Module): + def forward(self): + ... + + fxb = FxProgramsBuilder(MyModule()) + + @fxb.export_program( + args=([torch.empty([3, 2]), torch.empty([1, 2])],), + kwargs={"foobar": torch.empty([3, 1])}, + ) + def _compute1(module, inputs, *, foobar): + t1 = inputs[0] + t2 = inputs[1] + t3 = t1 + t2 + foobar + return [t3 * t3, foobar] + + @fxb.export_program( + args=([torch.empty([5]), torch.empty([5])],), + kwargs={"foobar": torch.empty([5])}, + ) + def _compute2(module, inputs, *, foobar): + t1 = inputs[0] + t2 = inputs[1] + t3 = t1 + t2 + foobar + return [t3 * t3, foobar] + + class ExportedPublicModule(CompiledModule): + compute1 = _compute1 + compute2 = _compute2 + + inst = ExportedPublicModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("func.func @compute1", module_str) + self.assertIn("func.func @compute2", module_str) + + def testParametersAsGlobals(self): + fxb = FxProgramsBuilder(SimpleParams()) + + @fxb.export_program( + args=(torch.empty([128, 20]),), + ) + def _compute1(module, x): + return module.forward(x) + + class ParamsAsGlobalsModule(CompiledModule): + params = export_parameters(fxb.root_module) + compute1 = _compute1 + compute2 = _compute1 + + inst = ParamsAsGlobalsModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn( + "util.global private @_params.classifier.weight {noinline}", module_str + ) + self.assertIn( + "util.global private @_params.classifier.bias {noinline}", module_str + ) + # Should only be two. + self.assertEqual(2, module_str.count("util.global private")) + # And two loads each loads. + self.assertEqual( + 2, module_str.count("util.global.load @_params.classifier.weight") + ) + self.assertEqual( + 2, module_str.count("util.global.load @_params.classifier.bias") + ) + + +class SimpleParams(nn.Module): + def __init__(self): + super().__init__() + self.classifier = nn.Linear(20, 30) + + def forward(self, x): + return self.classifier(x) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/core/tests/aot/jittable_test.py b/core/tests/aot/jittable_test.py index 0b3cabfa8..6419c0bd4 100644 --- a/core/tests/aot/jittable_test.py +++ b/core/tests/aot/jittable_test.py @@ -73,7 +73,7 @@ def compute(*, a, b): print(module_str) def testDynamicDims(self): - class ProcArgsModule(CompiledModule): + class DynamicDimsModule(CompiledModule): def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): return self.compute( a, @@ -87,7 +87,7 @@ def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): def compute(a, b): return a * b - inst = ProcArgsModule(context=Context(), import_to=None) + inst = DynamicDimsModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) diff --git a/core/tests/ops/iree_test.py b/core/tests/ops/iree_test.py index f10643026..b41647d65 100644 --- a/core/tests/ops/iree_test.py +++ b/core/tests/ops/iree_test.py @@ -17,12 +17,6 @@ def testTrace(self): t = torch.randn(3, 4) ops.iree.trace_tensor("TEST", t) - def testTraceList(self): - t1 = torch.randn(3, 4) - t2 = torch.randn(1, 8) - ops.iree.trace_tensors("TEST 2", [t1, t2]) - ops.iree.trace_tensors("TEST 1", [t1]) - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/core/tests/runtime/op_reg/kernel_aot_test.py b/core/tests/runtime/op_reg/kernel_aot_test.py index 48c7f59f1..0d31edbd3 100644 --- a/core/tests/runtime/op_reg/kernel_aot_test.py +++ b/core/tests/runtime/op_reg/kernel_aot_test.py @@ -49,6 +49,7 @@ def testTrace(self): print("CUSTOM OP CONVERTED:") module_asm = str(prog.mlir_module) + print(module_asm) self.assertIn('flow.tensor.trace "LAYER0"', module_asm) self.assertIn('flow.tensor.trace "LAYER1"', module_asm) self.assertIn('flow.tensor.trace "LAYER3"', module_asm) diff --git a/core/torchvision-requirements.txt b/core/torchvision-requirements.txt deleted file mode 100644 index e38d8d008..000000000 --- a/core/torchvision-requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ ---pre -torchvision diff --git a/models/turbine_models/custom_models/README.md b/models/turbine_models/custom_models/README.md index 98aa347b1..d56214257 100644 --- a/models/turbine_models/custom_models/README.md +++ b/models/turbine_models/custom_models/README.md @@ -7,8 +7,7 @@ cd SHARK-Turbine python -m venv turbine_venv && source turbine_venv/bin/activate pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt + -r core/pytorch-cpu-requirements.txt pip install --upgrade -r core/requirements.txt pip install -e core pip install -e models @@ -39,4 +38,4 @@ python models/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Ll 2) Interactive CLI chat mode. (just add a --chat_mode flag) ``` python models/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode -``` \ No newline at end of file +``` diff --git a/models/turbine_models/tests/stateless_llama_test.py b/models/turbine_models/tests/stateless_llama_test.py index ab5d228ce..c2ecc4b48 100644 --- a/models/turbine_models/tests/stateless_llama_test.py +++ b/models/turbine_models/tests/stateless_llama_test.py @@ -188,6 +188,9 @@ def test_streaming_vmfb_comparison(self): ) check_output_string(torch_str, turbine_str) + # See: https://github.com/nod-ai/SHARK-Turbine/issues/560 + # Developed issues related to the pytorch 2.3 upgrade. + @unittest.expectedFailure def test_rerotated_torch_comparison(self): torch_str = llm_runner.run_torch_llm( "Trelis/Llama-2-7b-chat-hf-function-calling-v2", diff --git a/serving/requirements.txt b/serving/requirements.txt index 3c9503df4..3cb469b2c 100644 --- a/serving/requirements.txt +++ b/serving/requirements.txt @@ -1,2 +1,3 @@ fastapi>=0.109.2 uvicorn>=0.27.0 +requests