-
Notifications
You must be signed in to change notification settings - Fork 46
Add a simple aot.export
entrypoint for one-short nn.Module export.
#84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
25a06b8
Add a simple `aot.export` entrypoint for one-short nn.Module export.
stellaraccident 09a25f1
Black
stellaraccident 09735b0
Remove stray prints
stellaraccident 722329c
Add example test runner
stellaraccident acfb3b8
Pin back to CPU pytorch.
stellaraccident 4de076f
Try to pin again
stellaraccident de9d05f
Pin again
stellaraccident 5da8874
Pin more
stellaraccident File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import shark_turbine.aot as aot | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layer0 = nn.Linear(8, 8, bias=True) | ||
self.layer1 = nn.Linear(8, 4, bias=True) | ||
self.layer2 = nn.Linear(4, 2, bias=True) | ||
self.layer3 = nn.Linear(2, 2, bias=True) | ||
|
||
def forward(self, x: torch.Tensor): | ||
x = self.layer0(x) | ||
x = torch.sigmoid(x) | ||
x = self.layer1(x) | ||
x = torch.sigmoid(x) | ||
x = self.layer2(x) | ||
x = torch.sigmoid(x) | ||
x = self.layer3(x) | ||
return x | ||
|
||
|
||
model = MLP() | ||
example_x = torch.empty(97, 8, dtype=torch.float32) | ||
exported = aot.export(model, example_x) | ||
exported.print_readable() | ||
compiled_binary = exported.compile(save_to=None) | ||
|
||
|
||
def infer(): | ||
import numpy as np | ||
import iree.runtime as rt | ||
|
||
config = rt.Config("local-task") | ||
vmm = rt.load_vm_module( | ||
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), | ||
config, | ||
) | ||
x = np.random.rand(97, 8).astype(np.float32) | ||
y = vmm.main(x) | ||
print(y.to_host()) | ||
|
||
|
||
infer() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Optional, Sequence, Union | ||
import functools | ||
import io | ||
from pathlib import Path | ||
import platform | ||
|
||
import torch | ||
|
||
from iree.compiler.api import ( | ||
Invocation, | ||
Session, | ||
Source, | ||
Output, | ||
) | ||
|
||
from .builtins import * | ||
from .compiled_module import ( | ||
CompiledModule, | ||
ExportProcDef, | ||
) | ||
from .support.ir_imports import ( | ||
Context, | ||
Operation, | ||
) | ||
from .support.procedural import ( | ||
AbstractTypedef, | ||
) | ||
|
||
|
||
_is_windows = platform.system() == "Windows" | ||
|
||
|
||
SaveableTarget = Union[str, Path, None, Output] | ||
|
||
|
||
class ExportOutput: | ||
"""Wrapper around a CompiledModule produced by `export`.""" | ||
|
||
def __init__( | ||
self, | ||
session: Session, | ||
compiled_module: CompiledModule, | ||
*, | ||
importer_uses_session: bool = False, | ||
): | ||
self.session = session | ||
self.session.set_flags("--iree-input-type=torch") | ||
self.compiled_module = compiled_module | ||
self._importer_uses_session = importer_uses_session | ||
|
||
@property | ||
def mlir_module(self) -> Operation: | ||
"""Gets the MLIR module resulting from the last compilation phase.""" | ||
return CompiledModule.get_mlir_module(self.compiled_module) | ||
|
||
def print_readable(self, large_elements_limit: int = 50): | ||
"""Prints a human readable version of the current compilation IR.""" | ||
self.mlir_module.print(large_elements_limit=large_elements_limit) | ||
|
||
def save_mlir(self, file_path: Union[str, Path]): | ||
"""Saves the current compilation IR to a path on disk. | ||
|
||
Args: | ||
file_path: Path to save the file. If it has a ".mlirbc" | ||
extension, it will be saved as bytecode. Otherwise as | ||
text. | ||
""" | ||
file_path = Path(file_path) | ||
with open(file_path, "wb") as f: | ||
if file_path.suffix == ".mlirbc": | ||
self.mlir_module.write_bytecode(f) | ||
else: | ||
self.mlir_module.print(f, binary=True) | ||
|
||
def _run_import(self): | ||
CompiledModule.run_import(self.compiled_module) | ||
|
||
def compile( | ||
self, | ||
save_to: SaveableTarget, | ||
*, | ||
target_backends: Union[str, Sequence[str]] = ("llvm-cpu",), | ||
) -> Optional[memoryview]: | ||
"""Compiles the exported program to an executable binary. | ||
|
||
Args: | ||
save_to: Where to save the compiled binary. Can be one of: | ||
None: outputs to a memory buffer and return the API Output. | ||
(str, Path): Outputs to a file | ||
Output: Raw compiler API Output object to save to. | ||
target_backends: A comma-delimitted string of IREE target backends or | ||
a sequence of strings. | ||
Returns: | ||
None unless if `save_to=None`, in which case, we return the backing compiler API | ||
Ouptut object. It can be queried for its backing memory via its `map_memory()` | ||
method. | ||
""" | ||
return_memory_view = False | ||
if save_to is None: | ||
output = Output.open_membuffer() | ||
return_memory_view = True | ||
elif isinstance(save_to, (str, Path)): | ||
save_to = Path(save_to) | ||
output = Output.open_file(str(save_to)) | ||
else: | ||
assert isinstance(output, Output) | ||
output = save_to | ||
|
||
target_backends = ( | ||
target_backends | ||
if isinstance(target_backends, str) | ||
else ",".join(target_backends) | ||
) | ||
inv = self.session.invocation() | ||
if self._importer_uses_session: | ||
inv.import_module(self.mlir_module) | ||
else: | ||
# Some platforms can't share the context across the importer and | ||
# session (cough: Windows). Round-trip in this case. | ||
buffer_io = io.BytesIO() | ||
self.mlir_module.write_bytecode(buffer_io) | ||
buffer = buffer_io.getvalue() | ||
source = Source.wrap_buffer(self.session, buffer) | ||
inv.parse_source(source) | ||
inv.enable_console_diagnostics() | ||
|
||
# TODO: Don't use flags to set the target backends: set module attributes. | ||
self.session.set_flags(f"--iree-hal-target-backends={target_backends}") | ||
if not inv.execute(): | ||
raise RuntimeError("Compilation failed: See diagnostics") | ||
|
||
inv.output_vm_bytecode(output) | ||
output.keep() | ||
if return_memory_view: | ||
return output | ||
else: | ||
return None | ||
|
||
|
||
# Decorator which explicitly exports a function. | ||
# TODO: Make this a public API on CompiledModule. | ||
def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> ExportProcDef: | ||
if f is None: | ||
return functools.partial(export_proc, signature=signature) | ||
return ExportProcDef(f.__name__, f, signature=signature) | ||
|
||
|
||
def export(mdl: torch.nn.Module, *example_args: torch.Tensor) -> ExportOutput: | ||
"""One shot export of an nn.Module. | ||
|
||
This is a very restrictive API vs the lower level `CompiledModule` | ||
facility. It is suitable for one-shot modules, with a single | ||
entrypoint and static example arguments where no additional | ||
configuration is needed for mutable parameters/buffers or state | ||
management. Dynamic shape constraints are also not presently | ||
exposed via this API, but we expect to allow this in the future. | ||
|
||
Args: | ||
mdl: The nn.Module to export. | ||
*example_args: Example tensors. | ||
|
||
Returns: | ||
An ExportOutput object that wraps the compilation and provides | ||
easy access. | ||
""" | ||
signature = [abstractify(t) for t in example_args] | ||
|
||
class Exported(CompiledModule, export_name=mdl._get_name()): | ||
params = export_parameters(mdl) | ||
|
||
@export_proc(signature=signature) | ||
def main(self, *args): | ||
return jittable(mdl.forward)(*args) | ||
|
||
session = Session() | ||
# There are some bugs with respect to Session/context interop that we | ||
# haven't squashed yet. For now, default everyone to round-tripping | ||
# via bytecode vs sharing the context between the importer/compiler. | ||
importer_uses_session = False and not _is_windows | ||
if importer_uses_session: | ||
context = session.context | ||
else: | ||
context = Context() | ||
|
||
cm = Exported(context=context, import_to="import") | ||
return ExportOutput(session, cm, importer_uses_session=importer_uses_session) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html | ||
--pre | ||
torch==2.1.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import logging | ||
from pathlib import Path | ||
import sys | ||
import subprocess | ||
import unittest | ||
|
||
REPO_DIR = Path(__file__).resolve().parent.parent.parent | ||
|
||
|
||
def _run(local_path: str): | ||
path = REPO_DIR / local_path | ||
subprocess.check_call([sys.executable, str(path)]) | ||
|
||
|
||
class AOTMLPTest(unittest.TestCase): | ||
def testMLPExportSimple(self): | ||
_run("examples/aot_mlp/mlp_export_simple.py") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.DEBUG) | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html | ||
--pre | ||
torchvision |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.