Skip to content

Add a simple aot.export entrypoint for one-short nn.Module export. #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ jobs:
- name: Sync source deps
run: |
python -m pip install --upgrade pip
pip install --upgrade -r requirements.txt
pip install --upgrade -e .[torch,testing]
pip install --index-url https://download.pytorch.org/whl/cpu \
-r pytorch-cpu-requirements.txt \
-r torchvision-requirements.txt
pip install -e .[testing]

- name: Run tests
run: |
Expand Down
53 changes: 53 additions & 0 deletions examples/aot_mlp/mlp_export_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
import torch.nn as nn

import shark_turbine.aot as aot


class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Linear(8, 8, bias=True)
self.layer1 = nn.Linear(8, 4, bias=True)
self.layer2 = nn.Linear(4, 2, bias=True)
self.layer3 = nn.Linear(2, 2, bias=True)

def forward(self, x: torch.Tensor):
x = self.layer0(x)
x = torch.sigmoid(x)
x = self.layer1(x)
x = torch.sigmoid(x)
x = self.layer2(x)
x = torch.sigmoid(x)
x = self.layer3(x)
return x


model = MLP()
example_x = torch.empty(97, 8, dtype=torch.float32)
exported = aot.export(model, example_x)
exported.print_readable()
compiled_binary = exported.compile(save_to=None)


def infer():
import numpy as np
import iree.runtime as rt

config = rt.Config("local-task")
vmm = rt.load_vm_module(
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
config,
)
x = np.random.rand(97, 8).astype(np.float32)
y = vmm.main(x)
print(y.to_host())


infer()
1 change: 1 addition & 0 deletions python/shark_turbine/aot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .compiled_module import CompiledModule
from .exporter import export

from .builtins import *
192 changes: 192 additions & 0 deletions python/shark_turbine/aot/exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional, Sequence, Union
import functools
import io
from pathlib import Path
import platform

import torch

from iree.compiler.api import (
Invocation,
Session,
Source,
Output,
)

from .builtins import *
from .compiled_module import (
CompiledModule,
ExportProcDef,
)
from .support.ir_imports import (
Context,
Operation,
)
from .support.procedural import (
AbstractTypedef,
)


_is_windows = platform.system() == "Windows"


SaveableTarget = Union[str, Path, None, Output]


class ExportOutput:
"""Wrapper around a CompiledModule produced by `export`."""

def __init__(
self,
session: Session,
compiled_module: CompiledModule,
*,
importer_uses_session: bool = False,
):
self.session = session
self.session.set_flags("--iree-input-type=torch")
self.compiled_module = compiled_module
self._importer_uses_session = importer_uses_session

@property
def mlir_module(self) -> Operation:
"""Gets the MLIR module resulting from the last compilation phase."""
return CompiledModule.get_mlir_module(self.compiled_module)

def print_readable(self, large_elements_limit: int = 50):
"""Prints a human readable version of the current compilation IR."""
self.mlir_module.print(large_elements_limit=large_elements_limit)

def save_mlir(self, file_path: Union[str, Path]):
"""Saves the current compilation IR to a path on disk.

Args:
file_path: Path to save the file. If it has a ".mlirbc"
extension, it will be saved as bytecode. Otherwise as
text.
"""
file_path = Path(file_path)
with open(file_path, "wb") as f:
if file_path.suffix == ".mlirbc":
self.mlir_module.write_bytecode(f)
else:
self.mlir_module.print(f, binary=True)

def _run_import(self):
CompiledModule.run_import(self.compiled_module)

def compile(
self,
save_to: SaveableTarget,
*,
target_backends: Union[str, Sequence[str]] = ("llvm-cpu",),
) -> Optional[memoryview]:
"""Compiles the exported program to an executable binary.

Args:
save_to: Where to save the compiled binary. Can be one of:
None: outputs to a memory buffer and return the API Output.
(str, Path): Outputs to a file
Output: Raw compiler API Output object to save to.
target_backends: A comma-delimitted string of IREE target backends or
a sequence of strings.
Returns:
None unless if `save_to=None`, in which case, we return the backing compiler API
Ouptut object. It can be queried for its backing memory via its `map_memory()`
method.
"""
return_memory_view = False
if save_to is None:
output = Output.open_membuffer()
return_memory_view = True
elif isinstance(save_to, (str, Path)):
save_to = Path(save_to)
output = Output.open_file(str(save_to))
else:
assert isinstance(output, Output)
output = save_to

target_backends = (
target_backends
if isinstance(target_backends, str)
else ",".join(target_backends)
)
inv = self.session.invocation()
if self._importer_uses_session:
inv.import_module(self.mlir_module)
else:
# Some platforms can't share the context across the importer and
# session (cough: Windows). Round-trip in this case.
buffer_io = io.BytesIO()
self.mlir_module.write_bytecode(buffer_io)
buffer = buffer_io.getvalue()
source = Source.wrap_buffer(self.session, buffer)
inv.parse_source(source)
inv.enable_console_diagnostics()

# TODO: Don't use flags to set the target backends: set module attributes.
self.session.set_flags(f"--iree-hal-target-backends={target_backends}")
if not inv.execute():
raise RuntimeError("Compilation failed: See diagnostics")

inv.output_vm_bytecode(output)
output.keep()
if return_memory_view:
return output
else:
return None


# Decorator which explicitly exports a function.
# TODO: Make this a public API on CompiledModule.
def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> ExportProcDef:
if f is None:
return functools.partial(export_proc, signature=signature)
return ExportProcDef(f.__name__, f, signature=signature)


def export(mdl: torch.nn.Module, *example_args: torch.Tensor) -> ExportOutput:
"""One shot export of an nn.Module.

This is a very restrictive API vs the lower level `CompiledModule`
facility. It is suitable for one-shot modules, with a single
entrypoint and static example arguments where no additional
configuration is needed for mutable parameters/buffers or state
management. Dynamic shape constraints are also not presently
exposed via this API, but we expect to allow this in the future.

Args:
mdl: The nn.Module to export.
*example_args: Example tensors.

Returns:
An ExportOutput object that wraps the compilation and provides
easy access.
"""
signature = [abstractify(t) for t in example_args]

class Exported(CompiledModule, export_name=mdl._get_name()):
params = export_parameters(mdl)

@export_proc(signature=signature)
def main(self, *args):
return jittable(mdl.forward)(*args)

session = Session()
# There are some bugs with respect to Session/context interop that we
# haven't squashed yet. For now, default everyone to round-tripping
# via bytecode vs sharing the context between the importer/compiler.
importer_uses_session = False and not _is_windows
if importer_uses_session:
context = session.context
else:
context = Context()

cm = Exported(context=context, import_to="import")
return ExportOutput(session, cm, importer_uses_session=importer_uses_session)
1 change: 0 additions & 1 deletion python/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
torch.ops.aten.nll_loss_backward,
torch.ops.aten._to_copy,
torch.ops.aten._log_softmax_backward_data,

]


Expand Down
1 change: 0 additions & 1 deletion pytorch-cpu-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.1.0
28 changes: 28 additions & 0 deletions tests/examples/aot_mlp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
from pathlib import Path
import sys
import subprocess
import unittest

REPO_DIR = Path(__file__).resolve().parent.parent.parent


def _run(local_path: str):
path = REPO_DIR / local_path
subprocess.check_call([sys.executable, str(path)])


class AOTMLPTest(unittest.TestCase):
def testMLPExportSimple(self):
_run("examples/aot_mlp/mlp_export_simple.py")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
1 change: 0 additions & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision