7
7
import onnxruntime as ort
8
8
from typing import TypeVar , List
9
9
from e2e_testing .storage import TestTensors , get_shape_string
10
- from e2e_testing .framework import CompiledOutput , ModelArtifact
10
+ from e2e_testing .framework import CompiledOutput , ModelArtifact , CompilerOptions , RuntimeOptions
11
11
from onnx import ModelProto
12
12
import os
13
13
from pathlib import Path
18
18
class BackendBase (abc .ABC ):
19
19
20
20
@abc .abstractmethod
21
- def compile (self , module : ModelArtifact ) -> CompiledOutput :
21
+ def compile (self , module : ModelArtifact , extra_options : CompilerOptions ) -> CompiledOutput :
22
22
"""specifies how to compile an MLIR Module"""
23
23
24
24
@abc .abstractmethod
25
- def load (self , artifact : CompiledOutput , func_name : str ) -> Invoker :
25
+ def load (self , artifact : CompiledOutput , func_name : str , extra_options : RuntimeOptions ) -> Invoker :
26
26
"""loads the function with name func_name from compiled artifact. This method should return a function callable from python."""
27
27
28
28
29
29
from iree import compiler as ireec
30
30
from iree import runtime as ireert
31
31
32
32
33
+ def flag (arg : str ) -> str :
34
+ if arg .startswith ("--" ):
35
+ return arg
36
+ return f'--{ arg } '
37
+
33
38
class SimpleIREEBackend (BackendBase ):
34
39
'''This backend uses iree to compile and run MLIR modules for a specified hal_target_backend'''
35
40
def __init__ (self , * , device = "local-task" , hal_target_backend = "llvm-cpu" , extra_args : List [str ] = None ):
36
41
self .device = device
37
42
self .hal_target_backend = hal_target_backend
38
- self .extra_args = []
39
- if extra_args :
40
- for a in extra_args :
41
- if a [0 :2 ] == "--" :
42
- self .extra_args .append (a )
43
- else :
44
- self .extra_args .append ("--" + a )
45
-
46
- def compile (self , module , * , save_to : str = None ):
43
+ self .extra_args = [] if extra_args is None else [flag (a ) for a in extra_args ]
44
+ if hal_target_backend == "rocm" :
45
+ self .extra_args += [
46
+ f"--iree-hip-target={ self .target_chip } " ,
47
+ ]
48
+ if hal_target_backend == "llvm-cpu" :
49
+ self .extra_args += [
50
+ "--iree-llvmcpu-target-cpu=host" ,
51
+ ]
52
+
53
+ def compile (self , module , * , save_to : str = None , extra_options : CompilerOptions ):
54
+ test_specific_args = list (extra_options .common_extra_args )
55
+ if self .hal_target_backend in extra_options .backend_specific_flags .keys ():
56
+ test_specific_args += list (extra_options .backend_specific_flags [self .hal_target_backend ])
57
+ compile_args = self .extra_args + [flag (arg ) for arg in test_specific_args ]
47
58
# compile to a vmfb for llvm-cpu
48
59
b = ireec .tools .compile_str (
49
60
str (module ),
50
61
target_backends = [self .hal_target_backend ],
51
- extra_args = self . extra_args ,
62
+ extra_args = compile_args ,
52
63
)
53
64
# log the vmfb
54
65
if save_to :
55
66
with open (os .path .join (save_to , "compiled_model.vmfb" ), "wb" ) as f :
56
67
f .write (b )
57
68
return b
58
69
59
- def load (self , artifact , * , func_name = "main" ):
70
+ def load (self , artifact , * , func_name = "main" , extra_options : RuntimeOptions ):
60
71
config = ireert .Config (self .device )
61
72
ctx = ireert .SystemContext (config = config )
62
73
vm_module = ireert .VmModule .copy_buffer (ctx .instance , artifact )
@@ -80,13 +91,7 @@ def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", target
80
91
self .device = device
81
92
self .hal_target_backend = hal_target_backend
82
93
self .target_chip = target_chip
83
- self .extra_args = []
84
- if extra_args :
85
- for a in extra_args :
86
- if a [0 :2 ] == "--" :
87
- self .extra_args .append (a )
88
- else :
89
- self .extra_args .append ("--" + a )
94
+ self .extra_args = [] if extra_args is None else [flag (a ) for a in extra_args ]
90
95
if hal_target_backend == "rocm" :
91
96
self .extra_args += [
92
97
f"--iree-hip-target={ self .target_chip } " ,
@@ -96,15 +101,17 @@ def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", target
96
101
"--iree-llvmcpu-target-cpu=host" ,
97
102
]
98
103
99
- def compile (self , module_path : str , * , save_to : str = None ) -> str :
104
+ def compile (self , module_path : str , * , save_to : str = None , extra_options : CompilerOptions ) -> str :
105
+ test_specific_args = list (extra_options .common_extra_args )
106
+ if self .hal_target_backend in extra_options .backend_specific_flags .keys ():
107
+ test_specific_args += list (extra_options .backend_specific_flags [self .hal_target_backend ])
108
+ compile_args = self .extra_args + [flag (arg ) for arg in test_specific_args ]
100
109
vmfb_path = os .path .join (save_to , "compiled_model.vmfb" )
101
110
arg_string = f"--iree-hal-target-backends={ self .hal_target_backend } "
102
- for arg in self .extra_args :
103
- arg_string += arg
104
- arg_string += " "
111
+ arg_string += ' ' .join (compile_args )
105
112
detail_log = os .path .join (save_to , "detail" , "compilation.detail.log" )
106
113
commands_log = os .path .join (save_to , "commands" , "compilation.commands.log" )
107
- script = f"iree-compile { module_path } { arg_string } -o { vmfb_path } 1> { detail_log } 2>&1"
114
+ script = f"iree-compile { module_path } { arg_string } -o { vmfb_path } 1> { detail_log } 2>&1"
108
115
with open (commands_log , "w" ) as file :
109
116
file .write (script )
110
117
# remove old vmfb if it exists
@@ -116,16 +123,21 @@ def compile(self, module_path: str, *, save_to : str = None) -> str:
116
123
raise FileNotFoundError (error_msg )
117
124
return vmfb_path
118
125
119
- def load (self , vmfb_path : str , * , func_name = None ):
126
+ def load (self , vmfb_path : str , * , func_name = None , extra_options : RuntimeOptions ):
120
127
"""A bit hacky. func returns a script that would dump outputs to terminal output. Modified in config.run method"""
128
+ test_specific_args = list (extra_options .common_extra_args )
129
+ if self .hal_target_backend in extra_options .backend_specific_flags .keys ():
130
+ test_specific_args += list (extra_options .backend_specific_flags [self .hal_target_backend ])
121
131
run_dir = Path (vmfb_path ).parent
122
132
def func (x : TestTensors ) -> str :
123
- script = f"iree-run-module --module='{ vmfb_path } ' --device={ self .device } "
133
+ script = f"iree-run-module --module='{ vmfb_path } ' --device={ self .device } "
134
+ for arg in test_specific_args :
135
+ script += f'{ flag (arg )} '
124
136
if func_name :
125
- script += f" --function='{ func_name } '"
137
+ script += f"--function='{ func_name } ' "
126
138
torch_inputs = x .to_torch ().data
127
139
for index , input in enumerate (torch_inputs ):
128
- script += f" --input='{ get_shape_string (input )} =@{ run_dir } /input.{ index } .bin'"
140
+ script += f"--input='{ get_shape_string (input )} =@{ run_dir } /input.{ index } .bin' "
129
141
return script
130
142
return func
131
143
@@ -135,16 +147,10 @@ class OnnxrtIreeEpBackend(BackendBase):
135
147
def __init__ (self , * , device = "local-task" , hal_target_device = "llvm-cpu" , extra_args : List [str ] = None ):
136
148
self .device = device
137
149
self .hal_target_device = hal_target_device
138
- if extra_args :
139
- self .extra_args = []
140
- for a in extra_args :
141
- if a [0 :2 ] == "--" :
142
- self .extra_args .append (a )
143
- else :
144
- self .extra_args .append ("--" + a )
145
- elif hal_target_device == "hip" :
150
+ self .extra_args = [] if extra_args is None else [flag (a ) for a in extra_args ]
151
+ if hal_target_device == "hip" :
146
152
# some extra args for Mi250 - some of these may not work for other chips
147
- self .extra_args = [
153
+ self .extra_args + = [
148
154
"--iree-hip-target=gfx90a" ,
149
155
]
150
156
self .providers = ["IreeExecutionProvider" ]
@@ -159,7 +165,7 @@ def __init__(self, *, device="local-task", hal_target_device="llvm-cpu", extra_a
159
165
# sess_opt.log_verbosity_level = 0
160
166
# self.sess_opt.log_severity_level = 0
161
167
162
- def compile (self , model : ModelProto , * , save_to : str = None ) -> ort .InferenceSession :
168
+ def compile (self , model : ModelProto , * , save_to : str = None , extra_options : CompilerOptions ) -> ort .InferenceSession :
163
169
if self .provider_options :
164
170
provider_options_dict = self .provider_options [0 ]
165
171
provider_options_dict ["save_to" ] = save_to
@@ -173,7 +179,7 @@ def compile(self, model: ModelProto, *, save_to: str = None) -> ort.InferenceSes
173
179
# can't save an onnx runtime session
174
180
return session
175
181
176
- def load (self , session : ort .InferenceSession , * , func_name = None ) -> Invoker :
182
+ def load (self , session : ort .InferenceSession , * , func_name = None , extra_options : RuntimeOptions ) -> Invoker :
177
183
def func (x : TestTensors ):
178
184
data = x .to_numpy ().data
179
185
session_inputs = session .get_inputs ()
0 commit comments