Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a Resnet 18 Example #268

Merged
merged 22 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/resnet-18/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Dynamic AOT Resnet-18 Example

This example AOT-compiles a Resnet-18 module for performing inference on a dynamic number of input images.

To run this example (with Python3.11), you should clone the repository to your local device and install the requirements in a virtual environment.

```bash
git clone https://github.com/nod-ai/SHARK-Turbine.git
cd SHARK-Turbine/examples/resnet-18
python -m venv rn18_venv
source ./rn18_venv/bin/activate
pip install -r requirements.txt
```

Once the requirements are installed, you should be able to run the example.

```bash
python resnet-18.py
```
2 changes: 2 additions & 0 deletions examples/resnet-18/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth considering moving resnet from examples to python/turbine_models in some form to enable use outside turbine?

shark_turbine==0.9.2
70 changes: 70 additions & 0 deletions examples/resnet-18/resnet-18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
from shark_turbine.aot import *
import iree.runtime as rt

# Loading feature extractor and pretrained model from huggingface
# extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18")
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")


# define a function to do inference
# this will get passed to the compiled module as a jittable function
def forward(pixel_values_tensor: torch.Tensor):
with torch.no_grad():
logits = model.forward(pixel_values_tensor).logits
predicted_id = torch.argmax(logits, -1)
return predicted_id


# a dynamic module for doing inference
# this will be compiled AOT to a memory buffer
class RN18(CompiledModule):
IanNod marked this conversation as resolved.
Show resolved Hide resolved
params = export_parameters(model)

def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)):
# set a constraint for the dynamic number of batches
# interestingly enough, it doesn't seem to limit BATCH_SIZE
const = [x.dynamic_dim(0) < 16]
return jittable(forward)(x, constraints=const)


# build an mlir module with 1-shot exporter
exported = export(RN18)
# compile exported module to a memory buffer
compiled_binary = exported.compile(save_to=None)


# return type is rt.array_interop.DeviceArray
# np.array of outputs can be accessed via to_host() method
def shark_infer(x):
config = rt.Config("local-task")
vmm = rt.load_vm_module(
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
config,
)
y = vmm.forward(x)
return y


# prints the text corresponding to output label codes
def print_labels(id):
for l in id:
print(model.config.id2label[l])


# finds discrepancies between id0 and id1
def compare_labels(id0, id1):
return (id0 != id1).nonzero(as_tuple=True)


# load some examples and check for discrepancies between
# compiled module and standard inference (forward function)

x = torch.randn(10, 3, 224, 224)
y0 = shark_infer(x)
y1 = forward(x)
print_labels(y0)
print(
f"Found {compare_labels(y0,y1)[0].size()[0]} discrepancies between turbine and standard result"
)
129 changes: 129 additions & 0 deletions python/turbine_models/custom_models/resnet_18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import sys
import re

from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
from shark_turbine.aot import *
from iree.compiler.ir import Context
import iree.runtime as rt
from turbine_models.custom_models.sd_inference import utils

import argparse

parser = argparse.ArgumentParser()

parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="microsoft/resnet-18",
)
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--vmfb_path", type=str, default="")
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
parser.add_argument(
"--iree_target_triple",
type=str,
default="",
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")

# TODO: Add other resnet models


class Resnet18Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForImageClassification.from_pretrained(
"microsoft/resnet-18"
)
# self.extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18")

def forward(self, pixel_values_tensor: torch.Tensor):
with torch.no_grad():
logits = self.model.forward(pixel_values_tensor).logits
predicted_id = torch.argmax(logits, -1)
return predicted_id


def export_resnet_18_model(
resnet_model, compile_to="torch", device=None, target_triple=None, max_alloc=None
):
class CompiledResnet18Model(CompiledModule):
params = export_parameters(resnet_model.model)

def main(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)):
const = [x.dynamic_dim(0) < 16]
return jittable(resnet_model.forward)(x, constraints=const)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledResnet18Model(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
if compile_to != "vmfb":
return module_str
else:
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, "resnet_18")


def run_resnet_18_vmfb_comparison(resnet_model, args):
config = rt.Config(args.device)

if args.vmfb_path:
mod = rt.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists("resnet_18.vmfb"):
mod = rt.VmModule.mmap(config.vm_instance, "resnet_18.vmfb")
else:
sys.exit("no vmfb_path provided, required for run_vmfb")

vm_modules = [
mod,
rt.create_hal_module(config.vm_instance, config.device),
]
ctx = rt.SystemContext(
vm_modules=vm_modules,
config=config,
)
inp = torch.rand(5, 3, 224, 224, dtype=torch.float32)
device_inputs = [rt.asdevicearray(config.device, inp)]

# Turbine output
CompModule = ctx.modules.compiled_resnet18_model
turbine_output = CompModule["main"](*device_inputs)
print(
"TURBINE OUTPUT:",
turbine_output.to_host(),
turbine_output.to_host().shape,
turbine_output.to_host().dtype,
)

# Torch output
torch_output = resnet_model.forward(inp)
torch_output = torch_output.detach().cpu().numpy()
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)

err = utils.largest_error(torch_output, turbine_output)
print("LARGEST ERROR:", err)
assert err < 9e-5


if __name__ == "__main__":
args = parser.parse_args()
resnet_model = Resnet18Model()
if args.run_vmfb:
run_resnet_18_vmfb_comparison(resnet_model, args)
else:
mod_str = export_resnet_18_model(
resnet_model,
args.compile_to,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = "resnet_18"
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
39 changes: 39 additions & 0 deletions python/turbine_models/tests/resnet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import argparse
import logging
from turbine_models.custom_models import resnet_18
import unittest
import os
import pytest

arguments = {
"run_vmfb": True,
"compile_to": None,
"vmfb_path": "",
"device": "local-task",
"iree_target_triple": "",
"vulkan_max_allocation": "4294967296",
}

resnet_model = resnet_18.Resnet18Model()


class Resnet18Test(unittest.TestCase):
@pytest.mark.xfail(
reason="caused by lack of support for DenseResourceElementsAttr iteration over a generic FloatAttr"
)
def testExportResnet18Model(self):
with self.assertRaises(SystemExit) as cm:
resnet_18.export_resnet_18_model(
resnet_model,
"vmfb",
"cpu",
)
self.assertEqual(cm.exception.code, None)
namespace = argparse.Namespace(**arguments)
resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace)
os.remove("resnet_18.vmfb")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
Loading