48
48
"dim" : 32 ,
49
49
"buffer_prefix" : "albert"
50
50
},
51
- "hf_Bart" : {
52
- "dim" : 16 ,
53
- "buffer_prefix" : "bart"
54
- },
55
- "hf_Bert" : {
56
- "dim" : 16 ,
57
- "buffer_prefix" : "bert"
58
- },
59
- "hf_GPT2" : {
60
- "dim" : 16 ,
61
- "buffer_prefix" : "gpt2"
62
- },
63
- "hf_T5" : {
64
- "dim" : 4 ,
65
- "buffer_prefix" : "t5"
66
- },
51
+ # "hf_Bart": {
52
+ # "dim": 16,
53
+ # },
54
+ # "hf_Bert": {
55
+ # "dim": 16,
56
+ # "buffer_prefix": "bert"
57
+ # },
58
+ # "hf_GPT2": {
59
+ # "dim": 16,
60
+ # "buffer_prefix": "gpt2"
61
+ # },
62
+ # "hf_T5": {
63
+ # "dim": 4,
64
+ # "buffer_prefix": "t5"
65
+ # },
67
66
"mnasnet1_0" : {
68
67
"dim" : 256 ,
69
68
},
@@ -182,30 +181,21 @@ def export_torchbench_model(
182
181
183
182
_ , model_name , model , forward_args , _ = get_model_and_inputs (model_id , batch_size , tb_dir , tb_args )
184
183
184
+ for idx , i in enumerate (forward_args .values ()):
185
+ np .save (f"input{ idx } " , i .clone ().detach ().cpu ())
185
186
if dtype == torch .float16 :
186
187
model = model .half ()
187
188
model .to ("cuda:0" )
188
189
189
190
if not isinstance (forward_args , dict ):
190
191
forward_args = [i .type (dtype ) for i in forward_args ]
191
- elif "hf" in model_id :
192
- forward_args ["head_mask" ] = torch .zeros (model .config .num_hidden_layers , device = "cuda:0" )
193
192
194
193
mapper = {}
195
194
if (external_weights_dir is not None ):
196
195
if not os .path .exists (external_weights_dir ):
197
196
os .mkdir (external_weights_dir )
198
- external_weight_path = os .path .join (external_weights_dir , f"{ model_id } _{ precision } .{ external_weights } " )
199
- if os .path .exists (external_weight_path ):
200
- print ("External weights for this module already exist at {external_weight_path}. Will not overwrite." )
201
- utils .save_external_weights (
202
- mapper ,
203
- model ,
204
- external_weights ,
205
- external_weight_path ,
206
- )
207
- if weights_only :
208
- return external_weight_path
197
+ external_weight_path = os .path .join (external_weights_dir , f"{ model_id } _{ precision } .irpa" )
198
+
209
199
210
200
decomp_list = [torch .ops .aten .reflection_pad2d ]
211
201
if decomp_attn == True :
@@ -225,18 +215,20 @@ def __init__(self, model):
225
215
self .mod = model
226
216
227
217
def forward (self , inp ):
228
- return self .mod (** inp , return_dict = False )
229
- # In transformers, the position ids buffer is registered as non-persistent,
230
- # which makes it fail to globalize in the FX import.
231
- # Add them manually to the state dict here.
232
-
233
- prefix = torchbench_models_dict [model_id ]["buffer_prefix" ]
234
- getattr (model , prefix ).embeddings .register_buffer (
235
- "position_ids" ,
236
- getattr (model , prefix ).embeddings .position_ids ,
237
- persistent = True ,
238
- )
218
+ return self .mod (** inp )
219
+
220
+ if "Bart" not in model_id :
221
+ # In some transformers models, the position ids buffer is registered as non-persistent,
222
+ # which makes it fail to globalize in the FX import.
223
+ # Add them manually to the state dict here.
239
224
225
+ prefix = torchbench_models_dict [model_id ]["buffer_prefix" ]
226
+ getattr (model , prefix ).embeddings .register_buffer (
227
+ "position_ids" ,
228
+ getattr (model , prefix ).embeddings .position_ids ,
229
+ persistent = True ,
230
+ )
231
+ breakpoint ()
240
232
fxb = FxProgramsBuilder (HF_M (model ))
241
233
@fxb .export_program (args = (forward_args ,))
242
234
def _forward (module : HF_M (model ), inputs ):
@@ -252,6 +244,7 @@ class CompiledTorchbenchModel(CompiledModule):
252
244
253
245
if external_weights :
254
246
externalize_module_parameters (model )
247
+ save_module_parameters (external_weight_path , model )
255
248
256
249
inst = CompiledTorchbenchModel (context = Context (), import_to = "IMPORT" )
257
250
0 commit comments