-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a simple
aot.export
entrypoint for one-short nn.Module export. (#…
…84) * Adds full compilation support (as opposed to just stopping at the exported IR). * Adds a sample which exercises the entire flow.
- Loading branch information
1 parent
bf80c1e
commit 6056884
Showing
8 changed files
with
278 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import shark_turbine.aot as aot | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layer0 = nn.Linear(8, 8, bias=True) | ||
self.layer1 = nn.Linear(8, 4, bias=True) | ||
self.layer2 = nn.Linear(4, 2, bias=True) | ||
self.layer3 = nn.Linear(2, 2, bias=True) | ||
|
||
def forward(self, x: torch.Tensor): | ||
x = self.layer0(x) | ||
x = torch.sigmoid(x) | ||
x = self.layer1(x) | ||
x = torch.sigmoid(x) | ||
x = self.layer2(x) | ||
x = torch.sigmoid(x) | ||
x = self.layer3(x) | ||
return x | ||
|
||
|
||
model = MLP() | ||
example_x = torch.empty(97, 8, dtype=torch.float32) | ||
exported = aot.export(model, example_x) | ||
exported.print_readable() | ||
compiled_binary = exported.compile(save_to=None) | ||
|
||
|
||
def infer(): | ||
import numpy as np | ||
import iree.runtime as rt | ||
|
||
config = rt.Config("local-task") | ||
vmm = rt.load_vm_module( | ||
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), | ||
config, | ||
) | ||
x = np.random.rand(97, 8).astype(np.float32) | ||
y = vmm.main(x) | ||
print(y.to_host()) | ||
|
||
|
||
infer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Optional, Sequence, Union | ||
import functools | ||
import io | ||
from pathlib import Path | ||
import platform | ||
|
||
import torch | ||
|
||
from iree.compiler.api import ( | ||
Invocation, | ||
Session, | ||
Source, | ||
Output, | ||
) | ||
|
||
from .builtins import * | ||
from .compiled_module import ( | ||
CompiledModule, | ||
ExportProcDef, | ||
) | ||
from .support.ir_imports import ( | ||
Context, | ||
Operation, | ||
) | ||
from .support.procedural import ( | ||
AbstractTypedef, | ||
) | ||
|
||
|
||
_is_windows = platform.system() == "Windows" | ||
|
||
|
||
SaveableTarget = Union[str, Path, None, Output] | ||
|
||
|
||
class ExportOutput: | ||
"""Wrapper around a CompiledModule produced by `export`.""" | ||
|
||
def __init__( | ||
self, | ||
session: Session, | ||
compiled_module: CompiledModule, | ||
*, | ||
importer_uses_session: bool = False, | ||
): | ||
self.session = session | ||
self.session.set_flags("--iree-input-type=torch") | ||
self.compiled_module = compiled_module | ||
self._importer_uses_session = importer_uses_session | ||
|
||
@property | ||
def mlir_module(self) -> Operation: | ||
"""Gets the MLIR module resulting from the last compilation phase.""" | ||
return CompiledModule.get_mlir_module(self.compiled_module) | ||
|
||
def print_readable(self, large_elements_limit: int = 50): | ||
"""Prints a human readable version of the current compilation IR.""" | ||
self.mlir_module.print(large_elements_limit=large_elements_limit) | ||
|
||
def save_mlir(self, file_path: Union[str, Path]): | ||
"""Saves the current compilation IR to a path on disk. | ||
Args: | ||
file_path: Path to save the file. If it has a ".mlirbc" | ||
extension, it will be saved as bytecode. Otherwise as | ||
text. | ||
""" | ||
file_path = Path(file_path) | ||
with open(file_path, "wb") as f: | ||
if file_path.suffix == ".mlirbc": | ||
self.mlir_module.write_bytecode(f) | ||
else: | ||
self.mlir_module.print(f, binary=True) | ||
|
||
def _run_import(self): | ||
CompiledModule.run_import(self.compiled_module) | ||
|
||
def compile( | ||
self, | ||
save_to: SaveableTarget, | ||
*, | ||
target_backends: Union[str, Sequence[str]] = ("llvm-cpu",), | ||
) -> Optional[memoryview]: | ||
"""Compiles the exported program to an executable binary. | ||
Args: | ||
save_to: Where to save the compiled binary. Can be one of: | ||
None: outputs to a memory buffer and return the API Output. | ||
(str, Path): Outputs to a file | ||
Output: Raw compiler API Output object to save to. | ||
target_backends: A comma-delimitted string of IREE target backends or | ||
a sequence of strings. | ||
Returns: | ||
None unless if `save_to=None`, in which case, we return the backing compiler API | ||
Ouptut object. It can be queried for its backing memory via its `map_memory()` | ||
method. | ||
""" | ||
return_memory_view = False | ||
if save_to is None: | ||
output = Output.open_membuffer() | ||
return_memory_view = True | ||
elif isinstance(save_to, (str, Path)): | ||
save_to = Path(save_to) | ||
output = Output.open_file(str(save_to)) | ||
else: | ||
assert isinstance(output, Output) | ||
output = save_to | ||
|
||
target_backends = ( | ||
target_backends | ||
if isinstance(target_backends, str) | ||
else ",".join(target_backends) | ||
) | ||
inv = self.session.invocation() | ||
if self._importer_uses_session: | ||
inv.import_module(self.mlir_module) | ||
else: | ||
# Some platforms can't share the context across the importer and | ||
# session (cough: Windows). Round-trip in this case. | ||
buffer_io = io.BytesIO() | ||
self.mlir_module.write_bytecode(buffer_io) | ||
buffer = buffer_io.getvalue() | ||
source = Source.wrap_buffer(self.session, buffer) | ||
inv.parse_source(source) | ||
inv.enable_console_diagnostics() | ||
|
||
# TODO: Don't use flags to set the target backends: set module attributes. | ||
self.session.set_flags(f"--iree-hal-target-backends={target_backends}") | ||
if not inv.execute(): | ||
raise RuntimeError("Compilation failed: See diagnostics") | ||
|
||
inv.output_vm_bytecode(output) | ||
output.keep() | ||
if return_memory_view: | ||
return output | ||
else: | ||
return None | ||
|
||
|
||
# Decorator which explicitly exports a function. | ||
# TODO: Make this a public API on CompiledModule. | ||
def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> ExportProcDef: | ||
if f is None: | ||
return functools.partial(export_proc, signature=signature) | ||
return ExportProcDef(f.__name__, f, signature=signature) | ||
|
||
|
||
def export(mdl: torch.nn.Module, *example_args: torch.Tensor) -> ExportOutput: | ||
"""One shot export of an nn.Module. | ||
This is a very restrictive API vs the lower level `CompiledModule` | ||
facility. It is suitable for one-shot modules, with a single | ||
entrypoint and static example arguments where no additional | ||
configuration is needed for mutable parameters/buffers or state | ||
management. Dynamic shape constraints are also not presently | ||
exposed via this API, but we expect to allow this in the future. | ||
Args: | ||
mdl: The nn.Module to export. | ||
*example_args: Example tensors. | ||
Returns: | ||
An ExportOutput object that wraps the compilation and provides | ||
easy access. | ||
""" | ||
signature = [abstractify(t) for t in example_args] | ||
|
||
class Exported(CompiledModule, export_name=mdl._get_name()): | ||
params = export_parameters(mdl) | ||
|
||
@export_proc(signature=signature) | ||
def main(self, *args): | ||
return jittable(mdl.forward)(*args) | ||
|
||
session = Session() | ||
# There are some bugs with respect to Session/context interop that we | ||
# haven't squashed yet. For now, default everyone to round-tripping | ||
# via bytecode vs sharing the context between the importer/compiler. | ||
importer_uses_session = False and not _is_windows | ||
if importer_uses_session: | ||
context = session.context | ||
else: | ||
context = Context() | ||
|
||
cm = Exported(context=context, import_to="import") | ||
return ExportOutput(session, cm, importer_uses_session=importer_uses_session) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html | ||
--pre | ||
torch==2.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import logging | ||
from pathlib import Path | ||
import sys | ||
import subprocess | ||
import unittest | ||
|
||
REPO_DIR = Path(__file__).resolve().parent.parent.parent | ||
|
||
|
||
def _run(local_path: str): | ||
path = REPO_DIR / local_path | ||
subprocess.check_call([sys.executable, str(path)]) | ||
|
||
|
||
class AOTMLPTest(unittest.TestCase): | ||
def testMLPExportSimple(self): | ||
_run("examples/aot_mlp/mlp_export_simple.py") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.DEBUG) | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html | ||
--pre | ||
torchvision |