From 60568842ecc53965f3f08d2f8f844d78e1d14801 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 9 Oct 2023 12:25:44 -0700 Subject: [PATCH] Add a simple `aot.export` entrypoint for one-short nn.Module export. (#84) * Adds full compilation support (as opposed to just stopping at the exported IR). * Adds a sample which exercises the entire flow. --- .github/workflows/test.yml | 6 +- examples/aot_mlp/mlp_export_simple.py | 53 +++++++ python/shark_turbine/aot/__init__.py | 1 + python/shark_turbine/aot/exporter.py | 192 ++++++++++++++++++++++++++ python/shark_turbine/dynamo/passes.py | 1 - pytorch-cpu-requirements.txt | 1 - tests/examples/aot_mlp_test.py | 28 ++++ torchvision-requirements.txt | 1 - 8 files changed, 278 insertions(+), 5 deletions(-) create mode 100644 examples/aot_mlp/mlp_export_simple.py create mode 100644 python/shark_turbine/aot/exporter.py create mode 100644 tests/examples/aot_mlp_test.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9d938f5df..bc2558358 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,8 +27,10 @@ jobs: - name: Sync source deps run: | python -m pip install --upgrade pip - pip install --upgrade -r requirements.txt - pip install --upgrade -e .[torch,testing] + pip install --index-url https://download.pytorch.org/whl/cpu \ + -r pytorch-cpu-requirements.txt \ + -r torchvision-requirements.txt + pip install -e .[testing] - name: Run tests run: | diff --git a/examples/aot_mlp/mlp_export_simple.py b/examples/aot_mlp/mlp_export_simple.py new file mode 100644 index 000000000..675f55c22 --- /dev/null +++ b/examples/aot_mlp/mlp_export_simple.py @@ -0,0 +1,53 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +import torch.nn as nn + +import shark_turbine.aot as aot + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layer0 = nn.Linear(8, 8, bias=True) + self.layer1 = nn.Linear(8, 4, bias=True) + self.layer2 = nn.Linear(4, 2, bias=True) + self.layer3 = nn.Linear(2, 2, bias=True) + + def forward(self, x: torch.Tensor): + x = self.layer0(x) + x = torch.sigmoid(x) + x = self.layer1(x) + x = torch.sigmoid(x) + x = self.layer2(x) + x = torch.sigmoid(x) + x = self.layer3(x) + return x + + +model = MLP() +example_x = torch.empty(97, 8, dtype=torch.float32) +exported = aot.export(model, example_x) +exported.print_readable() +compiled_binary = exported.compile(save_to=None) + + +def infer(): + import numpy as np + import iree.runtime as rt + + config = rt.Config("local-task") + vmm = rt.load_vm_module( + rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), + config, + ) + x = np.random.rand(97, 8).astype(np.float32) + y = vmm.main(x) + print(y.to_host()) + + +infer() diff --git a/python/shark_turbine/aot/__init__.py b/python/shark_turbine/aot/__init__.py index 070ea33bb..7bf2432e4 100644 --- a/python/shark_turbine/aot/__init__.py +++ b/python/shark_turbine/aot/__init__.py @@ -5,5 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .compiled_module import CompiledModule +from .exporter import export from .builtins import * diff --git a/python/shark_turbine/aot/exporter.py b/python/shark_turbine/aot/exporter.py new file mode 100644 index 000000000..f93a3e153 --- /dev/null +++ b/python/shark_turbine/aot/exporter.py @@ -0,0 +1,192 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional, Sequence, Union +import functools +import io +from pathlib import Path +import platform + +import torch + +from iree.compiler.api import ( + Invocation, + Session, + Source, + Output, +) + +from .builtins import * +from .compiled_module import ( + CompiledModule, + ExportProcDef, +) +from .support.ir_imports import ( + Context, + Operation, +) +from .support.procedural import ( + AbstractTypedef, +) + + +_is_windows = platform.system() == "Windows" + + +SaveableTarget = Union[str, Path, None, Output] + + +class ExportOutput: + """Wrapper around a CompiledModule produced by `export`.""" + + def __init__( + self, + session: Session, + compiled_module: CompiledModule, + *, + importer_uses_session: bool = False, + ): + self.session = session + self.session.set_flags("--iree-input-type=torch") + self.compiled_module = compiled_module + self._importer_uses_session = importer_uses_session + + @property + def mlir_module(self) -> Operation: + """Gets the MLIR module resulting from the last compilation phase.""" + return CompiledModule.get_mlir_module(self.compiled_module) + + def print_readable(self, large_elements_limit: int = 50): + """Prints a human readable version of the current compilation IR.""" + self.mlir_module.print(large_elements_limit=large_elements_limit) + + def save_mlir(self, file_path: Union[str, Path]): + """Saves the current compilation IR to a path on disk. + + Args: + file_path: Path to save the file. If it has a ".mlirbc" + extension, it will be saved as bytecode. Otherwise as + text. + """ + file_path = Path(file_path) + with open(file_path, "wb") as f: + if file_path.suffix == ".mlirbc": + self.mlir_module.write_bytecode(f) + else: + self.mlir_module.print(f, binary=True) + + def _run_import(self): + CompiledModule.run_import(self.compiled_module) + + def compile( + self, + save_to: SaveableTarget, + *, + target_backends: Union[str, Sequence[str]] = ("llvm-cpu",), + ) -> Optional[memoryview]: + """Compiles the exported program to an executable binary. + + Args: + save_to: Where to save the compiled binary. Can be one of: + None: outputs to a memory buffer and return the API Output. + (str, Path): Outputs to a file + Output: Raw compiler API Output object to save to. + target_backends: A comma-delimitted string of IREE target backends or + a sequence of strings. + Returns: + None unless if `save_to=None`, in which case, we return the backing compiler API + Ouptut object. It can be queried for its backing memory via its `map_memory()` + method. + """ + return_memory_view = False + if save_to is None: + output = Output.open_membuffer() + return_memory_view = True + elif isinstance(save_to, (str, Path)): + save_to = Path(save_to) + output = Output.open_file(str(save_to)) + else: + assert isinstance(output, Output) + output = save_to + + target_backends = ( + target_backends + if isinstance(target_backends, str) + else ",".join(target_backends) + ) + inv = self.session.invocation() + if self._importer_uses_session: + inv.import_module(self.mlir_module) + else: + # Some platforms can't share the context across the importer and + # session (cough: Windows). Round-trip in this case. + buffer_io = io.BytesIO() + self.mlir_module.write_bytecode(buffer_io) + buffer = buffer_io.getvalue() + source = Source.wrap_buffer(self.session, buffer) + inv.parse_source(source) + inv.enable_console_diagnostics() + + # TODO: Don't use flags to set the target backends: set module attributes. + self.session.set_flags(f"--iree-hal-target-backends={target_backends}") + if not inv.execute(): + raise RuntimeError("Compilation failed: See diagnostics") + + inv.output_vm_bytecode(output) + output.keep() + if return_memory_view: + return output + else: + return None + + +# Decorator which explicitly exports a function. +# TODO: Make this a public API on CompiledModule. +def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> ExportProcDef: + if f is None: + return functools.partial(export_proc, signature=signature) + return ExportProcDef(f.__name__, f, signature=signature) + + +def export(mdl: torch.nn.Module, *example_args: torch.Tensor) -> ExportOutput: + """One shot export of an nn.Module. + + This is a very restrictive API vs the lower level `CompiledModule` + facility. It is suitable for one-shot modules, with a single + entrypoint and static example arguments where no additional + configuration is needed for mutable parameters/buffers or state + management. Dynamic shape constraints are also not presently + exposed via this API, but we expect to allow this in the future. + + Args: + mdl: The nn.Module to export. + *example_args: Example tensors. + + Returns: + An ExportOutput object that wraps the compilation and provides + easy access. + """ + signature = [abstractify(t) for t in example_args] + + class Exported(CompiledModule, export_name=mdl._get_name()): + params = export_parameters(mdl) + + @export_proc(signature=signature) + def main(self, *args): + return jittable(mdl.forward)(*args) + + session = Session() + # There are some bugs with respect to Session/context interop that we + # haven't squashed yet. For now, default everyone to round-tripping + # via bytecode vs sharing the context between the importer/compiler. + importer_uses_session = False and not _is_windows + if importer_uses_session: + context = session.context + else: + context = Context() + + cm = Exported(context=context, import_to="import") + return ExportOutput(session, cm, importer_uses_session=importer_uses_session) diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 7b3aaf475..3671e550f 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -43,7 +43,6 @@ torch.ops.aten.nll_loss_backward, torch.ops.aten._to_copy, torch.ops.aten._log_softmax_backward_data, - ] diff --git a/pytorch-cpu-requirements.txt b/pytorch-cpu-requirements.txt index 4cbbbfa0b..403689468 100644 --- a/pytorch-cpu-requirements.txt +++ b/pytorch-cpu-requirements.txt @@ -1,3 +1,2 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torch==2.1.0 diff --git a/tests/examples/aot_mlp_test.py b/tests/examples/aot_mlp_test.py new file mode 100644 index 000000000..d55e0b9c6 --- /dev/null +++ b/tests/examples/aot_mlp_test.py @@ -0,0 +1,28 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from pathlib import Path +import sys +import subprocess +import unittest + +REPO_DIR = Path(__file__).resolve().parent.parent.parent + + +def _run(local_path: str): + path = REPO_DIR / local_path + subprocess.check_call([sys.executable, str(path)]) + + +class AOTMLPTest(unittest.TestCase): + def testMLPExportSimple(self): + _run("examples/aot_mlp/mlp_export_simple.py") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index e5630a854..e38d8d008 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,2 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision