Skip to content

Commit

Permalink
[alt] [WIP] setup MiGraphX models (nod-ai#323)
Browse files Browse the repository at this point in the history
Still some work to do for getting some of these models working. Will
update more today.
  • Loading branch information
zjgarvey authored Aug 20, 2024
1 parent a38e6c7 commit a555bf1
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 94 deletions.
39 changes: 32 additions & 7 deletions alt_e2eshark/e2e_testing/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import abc
import onnxruntime as ort
from typing import TypeVar
from typing import TypeVar, List
from e2e_testing.storage import TestTensors
from e2e_testing.framework import CompiledOutput, ModelArtifact
from onnx import ModelProto
Expand All @@ -30,20 +30,45 @@ def load(self, artifact: CompiledOutput, func_name: str) -> Invoker:

class SimpleIREEBackend(BackendBase):
'''This backend uses iree to compile and run MLIR modules for a specified hal_target_backend'''
def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu"):
def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", extra_args : List[str] = None):
self.device = device
self.hal_target_backend = hal_target_backend
if extra_args:
self.extra_args = []
for a in extra_args:
if a[0:2] == "--":
self.extra_args.append(a)
else:
self.extra_args.append("--" + a)
elif hal_target_backend == "rocm":
# some extra args for Mi300x - some of these may not work for other chips
self.extra_args = [
"--iree-rocm-target-chip=gfx942",
# "--iree-global-opt-propagate-transposes=true",
# "--iree-opt-outer-dim-concat=true",
# "--iree-opt-const-eval=false",
# "--iree-rocm-waves-per-eu=2",
# "--iree-llvmgpu-enable-prefetch",
# "--iree-flow-enable-aggressive-fusion",
# "--iree-flow-enable-fuse-horizontal-contractions=true",
# "--iree-opt-aggressively-propagate-transposes=true",
# "--iree-codegen-llvmgpu-use-vector-distribution=true",
# "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
# maybe add iree-preprocessing-transpose-convolution-pipeline to preprocessing pipeline.
]
elif hal_target_backend == "llvm-cpu":
self.extra_args = [
"--iree-input-demote-i64-to-i32",
# "--iree-llvmcpu-fail-on-large-vector=0",
# "--iree-llvmcpu-stack-allocation-limit=300000",
]

def compile(self, module, *, save_to: str = None):
# compile to a vmfb for llvm-cpu
b = ireec.tools.compile_str(
str(module),
target_backends=[self.hal_target_backend],
extra_args=[
"--iree-input-demote-i64-to-i32",
# "--iree-llvmcpu-fail-on-large-vector=0",
# "--iree-llvmcpu-stack-allocation-limit=300000",
],
extra_args=self.extra_args,
)
# log the vmfb
if save_to:
Expand Down
25 changes: 23 additions & 2 deletions alt_e2eshark/e2e_testing/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self.model = os.path.join(onnx_model_path, "model.onnx")
self.opset_version = opset_version
self.sess_options = ort.SessionOptions()
self.dim_param_dict = None

def forward(self, input: Optional[TestTensors] = None) -> TestTensors:
"""Applies self.model to self.input. Only override if necessary for specific models"""
Expand All @@ -55,6 +56,12 @@ def update_sess_options(self):
"""
pass

def update_dim_param_dict(self):
"""Can be overridden to modify a dictionary of dim parameters (self.dim_param_dict) used to
construct inputs for a model with dynamic dims.
"""
pass

def construct_model(self):
"""a method to be overwritten. To make a new test, define a subclass with an override for this method"""
raise NotImplementedError(
Expand All @@ -65,22 +72,26 @@ def construct_inputs(self):
"""can be overridden to generate specific inputs, but a default is provided for convenience"""
if not os.path.exists(self.model):
self.construct_model()
return get_sample_inputs_for_onnx_model(self.model)
self.update_dim_param_dict()
# print(self.get_signature())
# print(get_op_frequency(self.model))
return get_sample_inputs_for_onnx_model(self.model, self.dim_param_dict)

def apply_postprocessing(self, output: TestTensors):
"""can be overridden to define post-processing methods for individual models"""
return output

def save_processed_output(self, output: TestTensors, save_to: str, name: str):
"""can be overridden to provide instructions on saving processed outputs (e.g., images, labels, text)"""
pass

# the following helper methods aren't meant to be overriden

def get_signature(self, *, from_inputs=True):
"""Returns the input or output signature of self.model"""
if not os.path.exists(self.model):
self.construct_model()
return get_signature_for_onnx_model(self.model, from_inputs=from_inputs)
return get_signature_for_onnx_model(self.model, from_inputs=from_inputs, dim_param_dict=self.dim_param_dict)

def load_inputs(self, dir_path):
"""computes the input signature of the onnx model and loads inputs from bin files"""
Expand All @@ -102,6 +113,16 @@ def load_golden_outputs(self, dir_path):
"""computes the input signature of the onnx model and loads golden outputs from bin files"""
shapes, dtypes = self.get_signature(from_inputs=False)
return TestTensors.load_from(shapes, dtypes, dir_path, "golden_output")

def update_opset_version_and_overwrite(self):
if self.opset_version:
if not os.path.exists(self.model):
self.construct_model()
og_model = onnx.load(self.model)
model = onnx.version_converter.convert_version(
og_model, self.opset_version
)
onnx.save(model, self.model)

# TODO: extend TestModel to a union, or make TestModel a base class when supporting other frontends
TestModel = OnnxModelInfo
Expand Down
54 changes: 39 additions & 15 deletions alt_e2eshark/e2e_testing/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import onnxruntime
import torch
from e2e_testing.storage import TestTensors
from typing import Optional
from pathlib import Path


def dtype_from_ort_node(node):
Expand All @@ -26,42 +28,52 @@ def dtype_from_ort_node(node):
raise NotImplementedError(f"Unhandled dtype string found: {dtypestr}")


def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg):
def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None):
"""A convenience function for generating sample inputs for an onnxruntime node"""
int_dims = []
for dim in node.shape:
if isinstance(dim, str) and dim_param_dict:
if not dim in dim_param_dict.keys():
raise ValueError(f"input node {node.name} has a dim param='{dim}' not found in provided dim_param_dict: '{dim_param_dict}'")
else:
int_dims.append(dim_param_dict[dim])
continue
if not isinstance(dim, int):
raise TypeError(
f"input node '{node.name}' has a dim='{dim}', with invalid type: {type(dim)}\nexpected type: int.\nIf your model has dim_params, consider fixing them or setting custom inputs for this test."
f"input node '{node.name}' has dims={node.shape}. Node dim '{dim}' has invalid type: {type(dim)}\nexpected type: int.\nIf your model has dim_params, consider fixing them or setting custom inputs for this test."
)
if dim <= 0:
raise ValueError(
f"input node '{node.name}' has a non-positive dim: {dim}. Consider setting cutsom inputs for this test."
)
int_dims.append(dim)
rng = numpy.random.default_rng(19)
if node.type == "tensor(float)":
return rng.random(node.shape).astype(numpy.float32)
return rng.random(int_dims).astype(numpy.float32)
if node.type == "tensor(int)" or node.type == "tensor(int32)":
return rng.integers(0, 10000, size=node.shape, dtype=numpy.int32)
return rng.integers(0, 10000, size=int_dims, dtype=numpy.int32)
if node.type == "tensor(int8)":
return rng.integers(-127, 128, size=node.shape, dtype=numpy.int8)
return rng.integers(-127, 128, size=int_dims, dtype=numpy.int8)
if node.type == "tensor(int64)":
return rng.integers(0, 5, size=node.shape, dtype=numpy.int64)
return rng.integers(0, 5, size=int_dims, dtype=numpy.int64)
if node.type == "tensor(bool)":
return rng.integers(0, 2, size=node.shape, dtype=bool)
return rng.integers(0, 2, size=int_dims, dtype=bool)
raise NotImplementedError(f"Found an unhandled dtype: {node.type}.")


def get_sample_inputs_for_onnx_model(model_path):
def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None):
"""A convenience function for generating sample inputs for an onnx model"""
s = onnxruntime.InferenceSession(model_path, None)
opt = onnxruntime.SessionOptions()
opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
s = onnxruntime.InferenceSession(model_path, opt)
inputs = s.get_inputs()
sample_inputs = TestTensors(
tuple([generate_input_from_node(node) for node in inputs])
tuple([generate_input_from_node(node, dim_param_dict) for node in inputs])
)
return sample_inputs


def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True):
def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True, dim_param_dict: Optional[dict[str, int]] = None):
"""A convenience funtion for retrieving the input or output shapes and dtypes"""
s = onnxruntime.InferenceSession(model_path, None)
if from_inputs:
Expand All @@ -76,8 +88,13 @@ def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True):
return shapes, dtypes


def get_op_frequency(model_path):
model = onnx.load(model_path)
def get_op_frequency(model_or_path):
if isinstance(model_or_path, str) or isinstance(model_or_path, Path):
model = onnx.load(model_or_path)
elif isinstance(model_or_path, onnx.ModelProto):
model = model_or_path
else:
raise TypeError(f'Input argument must be a path, string, or onnx model.')
op_freq = dict()
for n in model.graph.node:
if n.op_type in op_freq:
Expand All @@ -90,6 +107,9 @@ def get_op_frequency(model_path):
def modify_model_output(model: onnx.ModelProto, final_node_key: int) -> onnx.ModelProto:
"""A helper function to change the output of an onnx model to a new output."""

if final_node_key < 0:
final_node_key += len(model.graph.node)

final_node = model.graph.node[final_node_key]

# clear old outputs
Expand Down Expand Up @@ -142,6 +162,12 @@ def find_minimal_graph(graph: onnx.GraphProto, top_key: int):

def find_node(model: onnx.ModelProto, n: int, op_name: str) -> onnx.NodeProto:
"""returns the output names for the nth node in the onnx model with op_type given by op_name"""
op_freq = get_op_frequency(model)
N = op_freq[op_name]
if n > N-1 or n < -N:
raise ValueError(f"There are {N} nodes with op name {op_name} in model. Provided index {n} is OOB.\n{op_freq}")
if n < 0:
n += N
match_counter = 0
key = -1
for nde in model.graph.node:
Expand All @@ -152,6 +178,4 @@ def find_node(model: onnx.ModelProto, n: int, op_name: str) -> onnx.NodeProto:
node = nde
break
match_counter += 1
if not node:
raise ValueError(f"Could not find {n} nodes of type {op_name} in {model}")
return key
6 changes: 4 additions & 2 deletions alt_e2eshark/iree_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
--pre
-f https://iree.dev/pip-release-links.html
# install nightly build of iree-compiler and iree-runtime
iree-compiler -f https://iree.dev/pip-release-links.html
iree-runtime -f https://iree.dev/pip-release-links.html
iree-compiler
iree-runtime
9 changes: 6 additions & 3 deletions alt_e2eshark/onnx_tests/helper_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self, name: str, onnx_model_path: str):
self.cache_dir = os.path.join(parent_cache_dir, name)
super().__init__(name, onnx_model_path, opset_version)

# def update_sess_options(self):
# self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
def update_sess_options(self):
self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL

def construct_model(self):
# try to find a .onnx file in the test-run dir
Expand Down Expand Up @@ -79,7 +79,10 @@ def construct_model(self):
if not os.path.exists(self.sibling_inst.model):
self.sibling_inst.construct_model()
self.model = self.sibling_inst.model


def update_dim_param_dict(self):
self.sibling_inst.update_dim_param_dict()
self.dim_param_dict = self.sibling_inst.dim_param_dict

def get_sibling_constructor(sibling_class, og_constructor, og_name):
"""Returns a constructor for the sibling class. Useful for convenient registration.
Expand Down
Loading

0 comments on commit a555bf1

Please sign in to comment.