Skip to content

Commit

Permalink
[ONNX][EP] Add provider_options parameter in IREE-EP backend (nod-ai#330
Browse files Browse the repository at this point in the history
)

Signed-off-by: Gaurav Shukla <gaurav.shukla@amd.com>
  • Loading branch information
Shukla-Gaurav authored Sep 25, 2024
1 parent d535d60 commit 34dc084
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 13 deletions.
43 changes: 33 additions & 10 deletions alt_e2eshark/e2e_testing/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,26 +145,49 @@ def func(x: TestTensors) -> str:

class OnnxrtIreeEpBackend(BackendBase):
'''This backend uses onnxrt iree-ep to compile and run onnx models for a specified hal_target_backend'''
def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", providers=["IreeExecutionProvider"], inter_op_num_threads=None, intra_op_num_threads=None):
# may need the device and target_backend for the future (e.g., when IREE-EP has support for specifying)
def __init__(self, *, device="local-task", hal_target_device="llvm-cpu", extra_args : List[str] = None):
self.device = device
self.hal_target_backend = hal_target_backend

self.providers=providers
# TODO: have more session options be optionally configurable by init args
self.hal_target_device = hal_target_device
if extra_args:
self.extra_args = []
for a in extra_args:
if a[0:2] == "--":
self.extra_args.append(a)
else:
self.extra_args.append("--" + a)
elif hal_target_device == "hip":
# some extra args for Mi250 - some of these may not work for other chips
self.extra_args = [
"--iree-hip-target=gfx90a",
]
elif hal_target_device == "llvm-cpu":
self.extra_args = [
"--iree-input-demote-i64-to-i32",
# "--iree-llvmcpu-fail-on-large-vector=0",
# "--iree-llvmcpu-stack-allocation-limit=300000",
]
self.providers = ["IreeExecutionProvider"]
# set provider options.
provider_options_dict = dict()
provider_options_dict["hal_target_device"] = self.hal_target_device
provider_options_dict["device"] = self.device
provider_options_dict["compile_time_flags"] = "+".join(self.extra_args)
self.provider_options = [provider_options_dict]
self.sess_opt = ort.SessionOptions()
self.sess_opt.execution_mode = ort.ExecutionMode.ORT_PARALLEL
if inter_op_num_threads:
self.sess_opt.inter_op_num_threads = inter_op_num_threads
if intra_op_num_threads:
self.sess_opt.intra_op_num_threads = intra_op_num_threads
# sess_opt.log_verbosity_level = 0
# self.sess_opt.log_severity_level = 0

def compile(self, model: ModelProto, *, save_to: str = None) -> ort.InferenceSession:
if self.provider_options:
provider_options_dict = self.provider_options[0]
provider_options_dict["save_to"] = save_to

session = ort.InferenceSession(
model.SerializeToString(),
self.sess_opt,
providers=self.providers,
provider_options=self.provider_options,
)
# can't save an onnx runtime session
return session
Expand Down
1 change: 1 addition & 0 deletions alt_e2eshark/onnx_tests/helper_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class AzureDownloadableModel(OnnxModelInfo):
"""This class can be used for models in our azure storage (both private and public)."""
def __init__(self, name: str, onnx_model_path: str):
# TODO: Extract opset version from onnx.version.opset
opset_version = 21
parent_cache_dir = os.getenv('CACHE_DIR')
if not parent_cache_dir:
Expand Down
35 changes: 35 additions & 0 deletions alt_e2eshark/onnx_tests/models/external_lists/onnxrt-iree-ep.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
AlexNet_vaiq_int8
beit_base_patch16_384.in22k_ft_in22k_in1k
deit3_large_patch16_224.fb_in1k
deit_base_distilled_patch16_224.fb_in1k
DenseNet201_vaiq_int8
efficientformer_l1.snap_dist_in1k
EfficientNet_v2_s_vaiq_int8
eva_large_patch14_336.in22k_ft_in22k_in1k
flexivit_base.1200ep_in1k
flexivit_small.1200ep_in1k
focalnet_base_lrf.ms_in1k
focalnet_tiny_srf.ms_in1k
gmixer_24_224.ra3_in1k
gmlp_s16_224.ra3_in1k
GoogLeNet_vaiq_int8
levit_256.fb_dist_in1k
levit_384.fb_dist_in1k
mixer_b16_224.goog_in21k_ft_in1k
mixer_b16_224.miil_in21k_ft_in1k
pit_b_distilled_224
pit_ti_224
poolformer_m36
poolformer_m48
RegNet_y_8gf_vaiq_int8
regnety_120.sw_in12k
regnety_320.seer
regnety_640.seer
resmlp_big_24_224.fb_in1k
ResNet152_vaiq_int8
resnet50
resnet50_vaiq_int8
resnetrs420
VGG11_bn_vaiq_int8
VGG19_vaiq_int8
WideResNet_50_2_vaiq_int8
2 changes: 1 addition & 1 deletion alt_e2eshark/onnx_tests/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from .vision_models import *
from .deeplab import *
from .migraphx import *
from .nlp import *
from .nlp import *
5 changes: 3 additions & 2 deletions alt_e2eshark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def main(args):
elif args.mode == "ort-ep":
# TODO: allow specifying provider explicitly from cl args.
config = OnnxEpTestConfig(
str(TEST_DIR), OnnxrtIreeEpBackend(device=args.device, hal_target_backend=args.backend))
str(TEST_DIR), OnnxrtIreeEpBackend(device=args.device, hal_target_device=args.backend))
else:
raise NotImplementedError(f"unsupported mode: {args.mode}")

# get test list
test_list = get_tests(args.groups, args.test_filter, args.testsfile)
test_list.sort()

#setup test stages
stages = ALL_STAGES if args.benchmark else DEFAULT_STAGES
Expand Down Expand Up @@ -316,7 +317,7 @@ def _get_argparse():
parser.add_argument(
"-b",
"--backend",
choices=["llvm-cpu", "amd-aie", "rocm", "cuda", "vmvx", "metal-spirv", "vulkan-spirv"],
choices=["llvm-cpu", "amd-aie", "rocm", "hip", "cuda", "vmvx", "metal-spirv", "vulkan-spirv"],
default="llvm-cpu",
help="specifies the iree-hal-target-backend for compile phase",
)
Expand Down

0 comments on commit 34dc084

Please sign in to comment.