1
1
from msvcrt import kbhit
2
- from shark .iree_utils .compile_utils import get_iree_compiled_module , load_vmfb_using_mmap
2
+ from shark .iree_utils .compile_utils import (
3
+ get_iree_compiled_module ,
4
+ load_vmfb_using_mmap ,
5
+ clean_device_info ,
6
+ get_iree_target_triple ,
7
+ )
3
8
from apps .shark_studio .web .utils .file_utils import (
4
9
get_checkpoints_path ,
5
10
get_resource_path ,
@@ -32,8 +37,8 @@ def __init__(
32
37
self .model_map = model_map
33
38
self .static_kwargs = static_kwargs
34
39
self .base_model_id = base_model_id
35
- self .device_name = device
36
- self .device = device . split ( "=>" )[ - 1 ]. strip ( " " )
40
+ self .triple = get_iree_target_triple ( device )
41
+ self .device , self . device_id = clean_device_info ( device )
37
42
self .import_mlir = import_mlir
38
43
self .iree_module_dict = {}
39
44
self .tempfiles = {}
@@ -46,22 +51,24 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
46
51
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
47
52
# and your model map is populated with any IR - unique model IDs and their static params,
48
53
# call this method to get the artifacts associated with your map.
49
- self .pipe_id = pipe_id
54
+ self .pipe_id = self . safe_name ( pipe_id )
50
55
self .pipe_vmfb_path = Path (os .path .join (get_checkpoints_path (".." ), self .pipe_id ))
51
56
self .pipe_vmfb_path .mkdir (parents = True , exist_ok = True )
52
- print ("\n [LOG] Checking for pre-compiled artifacts." )
53
57
if submodel == "None" :
58
+ print ("\n [LOG] Gathering any pre-compiled artifacts...." )
54
59
for key in self .model_map :
55
60
self .get_compiled_map (pipe_id , submodel = key )
56
61
else :
57
62
self .get_precompiled (pipe_id , submodel )
58
63
ireec_flags = []
59
64
if submodel in self .iree_module_dict :
60
65
if "vmfb" in self .iree_module_dict [submodel ]:
61
- print (f"[LOG] Found executable for { submodel } at { self . iree_module_dict [ submodel ][ 'vmfb' ] } ..." )
66
+ print (f"\n [LOG] Executable for { submodel } already loaded ..." )
62
67
return
68
+ elif "vmfb_path" in self .model_map [submodel ]:
69
+ return
63
70
elif submodel not in self .tempfiles :
64
- print (f"[LOG] Tempfile for { submodel } not found. Fetching torch IR..." )
71
+ print (f"\n [LOG] Tempfile for { submodel } not found. Fetching torch IR..." )
65
72
if submodel in self .static_kwargs :
66
73
init_kwargs = self .static_kwargs [submodel ]
67
74
for key in self .static_kwargs ["pipe" ]:
@@ -90,16 +97,6 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
90
97
return
91
98
92
99
93
- def hijack_weights (self , weights_path , submodel = "None" ):
94
- if submodel == "None" :
95
- for i in self .model_map :
96
- self .hijack_weights (weights_path , i )
97
- else :
98
- if submodel in self .iree_module_dict :
99
- self .model_map [submodel ]["external_weights_file" ] = weights_path
100
- return
101
-
102
-
103
100
def get_precompiled (self , pipe_id , submodel = "None" ):
104
101
if submodel == "None" :
105
102
for model in self .model_map :
@@ -112,33 +109,10 @@ def get_precompiled(self, pipe_id, submodel="None"):
112
109
break
113
110
for file in vmfbs :
114
111
if submodel in file :
115
- print (f"Found existing .vmfb at { file } " )
116
- self .iree_module_dict [submodel ] = {}
117
- (
118
- self .iree_module_dict [submodel ]["vmfb" ],
119
- self .iree_module_dict [submodel ]["config" ],
120
- self .iree_module_dict [submodel ]["temp_file_to_unlink" ],
121
- ) = load_vmfb_using_mmap (
122
- os .path .join (vmfbs_path , file ),
123
- self .device ,
124
- device_idx = 0 ,
125
- rt_flags = [],
126
- external_weight_file = self .model_map [submodel ]['external_weight_file' ],
127
- )
112
+ self .model_map [submodel ]["vmfb_path" ] = os .path .join (vmfbs_path , file )
128
113
return
129
114
130
115
131
- def safe_dict (self , kwargs : dict ):
132
- flat_args = {}
133
- for i in kwargs :
134
- if isinstance (kwargs [i ], dict ) and "pass_dict" not in kwargs [i ]:
135
- flat_args [i ] = [kwargs [i ][j ] for j in kwargs [i ]]
136
- else :
137
- flat_args [i ] = kwargs [i ]
138
-
139
- return flat_args
140
-
141
-
142
116
def import_torch_ir (self , submodel , kwargs ):
143
117
torch_ir = self .model_map [submodel ]["initializer" ](
144
118
** self .safe_dict (kwargs ), compile_to = "torch"
@@ -160,18 +134,53 @@ def import_torch_ir(self, submodel, kwargs):
160
134
def load_submodels (self , submodels : list ):
161
135
for submodel in submodels :
162
136
if submodel in self .iree_module_dict :
137
+ print (f"\n [LOG] { submodel } is ready for inference." )
138
+ if "vmfb_path" in self .model_map [submodel ]:
163
139
print (
164
- f"\n [LOG] Loading .vmfb for { submodel } from { self .iree_module_dict [submodel ]['vmfb' ]} "
140
+ f"\n [LOG] Loading .vmfb for { submodel } from { self .model_map [submodel ]['vmfb_path' ]} "
141
+ )
142
+ self .iree_module_dict [submodel ] = {}
143
+ (
144
+ self .iree_module_dict [submodel ]["vmfb" ],
145
+ self .iree_module_dict [submodel ]["config" ],
146
+ self .iree_module_dict [submodel ]["temp_file_to_unlink" ],
147
+ ) = load_vmfb_using_mmap (
148
+ self .model_map [submodel ]["vmfb_path" ],
149
+ self .device ,
150
+ device_idx = 0 ,
151
+ rt_flags = [],
152
+ external_weight_file = self .model_map [submodel ]['external_weight_file' ],
165
153
)
166
154
else :
167
155
self .get_compiled_map (self .pipe_id , submodel )
168
156
return
169
157
170
158
159
+ def unload_submodels (self , submodels : list ):
160
+ for submodel in submodels :
161
+ if submodel in self .iree_module_dict :
162
+ del self .iree_module_dict [submodel ]
163
+ gc .collect ()
164
+ return
165
+
166
+
171
167
def run (self , submodel , inputs ):
172
- inp = [ireert .asdevicearray (self .iree_module_dict [submodel ]["config" ].device , inputs )]
168
+ if not isinstance (inputs , list ):
169
+ inputs = [inputs ]
170
+ inp = [ireert .asdevicearray (self .iree_module_dict [submodel ]["config" ].device , input ) for input in inputs ]
173
171
return self .iree_module_dict [submodel ]['vmfb' ]['main' ](* inp )
174
172
175
173
176
- def safe_name (name ):
177
- return name .replace ("/" , "_" ).replace ("-" , "_" )
174
+ def safe_name (self , name ):
175
+ return name .replace ("/" , "_" ).replace ("-" , "_" ).replace ("\\ " , "_" )
176
+
177
+
178
+ def safe_dict (self , kwargs : dict ):
179
+ flat_args = {}
180
+ for i in kwargs :
181
+ if isinstance (kwargs [i ], dict ) and "pass_dict" not in kwargs [i ]:
182
+ flat_args [i ] = [kwargs [i ][j ] for j in kwargs [i ]]
183
+ else :
184
+ flat_args [i ] = kwargs [i ]
185
+
186
+ return flat_args
0 commit comments