-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding a Resnet 18 Example #268
Changes from 14 commits
486bd75
1760046
16851ae
3f30cb8
83a62e7
c78b43d
6535e10
bbaa2c1
88f2d88
7baabab
5e2e2de
cd64f3a
f548a5c
9500da3
a5f4e38
e094299
ef0f9ae
3b8bab0
ccc1f4b
39d4ee6
0fef610
ccaefaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
transformers | ||
datasets | ||
shark_turbine |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
IanNod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | ||
import torch | ||
from shark_turbine.aot import * | ||
from iree.compiler.ir import Context | ||
from iree.compiler.api import Session | ||
import iree.runtime as rt | ||
from datasets import load_dataset | ||
|
||
# Loading feature extractor and pretrained model from huggingface | ||
extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18") | ||
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18") | ||
|
||
# load an example | ||
dataset = load_dataset("huggingface/cats-image") | ||
image = dataset["test"]["image"][0] | ||
|
||
# if you want to see the cat picture: | ||
# image.save("cats-image.jpg") | ||
|
||
# if you want to run a custom image through inference. | ||
# import PIL | ||
# image = PIL.JpegImagePlugin.JpegImageFile("yourexamplepicture.jpg") | ||
|
||
# extract features from image to feed to model | ||
inputs = extractor(image, return_tensors="pt") | ||
pixel_tensor = inputs.pixel_values | ||
|
||
|
||
# define a function to do inference | ||
# this will get passed to the compiled module as a jittable function | ||
def forward(pixel_values_tensor: torch.Tensor): | ||
with torch.no_grad(): | ||
logits = model.forward(pixel_values_tensor).logits | ||
predicted_id = torch.argmax(logits, -1) | ||
return predicted_id | ||
|
||
|
||
# A dynamic module for doing inference | ||
class RN18(CompiledModule): | ||
IanNod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
params = export_parameters(model) | ||
|
||
def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)): | ||
# set a constraint for the dynamic number of batches | ||
const = [x.dynamic_dim(0) < 16] | ||
return jittable(forward)(x, constraints=const) | ||
|
||
|
||
# build an mlir module to compile with 1-shot exporter | ||
exported = export(RN18) | ||
|
||
compiled_binary = exported.compile(save_to=None) | ||
|
||
|
||
# return type is rt.array_interop.DeviceArray | ||
# np.array of outputs can be accessed via to_host() method | ||
def shark_infer(x): | ||
config = rt.Config("local-task") | ||
vmm = rt.load_vm_module( | ||
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), | ||
config, | ||
) | ||
y = vmm.forward(x) | ||
return y | ||
|
||
|
||
# prints the text labels for output ids | ||
def print_labels(id): | ||
for num in id: | ||
print(model.config.id2label[num]) | ||
|
||
|
||
# not sure what the point was of the dynamic dim constraint | ||
# also amusing that random tensors are always jellyfish | ||
x = torch.randn(17, 3, 224, 224) | ||
x[2] = pixel_tensor | ||
y = shark_infer(x) | ||
print_labels(y.to_host()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should add a test for this to ensure we don't break our example models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added something like a test, but I'm not sure if it is what you were requesting. I downloaded a bigger dataset and compared the results of the compiled module with the non SHARK-Turbine forward function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was more thinking unit tests that will run on the CI so anytime our code is updated it will be checked for any kind of failures on this model. An example of the aot mlp example can be found here for reference https://github.com/nod-ai/SHARK-Turbine/blob/main/tests/examples/aot_mlp_test.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Amusingly enough, this example broke today upon updating the pip release of shark_turbine from 9.2 to 9.3. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is unfortunate, my condolences. This does highlight the fact that the tests are important as whatever change broke this example would have failed and would need to be addressed to get merged. It may be helpful if you can post an issue of what the error is and maybe a minimal reproducer and we can get more eyes on the problem There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do! Thanks for all the comments so far, @IanNod There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Of course, thanks for the work you are doing! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth considering moving resnet from examples to python/turbine_models in some form to enable use outside turbine?