Skip to content

Commit

Permalink
Add a simple aot.export entrypoint for one-short nn.Module export. (#…
Browse files Browse the repository at this point in the history
…84)

* Adds full compilation support (as opposed to just stopping at the
exported IR).
* Adds a sample which exercises the entire flow.
  • Loading branch information
stellaraccident authored Oct 9, 2023
1 parent bf80c1e commit 6056884
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 5 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ jobs:
- name: Sync source deps
run: |
python -m pip install --upgrade pip
pip install --upgrade -r requirements.txt
pip install --upgrade -e .[torch,testing]
pip install --index-url https://download.pytorch.org/whl/cpu \
-r pytorch-cpu-requirements.txt \
-r torchvision-requirements.txt
pip install -e .[testing]
- name: Run tests
run: |
Expand Down
53 changes: 53 additions & 0 deletions examples/aot_mlp/mlp_export_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
import torch.nn as nn

import shark_turbine.aot as aot


class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Linear(8, 8, bias=True)
self.layer1 = nn.Linear(8, 4, bias=True)
self.layer2 = nn.Linear(4, 2, bias=True)
self.layer3 = nn.Linear(2, 2, bias=True)

def forward(self, x: torch.Tensor):
x = self.layer0(x)
x = torch.sigmoid(x)
x = self.layer1(x)
x = torch.sigmoid(x)
x = self.layer2(x)
x = torch.sigmoid(x)
x = self.layer3(x)
return x


model = MLP()
example_x = torch.empty(97, 8, dtype=torch.float32)
exported = aot.export(model, example_x)
exported.print_readable()
compiled_binary = exported.compile(save_to=None)


def infer():
import numpy as np
import iree.runtime as rt

config = rt.Config("local-task")
vmm = rt.load_vm_module(
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
config,
)
x = np.random.rand(97, 8).astype(np.float32)
y = vmm.main(x)
print(y.to_host())


infer()
1 change: 1 addition & 0 deletions python/shark_turbine/aot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .compiled_module import CompiledModule
from .exporter import export

from .builtins import *
192 changes: 192 additions & 0 deletions python/shark_turbine/aot/exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional, Sequence, Union
import functools
import io
from pathlib import Path
import platform

import torch

from iree.compiler.api import (
Invocation,
Session,
Source,
Output,
)

from .builtins import *
from .compiled_module import (
CompiledModule,
ExportProcDef,
)
from .support.ir_imports import (
Context,
Operation,
)
from .support.procedural import (
AbstractTypedef,
)


_is_windows = platform.system() == "Windows"


SaveableTarget = Union[str, Path, None, Output]


class ExportOutput:
"""Wrapper around a CompiledModule produced by `export`."""

def __init__(
self,
session: Session,
compiled_module: CompiledModule,
*,
importer_uses_session: bool = False,
):
self.session = session
self.session.set_flags("--iree-input-type=torch")
self.compiled_module = compiled_module
self._importer_uses_session = importer_uses_session

@property
def mlir_module(self) -> Operation:
"""Gets the MLIR module resulting from the last compilation phase."""
return CompiledModule.get_mlir_module(self.compiled_module)

def print_readable(self, large_elements_limit: int = 50):
"""Prints a human readable version of the current compilation IR."""
self.mlir_module.print(large_elements_limit=large_elements_limit)

def save_mlir(self, file_path: Union[str, Path]):
"""Saves the current compilation IR to a path on disk.
Args:
file_path: Path to save the file. If it has a ".mlirbc"
extension, it will be saved as bytecode. Otherwise as
text.
"""
file_path = Path(file_path)
with open(file_path, "wb") as f:
if file_path.suffix == ".mlirbc":
self.mlir_module.write_bytecode(f)
else:
self.mlir_module.print(f, binary=True)

def _run_import(self):
CompiledModule.run_import(self.compiled_module)

def compile(
self,
save_to: SaveableTarget,
*,
target_backends: Union[str, Sequence[str]] = ("llvm-cpu",),
) -> Optional[memoryview]:
"""Compiles the exported program to an executable binary.
Args:
save_to: Where to save the compiled binary. Can be one of:
None: outputs to a memory buffer and return the API Output.
(str, Path): Outputs to a file
Output: Raw compiler API Output object to save to.
target_backends: A comma-delimitted string of IREE target backends or
a sequence of strings.
Returns:
None unless if `save_to=None`, in which case, we return the backing compiler API
Ouptut object. It can be queried for its backing memory via its `map_memory()`
method.
"""
return_memory_view = False
if save_to is None:
output = Output.open_membuffer()
return_memory_view = True
elif isinstance(save_to, (str, Path)):
save_to = Path(save_to)
output = Output.open_file(str(save_to))
else:
assert isinstance(output, Output)
output = save_to

target_backends = (
target_backends
if isinstance(target_backends, str)
else ",".join(target_backends)
)
inv = self.session.invocation()
if self._importer_uses_session:
inv.import_module(self.mlir_module)
else:
# Some platforms can't share the context across the importer and
# session (cough: Windows). Round-trip in this case.
buffer_io = io.BytesIO()
self.mlir_module.write_bytecode(buffer_io)
buffer = buffer_io.getvalue()
source = Source.wrap_buffer(self.session, buffer)
inv.parse_source(source)
inv.enable_console_diagnostics()

# TODO: Don't use flags to set the target backends: set module attributes.
self.session.set_flags(f"--iree-hal-target-backends={target_backends}")
if not inv.execute():
raise RuntimeError("Compilation failed: See diagnostics")

inv.output_vm_bytecode(output)
output.keep()
if return_memory_view:
return output
else:
return None


# Decorator which explicitly exports a function.
# TODO: Make this a public API on CompiledModule.
def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> ExportProcDef:
if f is None:
return functools.partial(export_proc, signature=signature)
return ExportProcDef(f.__name__, f, signature=signature)


def export(mdl: torch.nn.Module, *example_args: torch.Tensor) -> ExportOutput:
"""One shot export of an nn.Module.
This is a very restrictive API vs the lower level `CompiledModule`
facility. It is suitable for one-shot modules, with a single
entrypoint and static example arguments where no additional
configuration is needed for mutable parameters/buffers or state
management. Dynamic shape constraints are also not presently
exposed via this API, but we expect to allow this in the future.
Args:
mdl: The nn.Module to export.
*example_args: Example tensors.
Returns:
An ExportOutput object that wraps the compilation and provides
easy access.
"""
signature = [abstractify(t) for t in example_args]

class Exported(CompiledModule, export_name=mdl._get_name()):
params = export_parameters(mdl)

@export_proc(signature=signature)
def main(self, *args):
return jittable(mdl.forward)(*args)

session = Session()
# There are some bugs with respect to Session/context interop that we
# haven't squashed yet. For now, default everyone to round-tripping
# via bytecode vs sharing the context between the importer/compiler.
importer_uses_session = False and not _is_windows
if importer_uses_session:
context = session.context
else:
context = Context()

cm = Exported(context=context, import_to="import")
return ExportOutput(session, cm, importer_uses_session=importer_uses_session)
1 change: 0 additions & 1 deletion python/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
torch.ops.aten.nll_loss_backward,
torch.ops.aten._to_copy,
torch.ops.aten._log_softmax_backward_data,

]


Expand Down
1 change: 0 additions & 1 deletion pytorch-cpu-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.1.0
28 changes: 28 additions & 0 deletions tests/examples/aot_mlp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
from pathlib import Path
import sys
import subprocess
import unittest

REPO_DIR = Path(__file__).resolve().parent.parent.parent


def _run(local_path: str):
path = REPO_DIR / local_path
subprocess.check_call([sys.executable, str(path)])


class AOTMLPTest(unittest.TestCase):
def testMLPExportSimple(self):
_run("examples/aot_mlp/mlp_export_simple.py")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
1 change: 0 additions & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision

0 comments on commit 6056884

Please sign in to comment.