diff --git a/alt_e2eshark/e2e_testing/backends.py b/alt_e2eshark/e2e_testing/backends.py index a2c473a0..70a10385 100644 --- a/alt_e2eshark/e2e_testing/backends.py +++ b/alt_e2eshark/e2e_testing/backends.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import abc import onnxruntime as ort -from typing import TypeVar +from typing import TypeVar, List from e2e_testing.storage import TestTensors from e2e_testing.framework import CompiledOutput, ModelArtifact from onnx import ModelProto @@ -30,20 +30,45 @@ def load(self, artifact: CompiledOutput, func_name: str) -> Invoker: class SimpleIREEBackend(BackendBase): '''This backend uses iree to compile and run MLIR modules for a specified hal_target_backend''' - def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu"): + def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", extra_args : List[str] = None): self.device = device self.hal_target_backend = hal_target_backend + if extra_args: + self.extra_args = [] + for a in extra_args: + if a[0:2] == "--": + self.extra_args.append(a) + else: + self.extra_args.append("--" + a) + elif hal_target_backend == "rocm": + # some extra args for Mi300x - some of these may not work for other chips + self.extra_args = [ + "--iree-rocm-target-chip=gfx942", + # "--iree-global-opt-propagate-transposes=true", + # "--iree-opt-outer-dim-concat=true", + # "--iree-opt-const-eval=false", + # "--iree-rocm-waves-per-eu=2", + # "--iree-llvmgpu-enable-prefetch", + # "--iree-flow-enable-aggressive-fusion", + # "--iree-flow-enable-fuse-horizontal-contractions=true", + # "--iree-opt-aggressively-propagate-transposes=true", + # "--iree-codegen-llvmgpu-use-vector-distribution=true", + # "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))", + # maybe add iree-preprocessing-transpose-convolution-pipeline to preprocessing pipeline. + ] + elif hal_target_backend == "llvm-cpu": + self.extra_args = [ + "--iree-input-demote-i64-to-i32", + # "--iree-llvmcpu-fail-on-large-vector=0", + # "--iree-llvmcpu-stack-allocation-limit=300000", + ] def compile(self, module, *, save_to: str = None): # compile to a vmfb for llvm-cpu b = ireec.tools.compile_str( str(module), target_backends=[self.hal_target_backend], - extra_args=[ - "--iree-input-demote-i64-to-i32", - # "--iree-llvmcpu-fail-on-large-vector=0", - # "--iree-llvmcpu-stack-allocation-limit=300000", - ], + extra_args=self.extra_args, ) # log the vmfb if save_to: diff --git a/alt_e2eshark/e2e_testing/framework.py b/alt_e2eshark/e2e_testing/framework.py index baf71d85..71e7c75f 100644 --- a/alt_e2eshark/e2e_testing/framework.py +++ b/alt_e2eshark/e2e_testing/framework.py @@ -30,6 +30,7 @@ def __init__( self.model = os.path.join(onnx_model_path, "model.onnx") self.opset_version = opset_version self.sess_options = ort.SessionOptions() + self.dim_param_dict = None def forward(self, input: Optional[TestTensors] = None) -> TestTensors: """Applies self.model to self.input. Only override if necessary for specific models""" @@ -55,6 +56,12 @@ def update_sess_options(self): """ pass + def update_dim_param_dict(self): + """Can be overridden to modify a dictionary of dim parameters (self.dim_param_dict) used to + construct inputs for a model with dynamic dims. + """ + pass + def construct_model(self): """a method to be overwritten. To make a new test, define a subclass with an override for this method""" raise NotImplementedError( @@ -65,7 +72,10 @@ def construct_inputs(self): """can be overridden to generate specific inputs, but a default is provided for convenience""" if not os.path.exists(self.model): self.construct_model() - return get_sample_inputs_for_onnx_model(self.model) + self.update_dim_param_dict() + # print(self.get_signature()) + # print(get_op_frequency(self.model)) + return get_sample_inputs_for_onnx_model(self.model, self.dim_param_dict) def apply_postprocessing(self, output: TestTensors): """can be overridden to define post-processing methods for individual models""" @@ -73,6 +83,7 @@ def apply_postprocessing(self, output: TestTensors): def save_processed_output(self, output: TestTensors, save_to: str, name: str): """can be overridden to provide instructions on saving processed outputs (e.g., images, labels, text)""" + pass # the following helper methods aren't meant to be overriden @@ -80,7 +91,7 @@ def get_signature(self, *, from_inputs=True): """Returns the input or output signature of self.model""" if not os.path.exists(self.model): self.construct_model() - return get_signature_for_onnx_model(self.model, from_inputs=from_inputs) + return get_signature_for_onnx_model(self.model, from_inputs=from_inputs, dim_param_dict=self.dim_param_dict) def load_inputs(self, dir_path): """computes the input signature of the onnx model and loads inputs from bin files""" @@ -102,6 +113,16 @@ def load_golden_outputs(self, dir_path): """computes the input signature of the onnx model and loads golden outputs from bin files""" shapes, dtypes = self.get_signature(from_inputs=False) return TestTensors.load_from(shapes, dtypes, dir_path, "golden_output") + + def update_opset_version_and_overwrite(self): + if self.opset_version: + if not os.path.exists(self.model): + self.construct_model() + og_model = onnx.load(self.model) + model = onnx.version_converter.convert_version( + og_model, self.opset_version + ) + onnx.save(model, self.model) # TODO: extend TestModel to a union, or make TestModel a base class when supporting other frontends TestModel = OnnxModelInfo diff --git a/alt_e2eshark/e2e_testing/onnx_utils.py b/alt_e2eshark/e2e_testing/onnx_utils.py index b2c1cb80..d665c9cf 100644 --- a/alt_e2eshark/e2e_testing/onnx_utils.py +++ b/alt_e2eshark/e2e_testing/onnx_utils.py @@ -3,6 +3,8 @@ import onnxruntime import torch from e2e_testing.storage import TestTensors +from typing import Optional +from pathlib import Path def dtype_from_ort_node(node): @@ -26,42 +28,52 @@ def dtype_from_ort_node(node): raise NotImplementedError(f"Unhandled dtype string found: {dtypestr}") -def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg): +def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None): """A convenience function for generating sample inputs for an onnxruntime node""" + int_dims = [] for dim in node.shape: + if isinstance(dim, str) and dim_param_dict: + if not dim in dim_param_dict.keys(): + raise ValueError(f"input node {node.name} has a dim param='{dim}' not found in provided dim_param_dict: '{dim_param_dict}'") + else: + int_dims.append(dim_param_dict[dim]) + continue if not isinstance(dim, int): raise TypeError( - f"input node '{node.name}' has a dim='{dim}', with invalid type: {type(dim)}\nexpected type: int.\nIf your model has dim_params, consider fixing them or setting custom inputs for this test." + f"input node '{node.name}' has dims={node.shape}. Node dim '{dim}' has invalid type: {type(dim)}\nexpected type: int.\nIf your model has dim_params, consider fixing them or setting custom inputs for this test." ) if dim <= 0: raise ValueError( f"input node '{node.name}' has a non-positive dim: {dim}. Consider setting cutsom inputs for this test." ) + int_dims.append(dim) rng = numpy.random.default_rng(19) if node.type == "tensor(float)": - return rng.random(node.shape).astype(numpy.float32) + return rng.random(int_dims).astype(numpy.float32) if node.type == "tensor(int)" or node.type == "tensor(int32)": - return rng.integers(0, 10000, size=node.shape, dtype=numpy.int32) + return rng.integers(0, 10000, size=int_dims, dtype=numpy.int32) if node.type == "tensor(int8)": - return rng.integers(-127, 128, size=node.shape, dtype=numpy.int8) + return rng.integers(-127, 128, size=int_dims, dtype=numpy.int8) if node.type == "tensor(int64)": - return rng.integers(0, 5, size=node.shape, dtype=numpy.int64) + return rng.integers(0, 5, size=int_dims, dtype=numpy.int64) if node.type == "tensor(bool)": - return rng.integers(0, 2, size=node.shape, dtype=bool) + return rng.integers(0, 2, size=int_dims, dtype=bool) raise NotImplementedError(f"Found an unhandled dtype: {node.type}.") -def get_sample_inputs_for_onnx_model(model_path): +def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None): """A convenience function for generating sample inputs for an onnx model""" - s = onnxruntime.InferenceSession(model_path, None) + opt = onnxruntime.SessionOptions() + opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + s = onnxruntime.InferenceSession(model_path, opt) inputs = s.get_inputs() sample_inputs = TestTensors( - tuple([generate_input_from_node(node) for node in inputs]) + tuple([generate_input_from_node(node, dim_param_dict) for node in inputs]) ) return sample_inputs -def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True): +def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True, dim_param_dict: Optional[dict[str, int]] = None): """A convenience funtion for retrieving the input or output shapes and dtypes""" s = onnxruntime.InferenceSession(model_path, None) if from_inputs: @@ -76,8 +88,13 @@ def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True): return shapes, dtypes -def get_op_frequency(model_path): - model = onnx.load(model_path) +def get_op_frequency(model_or_path): + if isinstance(model_or_path, str) or isinstance(model_or_path, Path): + model = onnx.load(model_or_path) + elif isinstance(model_or_path, onnx.ModelProto): + model = model_or_path + else: + raise TypeError(f'Input argument must be a path, string, or onnx model.') op_freq = dict() for n in model.graph.node: if n.op_type in op_freq: @@ -90,6 +107,9 @@ def get_op_frequency(model_path): def modify_model_output(model: onnx.ModelProto, final_node_key: int) -> onnx.ModelProto: """A helper function to change the output of an onnx model to a new output.""" + if final_node_key < 0: + final_node_key += len(model.graph.node) + final_node = model.graph.node[final_node_key] # clear old outputs @@ -142,6 +162,12 @@ def find_minimal_graph(graph: onnx.GraphProto, top_key: int): def find_node(model: onnx.ModelProto, n: int, op_name: str) -> onnx.NodeProto: """returns the output names for the nth node in the onnx model with op_type given by op_name""" + op_freq = get_op_frequency(model) + N = op_freq[op_name] + if n > N-1 or n < -N: + raise ValueError(f"There are {N} nodes with op name {op_name} in model. Provided index {n} is OOB.\n{op_freq}") + if n < 0: + n += N match_counter = 0 key = -1 for nde in model.graph.node: @@ -152,6 +178,4 @@ def find_node(model: onnx.ModelProto, n: int, op_name: str) -> onnx.NodeProto: node = nde break match_counter += 1 - if not node: - raise ValueError(f"Could not find {n} nodes of type {op_name} in {model}") return key diff --git a/alt_e2eshark/iree_requirements.txt b/alt_e2eshark/iree_requirements.txt index 0f6cb582..8b9773e8 100644 --- a/alt_e2eshark/iree_requirements.txt +++ b/alt_e2eshark/iree_requirements.txt @@ -1,3 +1,5 @@ +--pre +-f https://iree.dev/pip-release-links.html # install nightly build of iree-compiler and iree-runtime -iree-compiler -f https://iree.dev/pip-release-links.html -iree-runtime -f https://iree.dev/pip-release-links.html \ No newline at end of file +iree-compiler +iree-runtime \ No newline at end of file diff --git a/alt_e2eshark/onnx_tests/helper_classes.py b/alt_e2eshark/onnx_tests/helper_classes.py index 942abc3e..75d6fe95 100644 --- a/alt_e2eshark/onnx_tests/helper_classes.py +++ b/alt_e2eshark/onnx_tests/helper_classes.py @@ -27,8 +27,8 @@ def __init__(self, name: str, onnx_model_path: str): self.cache_dir = os.path.join(parent_cache_dir, name) super().__init__(name, onnx_model_path, opset_version) - # def update_sess_options(self): - # self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + def update_sess_options(self): + self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL def construct_model(self): # try to find a .onnx file in the test-run dir @@ -79,7 +79,10 @@ def construct_model(self): if not os.path.exists(self.sibling_inst.model): self.sibling_inst.construct_model() self.model = self.sibling_inst.model - + + def update_dim_param_dict(self): + self.sibling_inst.update_dim_param_dict() + self.dim_param_dict = self.sibling_inst.dim_param_dict def get_sibling_constructor(sibling_class, og_constructor, og_name): """Returns a constructor for the sibling class. Useful for convenient registration. diff --git a/alt_e2eshark/onnx_tests/models/migraphx.py b/alt_e2eshark/onnx_tests/models/migraphx.py new file mode 100644 index 00000000..103edf28 --- /dev/null +++ b/alt_e2eshark/onnx_tests/models/migraphx.py @@ -0,0 +1,158 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..helper_classes import AzureDownloadableModel +from e2e_testing.registry import register_test + +# TODOs: +# 1. just update the opset versions and re-upload to azure. +# 2. get tf models into onnx and upload +# 3. setup dim params for other misc models +# 4. reupload cadence model 1 + +ALL_MODELS = [ + "migraphx_agentmodel__AgentModel", + "migraphx_bert__bert-large-uncased", + "migraphx_bert__bertsquad-12", + "migraphx_cadene__dpn92i1", + "migraphx_cadene__inceptionv4i16", + "migraphx_cadene__resnext101_64x4di1", + "migraphx_cadene__resnext101_64x4di16", + "migraphx_huggingface-transformers__bert_mrpc8", + "migraphx_mlperf__bert_large_mlperf", + "migraphx_mlperf__resnet50_v1", + "migraphx_onnx-misc__taau_low_res_downsample_d2s_for_infer_time_fp16_opset11", + "migraphx_onnx-model-zoo__gpt2-10", + "migraphx_ORT__bert_base_cased_1", + "migraphx_ORT__bert_base_uncased_1", + "migraphx_ORT__bert_large_uncased_1", + "migraphx_ORT__distilgpt2_1", + "migraphx_ORT__onnx_models__bert_base_cased_1_fp16_gpu", + "migraphx_ORT__onnx_models__bert_large_uncased_1_fp16_gpu", + "migraphx_ORT__onnx_models__distilgpt2_1_fp16_gpu", + "migraphx_pytorch-examples__wlang_gru", + "migraphx_pytorch-examples__wlang_lstm", + "migraphx_sd__unet__model", + "migraphx_sdxl__unet__model", + "migraphx_torchvision__densenet121i32", + "migraphx_torchvision__inceptioni1", + "migraphx_torchvision__inceptioni32", + "migraphx_torchvision__resnet50i1", + "migraphx_torchvision__resnet50i64", +] + + +def dim_param_constructor(dim_param_dict): + class AzureWithDimParams(AzureDownloadableModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if ( + self.name == "migraphx_sd__unet__model" + or self.name == "migraphx_sdxl__unet__model" + ): + # trying to update opset version seems to cause a crash or other issues. + self.opset_version = None + # even with the following, ort fails to allocate memory with default session options: + self.sess_options.add_session_config_entry( + "use_device_allocator_for_initializers", "1" + ) + self.update_opset_version_and_overwrite() + + def update_dim_param_dict(self): + self.dim_param_dict = dim_param_dict + + return AzureWithDimParams + + +ORT_model_names = [ + "migraphx_ORT__bert_base_cased_1", # batch_size, seq_len + "migraphx_ORT__bert_base_uncased_1", # batch_size, seq_len + # the following test currently crashes for some reason (maybe opset version related?) + # "migraphx_ORT__bert_large_uncased_1", # batch_size, seq_len + "migraphx_ORT__distilgpt2_1", # batch_size, seq_len + "migraphx_ORT__onnx_models__bert_base_cased_1_fp16_gpu", # batch_size, seq_len + "migraphx_ORT__onnx_models__bert_large_uncased_1_fp16_gpu", # batch_size, seq_len + "migraphx_ORT__onnx_models__distilgpt2_1_fp16_gpu", # batch_size, seq_len +] + +llm_dict_0 = {"batch_size": 1, "seq_len": 128} +for name in ORT_model_names: + register_test(dim_param_constructor(llm_dict_0), name) + +static_dim_model_names = [ + "migraphx_bert__bert-large-uncased", # need to specify input range for indices input [-2,1] + "migraphx_cadene__dpn92i1", # need to give names to nodes??? did this locally, need to reupload + "migraphx_cadene__inceptionv4i16", + "migraphx_cadene__resnext101_64x4di1", + "migraphx_cadene__resnext101_64x4di16", + "migraphx_onnx-misc__taau_low_res_downsample_d2s_for_infer_time_fp16_opset11", # fp16 resize issue + "migraphx_pytorch-examples__wlang_gru", + "migraphx_pytorch-examples__wlang_lstm", # also needs node names + "migraphx_torchvision__densenet121i32", + "migraphx_torchvision__inceptioni1", + "migraphx_torchvision__inceptioni32", + "migraphx_torchvision__resnet50i1", + "migraphx_torchvision__resnet50i64", + "migraphx_huggingface-transformers__bert_mrpc8", # need to specify input range for indices input [-2,1] +] + +for name in static_dim_model_names: + register_test(dim_param_constructor(None), name) + +misc_models = { + "migraphx_agentmodel__AgentModel": {"batch": 1}, + "migraphx_bert__bertsquad-12": { + "unk__492": 1, + "unk__493": 1, + "unk__494": 1, + "unk__495": 1, + }, + "migraphx_mlperf__bert_large_mlperf": { + "batch_size": 1 + }, # need to specify input range for indices input [-2,1] + "migraphx_mlperf__resnet50_v1": {"unk__616": 1}, + "migraphx_onnx-model-zoo__gpt2-10": { + "input1_dynamic_axes_1": 1, + "input1_dynamic_axes_2": 1, + "input1_dynamic_axes_3": 1, + }, + "migraphx_sd__unet__model": { + "batch": 1, + "channels": 4, + "height": 512, + "width": 512, + "sequence": 64, + }, + "migraphx_models__whisper-tiny-decoder" : {"batch_size" : 1, "decoder_sequence_length" : 64, "encoder_sequence_length / 2" : 32}, + "migraphx_models__whisper-tiny-encoder" : {"batch_size" : 1, "feature_size" : 80, "encoder_sequence_length" : 64}, + # this one crashes for some reason... + # "migraphx_sdxl__unet__model" : {"batch_size" : 1, "num_channels" : 4, "height" : 512, "width" : 512, "steps" : 2, "sequence_length" : 64} +} + +for key, dim_param in misc_models.items(): + register_test(dim_param_constructor(dim_param), key) + + +### -------------------------------- ### +# Truncated Model Tests # +### -------------------------------- ### + +# some smaller repros for failed to legalize cmd.stream.dispatch: + +need_repro_dict = { + "migraphx_ORT__bert_base_cased_1" : ["cased" , 4, "MatMul"], + "migraphx_ORT__bert_base_uncased_1" : ["uncased", 1, "Transpose"], + "migraphx_ORT__distilgpt2_1" : ["gpt", 3, "Add"], + "migraphx_ORT__onnx_models__distilgpt2_1_fp16_gpu" : ["gptf16", 3, "Add"], + "migraphx_onnx-model-zoo__gpt2-10" : ["gpt2_10", 0, "NonZero"], +} + +from ..helper_classes import TruncatedModel, get_trucated_constructor + +trunc_const = lambda key : get_trucated_constructor(TruncatedModel, dim_param_constructor(llm_dict_0), key) + +for (key, value) in need_repro_dict.items(): + register_test(trunc_const(key)(value[1], value[2]), f"mi_trunc_{value[0]}_{value[1]}_{value[2]}") diff --git a/alt_e2eshark/onnx_tests/models/model.py b/alt_e2eshark/onnx_tests/models/model.py index 71e31a6e..60000412 100644 --- a/alt_e2eshark/onnx_tests/models/model.py +++ b/alt_e2eshark/onnx_tests/models/model.py @@ -7,4 +7,5 @@ from .azure_models import * from .opt_models import * from .vision_models import * -from .deeplab import * \ No newline at end of file +from .deeplab import * +from .migraphx import * \ No newline at end of file diff --git a/alt_e2eshark/requirements.txt b/alt_e2eshark/requirements.txt deleted file mode 100644 index 16bff716..00000000 --- a/alt_e2eshark/requirements.txt +++ /dev/null @@ -1,18 +0,0 @@ -#install nightly build of torch_mlir, if on Linux (no macOS or Windows nightly builds) --f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels -torch-mlir ; sys_platform == "linux" -# install nightly build of iree-compiler and iree-runtime -iree-compiler -f https://iree.dev/pip-release-links.html -iree-runtime -f https://iree.dev/pip-release-links.html -tabulate -simplejson -ml_dtypes -onnx -onnxruntime -transformers -huggingface-hub -sentencepiece -accelerate -auto-gptq -optimum -azure-storage-blob \ No newline at end of file diff --git a/alt_e2eshark/run.py b/alt_e2eshark/run.py index fcab1728..4cbff4ec 100644 --- a/alt_e2eshark/run.py +++ b/alt_e2eshark/run.py @@ -28,13 +28,15 @@ # import backends from e2e_testing.backends import SimpleIREEBackend, OnnxrtIreeEpBackend +from utils.report import generate_report ALL_STAGES = [ "setup", - "native_inference", "import_model", "preprocessing", "compilation", + "construct_inputs", + "native_inference", "compiled_inference", "postprocessing", ] @@ -72,7 +74,7 @@ def main(args): if args.mode == "onnx-iree": pipeline = REDUCE_TO_LINALG_PIPELINE if args.torchtolinalg else [] config = OnnxTestConfig( - str(TEST_DIR), SimpleIREEBackend(device=args.device, hal_target_backend=args.backend), pipeline + str(TEST_DIR), SimpleIREEBackend(device=args.device, hal_target_backend=args.backend, extra_args=args.iree_compile_args), pipeline ) elif args.mode == "ort-ep": # TODO: allow specifying provider explicitly from cl args. @@ -92,30 +94,33 @@ def main(args): if args.skip_stages: stages = [s for s in stages if s not in args.skip_stages] + parent_log_dir = os.path.join(TEST_DIR, args.rundirectory) - run_tests( + status_dict = run_tests( test_list, config, - args.rundirectory, + parent_log_dir, args.no_artifacts, args.verbose, stages, args.load_inputs ) + if args.report: + generate_report(args, stages, test_list, status_dict) + def run_tests( - test_list: List[Test], config: TestConfig, dir_name: str, no_artifacts: bool, verbose: bool, stages: List[str], load_inputs: bool -): - """runs tests in test_list based on config""" + test_list: List[Test], config: TestConfig, parent_log_dir: str, no_artifacts: bool, verbose: bool, stages: List[str], load_inputs: bool +) -> Dict[str, str]: + """runs tests in test_list based on config. Returns a dictionary containing the test statuses.""" # TODO: multi-process # TODO: setup exception handling and better logging # TODO: log command-line reproducers for each step # set up a parent log directory to store results - parent_log_dir = str(TEST_DIR) + "/" + dir_name + "/" if not os.path.exists(parent_log_dir): - os.mkdir(parent_log_dir) + os.makedirs(parent_log_dir) num_passes = 0 warnings.filterwarnings("ignore") @@ -124,15 +129,17 @@ def run_tests( print(f"Stages to be run: {stages}") print(f'Test list: {[test.unique_name for test in test_list]}') + status_dict = dict() + for t in test_list: if verbose: print(f"running test {t.unique_name}...") # set log directory for the individual test - log_dir = parent_log_dir + t.unique_name + "/" + log_dir = os.path.join(parent_log_dir, t.unique_name) + "/" if not os.path.exists(log_dir): - os.mkdir(log_dir) + os.makedirs(log_dir) try: # TODO: convert staging to an Enum and figure out how to specify staging from args @@ -143,23 +150,15 @@ def run_tests( if curr_stage in stages: # build an instance of the test info class inst = t.model_constructor(t.unique_name, log_dir) - # generate inputs from the test info instance - if load_inputs: - inputs = inst.load_inputs(log_dir) - else: - inputs = inst.construct_inputs() - inputs.save_to(log_dir + "input") - - # run native inference - curr_stage = "native_inference" - if curr_stage in stages: - golden_outputs_raw = inst.forward(inputs) - golden_outputs_raw.save_to(log_dir + "golden_output") - + # this is highly onnx specific. + # TODO: Figure out how to factor this out of run.py + if not os.path.exists(inst.model): + inst.construct_model() + + artifact_save_to = None if no_artifacts else log_dir # generate mlir from the instance using the config curr_stage = "import_model" if curr_stage in stages: - artifact_save_to = None if no_artifacts else log_dir model_artifact, func_name = config.import_model( inst, save_to=artifact_save_to ) @@ -176,6 +175,21 @@ def run_tests( if curr_stage in stages: compiled_artifact = config.compile(model_artifact, save_to=artifact_save_to) + # get inputs from inst + curr_stage = "construct_inputs" + if curr_stage in stages: + if load_inputs: + inputs = inst.load_inputs(log_dir) + else: + inputs = inst.construct_inputs() + inputs.save_to(log_dir + "input") + + # run native inference + curr_stage = "native_inference" + if curr_stage in stages: + golden_outputs_raw = inst.forward(inputs) + golden_outputs_raw.save_to(log_dir + "golden_output") + # run inference with the compiled module curr_stage = "compiled_inference" if curr_stage in stages: @@ -191,6 +205,7 @@ def run_tests( inst.save_processed_output(outputs, log_dir, "output") except Exception as e: + status_dict[t.unique_name] = curr_stage log_exception(e, log_dir, curr_stage, t.unique_name, verbose) continue @@ -205,18 +220,26 @@ def run_tests( ) # log the results test_passed = log_result(result, log_dir, [1e-3, 1e-3]) - num_passes += int(test_passed) - if verbose: - to_print = "\tPASS" if test_passed else "\tFAILED (Numerics)" - print(to_print) - elif not test_passed: - print(f"FAILED: {t.unique_name}") + if test_passed: + status_dict[t.unique_name] = "PASS" + num_passes+=1 + else: + status_dict[t.unique_name] = "Numerics" except Exception as e: + status_dict[inst.name] = "results-summary" log_exception(e, log_dir, "results-summary", t.unique_name, verbose) + + if verbose: + if t.unique_name not in status_dict.keys() or status_dict[t.unique_name] == "PASS": + print(f"\tPASSED") + else: + print(f"\tFAILED ({status_dict[t.unique_name]})") print("\nTest Summary:") print(f"\tPASSES: {num_passes}\n\tTOTAL: {len(test_list)}") print(f"results stored in {parent_log_dir}") + status_dict = dict(sorted(status_dict.items(), key=lambda item : item[0].lower())) + return status_dict def log_result(result, log_dir, tol): @@ -259,9 +282,8 @@ def _get_argparse(): parser.add_argument( "-d", "--device", - choices=["local-task","local-sync","vulkan","hip","cuda"], default="local-task", - help="specifies the device for runtime config", + help="specifies the device for runtime config. E.g. local-task, local-sync, vulkan, hip, cuda", ) parser.add_argument( "-b", @@ -270,6 +292,13 @@ def _get_argparse(): default="llvm-cpu", help="specifies the iree-hal-target-backend for compile phase", ) + parser.add_argument( + "-ica", + "--iree-compile-args", + nargs="*", + default = None, + help="Manually specify a space-seperated list of extra args for iree-compile. Do not put `--` before the args.", + ) # parser.add_argument( # "-f", # "--framework", @@ -363,6 +392,17 @@ def _get_argparse(): action="store_true", default=False, ) + parser.add_argument( + "--report", + action="store_true", + default=False, + help="Generate test report summary", + ) + parser.add_argument( + "--report-file", + default="report.md", + help="output filename for the report summary.", + ) # parser.add_argument( # "-d", # "--todtype", @@ -384,18 +424,6 @@ def _get_argparse(): # help="Skip running of tests. Useful for generating test summary after the run", # ) # parser.add_argument( - # "--report", - # action="store_true", - # default=False, - # help="Generate test report summary", - # ) - # parser.add_argument( - # "--reportformat", - # choices=["pipe", "github", "html", "csv"], - # default="pipe", - # help="Format of the test report summary file. It takes subset of tablefmt value of python tabulate", - # ) - # parser.add_argument( # "--uploadtestsfile", # help="A file with lists of tests that should be uploaded", # ) diff --git a/alt_e2eshark/torch_mlir_requirements.txt b/alt_e2eshark/torch_mlir_requirements.txt index a678680f..529ba0ee 100644 --- a/alt_e2eshark/torch_mlir_requirements.txt +++ b/alt_e2eshark/torch_mlir_requirements.txt @@ -1,3 +1,4 @@ -#install nightly build of torch_mlir, if on Linux (no macOS or Windows nightly builds) +--pre -f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels +#install nightly build of torch_mlir, if on Linux (no macOS or Windows nightly builds) torch-mlir ; sys_platform == "linux" \ No newline at end of file diff --git a/alt_e2eshark/utils/report.py b/alt_e2eshark/utils/report.py new file mode 100644 index 00000000..51318c75 --- /dev/null +++ b/alt_e2eshark/utils/report.py @@ -0,0 +1,36 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +def generate_report(args, stages, test_list, status_dict): + """generates a markdown report for a test-run""" + + # set up report summary + stages.append("results-summary") + stages.append("Numerics") + stages.append("PASS") + stages.reverse() + counts = {s : 0 for s in stages} + for (key, value) in status_dict.items(): + counts[value] += 1 + results_str = "## Summary\n\n|Stage|Count|\n|--|--|\n" + results_str += f"| Total | {len(test_list)} |\n" + for (key, value) in counts.items(): + results_str += f"| {key} | {value} |\n" + + # set up report detail + report_string = f"\n## Test Run Detail \n Test was run with the following arguments:\n{args}\n\n" + report_string += "| Test | Exit Status | Notes |\n" + report_string += "|--|--|--|\n" + for (key, value) in status_dict.items(): + report_string += f"| {key} | {value} | |\n" + + # get a report file and write to it + report_file = "report.md" + if args.report_file: + report_file = args.report_file + with open(report_file, "w") as file: + file.write(results_str) + file.write(report_string) \ No newline at end of file diff --git a/alt_e2eshark/utils/write_env.py b/alt_e2eshark/utils/write_env.py new file mode 100644 index 00000000..e8d6cf5a --- /dev/null +++ b/alt_e2eshark/utils/write_env.py @@ -0,0 +1,82 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +import argparse + +def _get_argparse(): + msg = "A script for setting up a .env file." + parser = argparse.ArgumentParser(prog="write_env.py", description=msg, epilog="") + + parser.add_argument( + "-i", + "--iree-build", + help="specify path to iree-build", + ) + parser.add_argument( + "-t", + "--torch-mlir-build", + help="specify path to torch-mlir/build", + ) + parser.add_argument( + "-c", + "--cache", + help="specify path to cache directory for downloading large models (e.g., '/home/username/.cache')", + ) + parser.add_argument( + "-a", + "--azure-private-connection", + help="specify azure-private-connection string for onnxprivatestorage", + ) + return parser + +def test_path(path: Path): + if not path.exists(): + raise OSError(f'path: {path.absolute()} could not be resolved') + +def main(args): + s = "" + pypaths = [] + + if args.iree_build: + iree_build_dir = Path(args.iree_build).resolve() + test_path(iree_build_dir) + compiler_bindings = iree_build_dir.joinpath("compiler/bindings/python") + runtime_bindings = iree_build_dir.joinpath("runtime/bindings/python") + test_path(compiler_bindings) + test_path(runtime_bindings) + pypaths.append(str(compiler_bindings)) + pypaths.append(str(runtime_bindings)) + + if args.torch_mlir_build: + torch_mlir_build_dir = Path(args.torch_mlir_build).resolve() + test_path(torch_mlir_build_dir) + torch_mlir_bindings = torch_mlir_build_dir.joinpath("tools/torch-mlir/python_packages/torch_mlir/") + test_path(torch_mlir_bindings) + pypaths.append(str(torch_mlir_bindings)) + + if args.cache: + cache_dir = Path(args.cache).resolve() + test_path(cache_dir) + s += f"CACHE_DIR='{cache_dir}'\n" + + if args.azure_private_connection: + s += f'AZ_PRIVATE_CONNECTION="{args.azure_private_connection}"\n' + + if len(pypaths) > 0: + pypathstr = ":".join(pypaths) + s += f'PYTHONPATH="{pypathstr}"\n' + + if len(s) > 0: + with open(".env", "w") as file: + file.write(s) + + print("Check .env and run this script to export the variables to your environment (linux):") + print("export $(cat .env | xargs)") + +if __name__ == "__main__": + parser = _get_argparse() + main(parser.parse_args()) \ No newline at end of file