Skip to content

Commit 9215c13

Browse files
authored
Add Flux_1.dev transformer and text encoder tests (nod-ai#411)
The most significant infrastructure change is with some reworks to OnnxModelInfo: 1. `OnnxModelInfo` now has an attribute for `ExtraOptions`, which holds test-specific options for later testing stages beyond setup. This was required to allow running with external params, passing compile flags, and setting custom mlir import options. 2. All attributes configurable via `update_*` methods in `OnnxModelInfo` are updated in the init method immediately. This required reworking a few helper classes which relied on calling these at a specific time. Another infrastructure change is through the migration to using `iree-import-onnx` rather than torch-mlir's onnx importer. This is to allow externalizing parameters. Adds the following three tests: - `flux_1_dev_transformer` - `flux_1_dev_clip` - `flux_1_dev_t5`
1 parent 25ba264 commit 9215c13

File tree

10 files changed

+559
-101
lines changed

10 files changed

+559
-101
lines changed

alt_e2eshark/base_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ ml_dtypes
66
onnx
77
onnxruntime
88
transformers
9+
diffusers
910
huggingface-hub
1011
sentencepiece
1112
accelerate

alt_e2eshark/e2e_testing/backends.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import onnxruntime as ort
88
from typing import TypeVar, List
99
from e2e_testing.storage import TestTensors, get_shape_string
10-
from e2e_testing.framework import CompiledOutput, ModelArtifact
10+
from e2e_testing.framework import CompiledOutput, ModelArtifact, CompilerOptions, RuntimeOptions
1111
from onnx import ModelProto
1212
import os
1313
from pathlib import Path
@@ -18,45 +18,56 @@
1818
class BackendBase(abc.ABC):
1919

2020
@abc.abstractmethod
21-
def compile(self, module: ModelArtifact) -> CompiledOutput:
21+
def compile(self, module: ModelArtifact, extra_options : CompilerOptions) -> CompiledOutput:
2222
"""specifies how to compile an MLIR Module"""
2323

2424
@abc.abstractmethod
25-
def load(self, artifact: CompiledOutput, func_name: str) -> Invoker:
25+
def load(self, artifact: CompiledOutput, func_name: str, extra_options : RuntimeOptions) -> Invoker:
2626
"""loads the function with name func_name from compiled artifact. This method should return a function callable from python."""
2727

2828

2929
from iree import compiler as ireec
3030
from iree import runtime as ireert
3131

3232

33+
def flag(arg : str) -> str:
34+
if arg.startswith("--"):
35+
return arg
36+
return f'--{arg}'
37+
3338
class SimpleIREEBackend(BackendBase):
3439
'''This backend uses iree to compile and run MLIR modules for a specified hal_target_backend'''
3540
def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", extra_args : List[str] = None):
3641
self.device = device
3742
self.hal_target_backend = hal_target_backend
38-
self.extra_args = []
39-
if extra_args:
40-
for a in extra_args:
41-
if a[0:2] == "--":
42-
self.extra_args.append(a)
43-
else:
44-
self.extra_args.append("--" + a)
45-
46-
def compile(self, module, *, save_to: str = None):
43+
self.extra_args = [] if extra_args is None else [flag(a) for a in extra_args]
44+
if hal_target_backend == "rocm":
45+
self.extra_args += [
46+
f"--iree-hip-target={self.target_chip}",
47+
]
48+
if hal_target_backend == "llvm-cpu":
49+
self.extra_args += [
50+
"--iree-llvmcpu-target-cpu=host",
51+
]
52+
53+
def compile(self, module, *, save_to: str = None, extra_options : CompilerOptions):
54+
test_specific_args = list(extra_options.common_extra_args)
55+
if self.hal_target_backend in extra_options.backend_specific_flags.keys():
56+
test_specific_args += list(extra_options.backend_specific_flags[self.hal_target_backend])
57+
compile_args = self.extra_args + [flag(arg) for arg in test_specific_args]
4758
# compile to a vmfb for llvm-cpu
4859
b = ireec.tools.compile_str(
4960
str(module),
5061
target_backends=[self.hal_target_backend],
51-
extra_args=self.extra_args,
62+
extra_args=compile_args,
5263
)
5364
# log the vmfb
5465
if save_to:
5566
with open(os.path.join(save_to, "compiled_model.vmfb"), "wb") as f:
5667
f.write(b)
5768
return b
5869

59-
def load(self, artifact, *, func_name="main"):
70+
def load(self, artifact, *, func_name="main", extra_options : RuntimeOptions):
6071
config = ireert.Config(self.device)
6172
ctx = ireert.SystemContext(config=config)
6273
vm_module = ireert.VmModule.copy_buffer(ctx.instance, artifact)
@@ -80,13 +91,7 @@ def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", target
8091
self.device = device
8192
self.hal_target_backend = hal_target_backend
8293
self.target_chip = target_chip
83-
self.extra_args = []
84-
if extra_args:
85-
for a in extra_args:
86-
if a[0:2] == "--":
87-
self.extra_args.append(a)
88-
else:
89-
self.extra_args.append("--" + a)
94+
self.extra_args = [] if extra_args is None else [flag(a) for a in extra_args]
9095
if hal_target_backend == "rocm":
9196
self.extra_args += [
9297
f"--iree-hip-target={self.target_chip}",
@@ -96,15 +101,17 @@ def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", target
96101
"--iree-llvmcpu-target-cpu=host",
97102
]
98103

99-
def compile(self, module_path: str, *, save_to : str = None) -> str:
104+
def compile(self, module_path: str, *, save_to : str = None, extra_options : CompilerOptions) -> str:
105+
test_specific_args = list(extra_options.common_extra_args)
106+
if self.hal_target_backend in extra_options.backend_specific_flags.keys():
107+
test_specific_args += list(extra_options.backend_specific_flags[self.hal_target_backend])
108+
compile_args = self.extra_args + [flag(arg) for arg in test_specific_args]
100109
vmfb_path = os.path.join(save_to, "compiled_model.vmfb")
101110
arg_string = f"--iree-hal-target-backends={self.hal_target_backend} "
102-
for arg in self.extra_args:
103-
arg_string += arg
104-
arg_string += " "
111+
arg_string += ' '.join(compile_args)
105112
detail_log = os.path.join(save_to, "detail", "compilation.detail.log")
106113
commands_log = os.path.join(save_to, "commands", "compilation.commands.log")
107-
script = f"iree-compile {module_path} {arg_string}-o {vmfb_path} 1> {detail_log} 2>&1"
114+
script = f"iree-compile {module_path} {arg_string} -o {vmfb_path} 1> {detail_log} 2>&1"
108115
with open(commands_log, "w") as file:
109116
file.write(script)
110117
# remove old vmfb if it exists
@@ -116,16 +123,21 @@ def compile(self, module_path: str, *, save_to : str = None) -> str:
116123
raise FileNotFoundError(error_msg)
117124
return vmfb_path
118125

119-
def load(self, vmfb_path: str, *, func_name=None):
126+
def load(self, vmfb_path: str, *, func_name=None, extra_options : RuntimeOptions):
120127
"""A bit hacky. func returns a script that would dump outputs to terminal output. Modified in config.run method"""
128+
test_specific_args = list(extra_options.common_extra_args)
129+
if self.hal_target_backend in extra_options.backend_specific_flags.keys():
130+
test_specific_args += list(extra_options.backend_specific_flags[self.hal_target_backend])
121131
run_dir = Path(vmfb_path).parent
122132
def func(x: TestTensors) -> str:
123-
script = f"iree-run-module --module='{vmfb_path}' --device={self.device}"
133+
script = f"iree-run-module --module='{vmfb_path}' --device={self.device} "
134+
for arg in test_specific_args:
135+
script += f'{flag(arg)} '
124136
if func_name:
125-
script += f" --function='{func_name}'"
137+
script += f"--function='{func_name}' "
126138
torch_inputs = x.to_torch().data
127139
for index, input in enumerate(torch_inputs):
128-
script += f" --input='{get_shape_string(input)}=@{run_dir}/input.{index}.bin'"
140+
script += f"--input='{get_shape_string(input)}=@{run_dir}/input.{index}.bin' "
129141
return script
130142
return func
131143

@@ -135,16 +147,10 @@ class OnnxrtIreeEpBackend(BackendBase):
135147
def __init__(self, *, device="local-task", hal_target_device="llvm-cpu", extra_args : List[str] = None):
136148
self.device = device
137149
self.hal_target_device = hal_target_device
138-
if extra_args:
139-
self.extra_args = []
140-
for a in extra_args:
141-
if a[0:2] == "--":
142-
self.extra_args.append(a)
143-
else:
144-
self.extra_args.append("--" + a)
145-
elif hal_target_device == "hip":
150+
self.extra_args = [] if extra_args is None else [flag(a) for a in extra_args]
151+
if hal_target_device == "hip":
146152
# some extra args for Mi250 - some of these may not work for other chips
147-
self.extra_args = [
153+
self.extra_args += [
148154
"--iree-hip-target=gfx90a",
149155
]
150156
self.providers = ["IreeExecutionProvider"]
@@ -159,7 +165,7 @@ def __init__(self, *, device="local-task", hal_target_device="llvm-cpu", extra_a
159165
# sess_opt.log_verbosity_level = 0
160166
# self.sess_opt.log_severity_level = 0
161167

162-
def compile(self, model: ModelProto, *, save_to: str = None) -> ort.InferenceSession:
168+
def compile(self, model: ModelProto, *, save_to: str = None, extra_options : CompilerOptions) -> ort.InferenceSession:
163169
if self.provider_options:
164170
provider_options_dict = self.provider_options[0]
165171
provider_options_dict["save_to"] = save_to
@@ -173,7 +179,7 @@ def compile(self, model: ModelProto, *, save_to: str = None) -> ort.InferenceSes
173179
# can't save an onnx runtime session
174180
return session
175181

176-
def load(self, session: ort.InferenceSession, *, func_name=None) -> Invoker:
182+
def load(self, session: ort.InferenceSession, *, func_name=None, extra_options : RuntimeOptions) -> Invoker:
177183
def func(x: TestTensors):
178184
data = x.to_numpy().data
179185
session_inputs = session.get_inputs()

alt_e2eshark/e2e_testing/framework.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,39 @@
88
import abc
99
import os
1010
from pathlib import Path
11-
from typing import Union, TypeVar, Tuple, NamedTuple, Dict, Optional, Callable
11+
from typing import Union, TypeVar, Tuple, NamedTuple, Dict, Optional, Callable, List
1212
from e2e_testing.storage import TestTensors
1313
from e2e_testing.onnx_utils import *
1414

1515
# This file two types of classes: framework-specific base classes for storing model info, and generic classes for testing infrastructure.
1616

1717
Module = TypeVar("Module")
1818

19+
class ImporterOptions(NamedTuple):
20+
opset_version : Optional[int] = None
21+
large_model : bool = False
22+
externalize_params : bool = False
23+
externalize_inputs_threshold : Optional[int] = None
24+
num_elements_threshold: int = 100
25+
params_scope : str = "model"
26+
param_gb_threshold : Optional[float] = None
27+
28+
class CompilerOptions(NamedTuple):
29+
"""Specify, for specific iree-hal-target-backends, a tuple of extra compiler flags.
30+
Also allows backend-agnostic options to be included."""
31+
backend_specific_flags : Dict[str, Tuple[str]] = dict()
32+
common_extra_args : Tuple[str] = tuple()
33+
34+
class RuntimeOptions(NamedTuple):
35+
"""Specify, for specific iree-hal-target-backends, a tuple of extra runtime flags.
36+
Also allows backend-agnostic options to be included."""
37+
backend_specific_flags : Dict[str, Tuple[str]] = dict()
38+
common_extra_args : Tuple[str] = tuple()
39+
40+
class ExtraOptions(NamedTuple):
41+
import_model_options : ImporterOptions = ImporterOptions()
42+
compilation_options : CompilerOptions = CompilerOptions()
43+
compiled_inference_options : RuntimeOptions = RuntimeOptions()
1944

2045
class OnnxModelInfo:
2146
"""Stores information about an onnx test: the filepath to model.onnx, how to construct/download it, and how to construct sample inputs for a test run."""
@@ -29,16 +54,21 @@ def __init__(
2954
self.name = name
3055
self.model = os.path.join(onnx_model_path, "model.onnx")
3156
self.opset_version = opset_version
32-
self.sess_options = ort.SessionOptions()
57+
3358
self.dim_param_dict = None
59+
self.update_dim_param_dict()
3460
self.input_name_to_shape_map = None
61+
self.update_input_name_to_shape_map()
62+
self.sess_options = ort.SessionOptions()
63+
self.update_sess_options()
64+
self.extra_options = ExtraOptions()
65+
self.update_extra_options()
3566

3667
def forward(self, input: Optional[TestTensors] = None) -> TestTensors:
3768
"""Applies self.model to self.input. Only override if necessary for specific models"""
3869
input = input.to_numpy().data
3970
if not os.path.exists(self.model):
4071
self.construct_model()
41-
self.update_sess_options()
4272
session = ort.InferenceSession(self.model, self.sess_options)
4373
session_inputs = session.get_inputs()
4474
session_outputs = session.get_outputs()
@@ -50,23 +80,27 @@ def forward(self, input: Optional[TestTensors] = None) -> TestTensors:
5080

5181
return TestTensors(model_output)
5282

53-
def update_sess_options(self):
54-
"""Can be overridden to modify session options (self.sess_options) for gold inference.
55-
It is sometimes useful to disable all optimizations, which can be done with:
56-
self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
57-
"""
58-
pass
59-
6083
def update_dim_param_dict(self):
6184
"""Can be overridden to modify a dictionary of dim parameters (self.dim_param_dict) used to
6285
construct inputs for a model with dynamic dims.
6386
"""
6487
pass
6588

66-
def contruct_input_name_to_shape_map(self):
89+
def update_input_name_to_shape_map(self):
6790
"""Can be overriden to construct an assocation map between the name of the input nodes and their shapes."""
6891
pass
6992

93+
def update_sess_options(self):
94+
"""Can be overridden to modify session options (self.sess_options) for gold inference.
95+
It is sometimes useful to disable all optimizations, which can be done with:
96+
self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
97+
"""
98+
pass
99+
100+
def update_extra_options(self):
101+
"""Can be overridden to set self.extra_options = ExtraOptions(**kwargs)"""
102+
pass
103+
70104
def construct_model(self):
71105
"""a method to be overwritten. To make a new test, define a subclass with an override for this method"""
72106
raise NotImplementedError(
@@ -151,7 +185,7 @@ def get_metadata(self):
151185
class TestConfig(abc.ABC):
152186

153187
@abc.abstractmethod
154-
def import_model(self, program: TestModel, *, save_to: str) -> Tuple[ModelArtifact, str | None]:
188+
def import_model(self, program: TestModel, *, save_to: str, extra_options : ImporterOptions) -> Tuple[ModelArtifact, str | None]:
155189
"""imports the test model to model artifact (e.g., loads the onnx model )"""
156190
pass
157191

@@ -161,16 +195,16 @@ def preprocess_model(self, model_artifact: ModelArtifact, *, save_to: str) -> Mo
161195
pass
162196

163197
@abc.abstractmethod
164-
def compile(self, module: ModelArtifact, *, save_to: str) -> CompiledOutput:
198+
def compile(self, module: ModelArtifact, *, save_to: str, extra_options : CompilerOptions) -> CompiledOutput:
165199
"""converts the test program to a compiled artifact"""
166200
pass
167201

168202
@abc.abstractmethod
169-
def run(self, artifact: CompiledOutput, input: TestTensors) -> TestTensors:
203+
def run(self, artifact: CompiledOutput, input: TestTensors, extra_options : RuntimeOptions) -> TestTensors:
170204
"""runs the input through the compiled artifact"""
171205
pass
172206

173-
def benchmark(self, artifact: CompiledOutput, input: TestTensors, repetitions: int, *, func_name=None) -> float:
207+
def benchmark(self, artifact: CompiledOutput, input: TestTensors, repetitions: int, *, func_name=None, extra_options : RuntimeOptions) -> float:
174208
"""returns a float representing inference time in ms"""
175209
pass
176210

alt_e2eshark/e2e_testing/storage.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,14 @@ def to_dtype(self, dtype, *, index: Optional[int] = None) -> "TestTensors":
171171
"""returns a copy of self with a converted dtype (at a particular index, if specified)"""
172172
if self.type == numpy.ndarray:
173173
if index:
174-
try:
175-
new_data = self.data
176-
new_data[index] = new_data[index].astype(dtype)
177-
except Exception as e:
178-
print("to_dtype failed due to excepton {e}.")
174+
new_data = self.data
175+
new_data[index] = new_data[index].astype(dtype)
179176
else:
180177
new_data = tuple([d.astype(dtype) for d in self.data])
181178
if self.type == torch.Tensor:
182179
if index:
183-
try:
184-
new_data = self.data
185-
new_data[index] = new_data[index].to(dtype=dtype)
186-
except Exception as e:
187-
print("to_dtype failed due to excepton {e}.")
180+
new_data = self.data
181+
new_data[index] = new_data[index].to(dtype=dtype)
188182
else:
189183
new_data = tuple([d.to(dtype=dtype) for d in self.data])
190184
return TestTensors(new_data)

0 commit comments

Comments
 (0)