Skip to content

Commit 2af4b87

Browse files
committed
Add bert script
1 parent cc61977 commit 2af4b87

File tree

1 file changed

+110
-0
lines changed
  • python/turbine_models/custom_models

1 file changed

+110
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from transformers import AutoModelForCausalLM
2+
import safetensors
3+
from iree.compiler.ir import Context
4+
import torch
5+
import shark_turbine.aot as aot
6+
from shark_turbine.aot import *
7+
from turbine_models.custom_models.sd_inference import utils
8+
import argparse
9+
10+
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument(
13+
"--hf_model_name",
14+
type=str,
15+
help="HF model name",
16+
default="bert-large-uncased",
17+
)
18+
parser.add_argument(
19+
"--hf_auth_token", type=str, help="The Hugging Face auth token, required",
20+
)
21+
parser.add_argument(
22+
"--compile_to", type=str, default="linalg", help="linalg, vmfb"
23+
)
24+
parser.add_argument(
25+
"--external_weights",
26+
type=str,
27+
default=None,
28+
help="saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]",
29+
)
30+
parser.add_argument(
31+
"--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm"
32+
)
33+
# TODO: Bring in detection for target triple
34+
parser.add_argument(
35+
"--iree_target_triple",
36+
type=str,
37+
default="host",
38+
help="Specify vulkan target triple or rocm/cuda target device.",
39+
)
40+
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
41+
42+
43+
def export_bert_model(
44+
hf_model_name,
45+
hf_auth_token=None,
46+
external_weights=None,
47+
compile_to="linalg",
48+
device=None,
49+
target_triple=None,
50+
max_alloc=None,
51+
):
52+
safe_name = args.hf_model_name.split("/")[-1].strip().replace("-", "_")
53+
model = AutoModelForCausalLM.from_pretrained(
54+
hf_model_name,
55+
token=hf_auth_token,
56+
torch_dtype=torch.float,
57+
trust_remote_code=True,
58+
)
59+
60+
mapper = {}
61+
if external_weights is not None:
62+
if external_weights == "safetensors":
63+
mod_params = dict(model.named_parameters())
64+
for name in mod_params:
65+
mapper["params." + name] = name
66+
safetensors.torch.save_file(mod_params, safe_name + ".safetensors")
67+
68+
elif external_weights == "gguf":
69+
tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)
70+
mapper = tensor_mapper.mapping
71+
72+
class BertModule(CompiledModule):
73+
if external_weights:
74+
params = export_parameters(
75+
model, external=True, external_scope="", name_mapper=mapper.get
76+
)
77+
else:
78+
params = export_parameters(model)
79+
compute = jittable(model.forward)
80+
81+
def run_forward(
82+
self,
83+
x=AbstractTensor(1, 1, dtype=torch.int64),
84+
mask=AbstractTensor(1, 1, dtype=torch.int64)
85+
):
86+
return self.compute(x, attention_mask=mask)
87+
88+
inst = BertModule(context=Context())
89+
module_str = str(CompiledModule.get_mlir_module(inst))
90+
91+
with open(f"{safe_name}.mlir", "w+") as f:
92+
f.write(module_str)
93+
print("Saved to", safe_name + ".mlir")
94+
95+
if compile_to == "vmfb":
96+
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
97+
98+
99+
if __name__ == "__main__":
100+
args = parser.parse_args()
101+
export_bert_model(
102+
args.hf_model_name,
103+
args.hf_auth_token,
104+
args.external_weights,
105+
args.compile_to,
106+
args.device,
107+
args.iree_target_triple,
108+
args.vulkan_max_allocation,
109+
)
110+

0 commit comments

Comments
 (0)