-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shark_turbine 0.9.3 Breaks a Resnet-18 Model #282
Comments
I'll see if I can figure out how to get more a more helpful stack trace for this smaller reproduction. In the meantime, here is the result of running a resnet-18 unit test in #268 :
|
Would be helpful to look at the Torch IR that we are compiling through IREE that is giving that error and linking to it in a gist. You should be able to save that mlir with exported.save_mlir("mlir path"). The recommended path is generally to use the lower level CompiledModule as you can see in python/turbine_models/custom_models/stateless_llama.py which we save the mlir when the flag --compile_to=torch or linalg. |
This is related to the switch from DenseElementsAttr -> DenseResourceElementsAttr. Avi and Sai did a pass through IREE and ported things to use a more generic interface but this must have been missed. DenseResourceElementsAttr does not support iteration through a generic FloatAttr. A stack trace would isolate the failure point and then it will need to be fixed. We should get this model in the CI. |
Here's a gist containing the torch IR Both IRs are generated from the following python code: from transformers import AutoModelForImageClassification
import torch
from shark_turbine.aot import *
from iree.compiler.ir import Context
from iree.compiler.api import Session
import iree.runtime as rt
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")
def forward(pixel_values_tensor: torch.Tensor):
with torch.no_grad():
logits = model.forward(pixel_values_tensor).logits
predicted_id = torch.argmax(logits, -1)
return predicted_id
class RN18(CompiledModule):
params = export_parameters(model,external=True)
def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)):
const = [x.dynamic_dim(0) < 16]
return jittable(forward)(x, constraints=const)
inst = RN18(context=Context(), import_to="INPUT")
torch_str = str(CompiledModule.get_mlir_module(inst))
with open("resnet18.mlir", "w+") as f:
f.write(torch_str)
session = Session()
ExportedModule = exporter.ExportOutput(session, inst)
mlir_module = str(ExportedModule.mlir_module)
with open("exported_resnet18.mlir", "w+") as f:
f.write(mlir_module) Then running the following will produce an error message: import iree.compiler as ic
flags = [
"--iree-input-type=torch",
"--mlir-print-debuginfo",
"--mlir-print-op-on-diagnostic=false",
"--iree-llvmcpu-target-cpu-features=host",
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-opt-const-expr-hoisting=False",
"--iree-llvmcpu-enable-ukernels=all",
]
flatbuffer_blob = ic.compile_str(
torch_str,
target_backends=["llvm-cpu"],
extra_args=flags) The error occurs upon invoking ic.compile_st:
|
I'll follow the instructions on how to get more information from that error message and add another comment |
After following the error message instructions, the information in IREE_SAVE_TEMPS is saved in this gist The error message with llvm-symbolizer linked:
|
Using the dependencies:
the following python code will successfully execute:
However, using the newer release of shark_turbine (0.9.3) and running the same python script will result in an error message.
Related to #268
The text was updated successfully, but these errors were encountered: