-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding a Resnet 18 Example #268
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
486bd75
Added a resnet-18 example
zjgarvey 1760046
Merge branch 'nod-ai:main' into main
zjgarvey 16851ae
Trying different forwards
zjgarvey 3f30cb8
Merge branch 'main' of https://github.com/zjgarvey/SHARK-Turbine
zjgarvey 83a62e7
Split different approaches to forward
zjgarvey c78b43d
Added back torch.no_grad() context manager
zjgarvey 6535e10
Made a python file for the resnet-18 example
zjgarvey bbaa2c1
Updated the notebook
zjgarvey 88f2d88
Merge branch 'nod-ai:main' into main
zjgarvey 7baabab
Cleaned up Resnet Example
zjgarvey 5e2e2de
Black Reformatting
zjgarvey cd64f3a
Resnet example now uses 1-shot export
zjgarvey f548a5c
black
zjgarvey 9500da3
deleted mlir
zjgarvey a5f4e38
Added README and a test
zjgarvey e094299
ran black for formatting
zjgarvey ef0f9ae
Simplified dependencies
zjgarvey 3b8bab0
Simplified Example; removed requirements
zjgarvey ccc1f4b
Update README.md
zjgarvey 39d4ee6
added resnet 18 to custom models, added unit test
zjgarvey 0fef610
fixed an attribute error in resnet 18 model
zjgarvey ccaefaa
Marked Resnet18 test as xfail
zjgarvey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Dynamic AOT Resnet-18 Example | ||
|
||
This example AOT-compiles a Resnet-18 module for performing inference on a dynamic number of input images. | ||
|
||
To run this example (with Python3.11), you should clone the repository to your local device and install the requirements in a virtual environment. | ||
|
||
```bash | ||
git clone https://github.com/nod-ai/SHARK-Turbine.git | ||
cd SHARK-Turbine/examples/resnet-18 | ||
python -m venv rn18_venv | ||
source ./rn18_venv/bin/activate | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Once the requirements are installed, you should be able to run the example. | ||
|
||
```bash | ||
python resnet-18.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
transformers | ||
shark_turbine==0.9.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | ||
import torch | ||
from shark_turbine.aot import * | ||
import iree.runtime as rt | ||
|
||
# Loading feature extractor and pretrained model from huggingface | ||
# extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18") | ||
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18") | ||
|
||
|
||
# define a function to do inference | ||
# this will get passed to the compiled module as a jittable function | ||
def forward(pixel_values_tensor: torch.Tensor): | ||
with torch.no_grad(): | ||
logits = model.forward(pixel_values_tensor).logits | ||
predicted_id = torch.argmax(logits, -1) | ||
return predicted_id | ||
|
||
|
||
# a dynamic module for doing inference | ||
# this will be compiled AOT to a memory buffer | ||
class RN18(CompiledModule): | ||
IanNod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
params = export_parameters(model) | ||
|
||
def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)): | ||
# set a constraint for the dynamic number of batches | ||
# interestingly enough, it doesn't seem to limit BATCH_SIZE | ||
const = [x.dynamic_dim(0) < 16] | ||
return jittable(forward)(x, constraints=const) | ||
|
||
|
||
# build an mlir module with 1-shot exporter | ||
exported = export(RN18) | ||
# compile exported module to a memory buffer | ||
compiled_binary = exported.compile(save_to=None) | ||
|
||
|
||
# return type is rt.array_interop.DeviceArray | ||
# np.array of outputs can be accessed via to_host() method | ||
def shark_infer(x): | ||
config = rt.Config("local-task") | ||
vmm = rt.load_vm_module( | ||
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), | ||
config, | ||
) | ||
y = vmm.forward(x) | ||
return y | ||
|
||
|
||
# prints the text corresponding to output label codes | ||
def print_labels(id): | ||
for l in id: | ||
print(model.config.id2label[l]) | ||
|
||
|
||
# finds discrepancies between id0 and id1 | ||
def compare_labels(id0, id1): | ||
return (id0 != id1).nonzero(as_tuple=True) | ||
|
||
|
||
# load some examples and check for discrepancies between | ||
# compiled module and standard inference (forward function) | ||
|
||
x = torch.randn(10, 3, 224, 224) | ||
y0 = shark_infer(x) | ||
y1 = forward(x) | ||
print_labels(y0) | ||
print( | ||
f"Found {compare_labels(y0,y1)[0].size()[0]} discrepancies between turbine and standard result" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import os | ||
import sys | ||
import re | ||
|
||
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | ||
import torch | ||
from shark_turbine.aot import * | ||
from iree.compiler.ir import Context | ||
import iree.runtime as rt | ||
from turbine_models.custom_models.sd_inference import utils | ||
|
||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--hf_model_name", | ||
type=str, | ||
help="HF model name", | ||
default="microsoft/resnet-18", | ||
) | ||
parser.add_argument("--run_vmfb", action="store_true") | ||
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") | ||
parser.add_argument("--vmfb_path", type=str, default="") | ||
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") | ||
parser.add_argument( | ||
"--iree_target_triple", | ||
type=str, | ||
default="", | ||
help="Specify vulkan target triple or rocm/cuda target device.", | ||
) | ||
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") | ||
|
||
# TODO: Add other resnet models | ||
|
||
|
||
class Resnet18Model(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.model = AutoModelForImageClassification.from_pretrained( | ||
"microsoft/resnet-18" | ||
) | ||
# self.extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18") | ||
|
||
def forward(self, pixel_values_tensor: torch.Tensor): | ||
with torch.no_grad(): | ||
logits = self.model.forward(pixel_values_tensor).logits | ||
predicted_id = torch.argmax(logits, -1) | ||
return predicted_id | ||
|
||
|
||
def export_resnet_18_model( | ||
resnet_model, compile_to="torch", device=None, target_triple=None, max_alloc=None | ||
): | ||
class CompiledResnet18Model(CompiledModule): | ||
params = export_parameters(resnet_model.model) | ||
|
||
def main(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)): | ||
const = [x.dynamic_dim(0) < 16] | ||
return jittable(resnet_model.forward)(x, constraints=const) | ||
|
||
import_to = "INPUT" if compile_to == "linalg" else "IMPORT" | ||
inst = CompiledResnet18Model(context=Context(), import_to=import_to) | ||
|
||
module_str = str(CompiledModule.get_mlir_module(inst)) | ||
if compile_to != "vmfb": | ||
return module_str | ||
else: | ||
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, "resnet_18") | ||
|
||
|
||
def run_resnet_18_vmfb_comparison(resnet_model, args): | ||
config = rt.Config(args.device) | ||
|
||
if args.vmfb_path: | ||
mod = rt.VmModule.mmap(config.vm_instance, args.vmfb_path) | ||
elif os.path.exists("resnet_18.vmfb"): | ||
mod = rt.VmModule.mmap(config.vm_instance, "resnet_18.vmfb") | ||
else: | ||
sys.exit("no vmfb_path provided, required for run_vmfb") | ||
|
||
vm_modules = [ | ||
mod, | ||
rt.create_hal_module(config.vm_instance, config.device), | ||
] | ||
ctx = rt.SystemContext( | ||
vm_modules=vm_modules, | ||
config=config, | ||
) | ||
inp = torch.rand(5, 3, 224, 224, dtype=torch.float32) | ||
device_inputs = [rt.asdevicearray(config.device, inp)] | ||
|
||
# Turbine output | ||
CompModule = ctx.modules.compiled_resnet18_model | ||
turbine_output = CompModule["main"](*device_inputs) | ||
print( | ||
"TURBINE OUTPUT:", | ||
turbine_output.to_host(), | ||
turbine_output.to_host().shape, | ||
turbine_output.to_host().dtype, | ||
) | ||
|
||
# Torch output | ||
torch_output = resnet_model.forward(inp) | ||
torch_output = torch_output.detach().cpu().numpy() | ||
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) | ||
|
||
err = utils.largest_error(torch_output, turbine_output) | ||
print("LARGEST ERROR:", err) | ||
assert err < 9e-5 | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
resnet_model = Resnet18Model() | ||
if args.run_vmfb: | ||
run_resnet_18_vmfb_comparison(resnet_model, args) | ||
else: | ||
mod_str = export_resnet_18_model( | ||
resnet_model, | ||
args.compile_to, | ||
args.device, | ||
args.iree_target_triple, | ||
args.vulkan_max_allocation, | ||
) | ||
safe_name = "resnet_18" | ||
with open(f"{safe_name}.mlir", "w+") as f: | ||
f.write(mod_str) | ||
print("Saved to", safe_name + ".mlir") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import argparse | ||
import logging | ||
from turbine_models.custom_models import resnet_18 | ||
import unittest | ||
import os | ||
import pytest | ||
|
||
arguments = { | ||
"run_vmfb": True, | ||
"compile_to": None, | ||
"vmfb_path": "", | ||
"device": "local-task", | ||
"iree_target_triple": "", | ||
"vulkan_max_allocation": "4294967296", | ||
} | ||
|
||
resnet_model = resnet_18.Resnet18Model() | ||
|
||
|
||
class Resnet18Test(unittest.TestCase): | ||
@pytest.mark.xfail( | ||
reason="caused by lack of support for DenseResourceElementsAttr iteration over a generic FloatAttr" | ||
) | ||
def testExportResnet18Model(self): | ||
with self.assertRaises(SystemExit) as cm: | ||
resnet_18.export_resnet_18_model( | ||
resnet_model, | ||
"vmfb", | ||
"cpu", | ||
) | ||
self.assertEqual(cm.exception.code, None) | ||
namespace = argparse.Namespace(**arguments) | ||
resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) | ||
os.remove("resnet_18.vmfb") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.DEBUG) | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth considering moving resnet from examples to python/turbine_models in some form to enable use outside turbine?