1
+ from transformers import AutoModelForCausalLM
2
+ import safetensors
3
+ from iree .compiler .ir import Context
4
+ import torch
5
+ import shark_turbine .aot as aot
6
+ from shark_turbine .aot import *
7
+ from turbine_models .custom_models .sd_inference import utils
8
+ import argparse
9
+
10
+
11
+ parser = argparse .ArgumentParser ()
12
+ parser .add_argument (
13
+ "--hf_model_name" ,
14
+ type = str ,
15
+ help = "HF model name" ,
16
+ default = "bert-large-uncased" ,
17
+ )
18
+ parser .add_argument (
19
+ "--hf_auth_token" , type = str , help = "The Hugging Face auth token, required" ,
20
+ )
21
+ parser .add_argument (
22
+ "--compile_to" , type = str , default = "linalg" , help = "linalg, vmfb"
23
+ )
24
+ parser .add_argument (
25
+ "--external_weights" ,
26
+ type = str ,
27
+ default = None ,
28
+ help = "saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]" ,
29
+ )
30
+ parser .add_argument (
31
+ "--device" , type = str , default = "cpu" , help = "cpu, cuda, vulkan, rocm"
32
+ )
33
+ # TODO: Bring in detection for target triple
34
+ parser .add_argument (
35
+ "--iree_target_triple" ,
36
+ type = str ,
37
+ default = "host" ,
38
+ help = "Specify vulkan target triple or rocm/cuda target device." ,
39
+ )
40
+ parser .add_argument ("--vulkan_max_allocation" , type = str , default = "4294967296" )
41
+
42
+
43
+ def export_bert_model (
44
+ hf_model_name ,
45
+ hf_auth_token = None ,
46
+ external_weights = None ,
47
+ compile_to = "linalg" ,
48
+ device = None ,
49
+ target_triple = None ,
50
+ max_alloc = None ,
51
+ ):
52
+ safe_name = args .hf_model_name .split ("/" )[- 1 ].strip ().replace ("-" , "_" )
53
+ model = AutoModelForCausalLM .from_pretrained (
54
+ hf_model_name ,
55
+ token = hf_auth_token ,
56
+ torch_dtype = torch .float ,
57
+ trust_remote_code = True ,
58
+ )
59
+
60
+ mapper = {}
61
+ if external_weights is not None :
62
+ if external_weights == "safetensors" :
63
+ mod_params = dict (model .named_parameters ())
64
+ for name in mod_params :
65
+ mapper ["params." + name ] = name
66
+ safetensors .torch .save_file (mod_params , safe_name + ".safetensors" )
67
+
68
+ elif external_weights == "gguf" :
69
+ tensor_mapper = remap_gguf .TensorNameMap (remap_gguf .MODEL_ARCH .LLAMA , HEADS )
70
+ mapper = tensor_mapper .mapping
71
+
72
+ class BertModule (CompiledModule ):
73
+ if external_weights :
74
+ params = export_parameters (
75
+ model , external = True , external_scope = "" , name_mapper = mapper .get
76
+ )
77
+ else :
78
+ params = export_parameters (model )
79
+ compute = jittable (model .forward )
80
+
81
+ def run_forward (
82
+ self ,
83
+ x = AbstractTensor (1 , 1 , dtype = torch .int64 ),
84
+ mask = AbstractTensor (1 , 1 , dtype = torch .int64 )
85
+ ):
86
+ return self .compute (x , attention_mask = mask )
87
+
88
+ inst = BertModule (context = Context ())
89
+ module_str = str (CompiledModule .get_mlir_module (inst ))
90
+
91
+ with open (f"{ safe_name } .mlir" , "w+" ) as f :
92
+ f .write (module_str )
93
+ print ("Saved to" , safe_name + ".mlir" )
94
+
95
+ if compile_to == "vmfb" :
96
+ utils .compile_to_vmfb (module_str , device , target_triple , max_alloc , safe_name )
97
+
98
+
99
+ if __name__ == "__main__" :
100
+ args = parser .parse_args ()
101
+ export_bert_model (
102
+ args .hf_model_name ,
103
+ args .hf_auth_token ,
104
+ args .external_weights ,
105
+ args .compile_to ,
106
+ args .device ,
107
+ args .iree_target_triple ,
108
+ args .vulkan_max_allocation ,
109
+ )
110
+
0 commit comments