Skip to content

Commit 9dae027

Browse files
authored
Adds script to generate external parameters as safetensors (#191)
- Script to generate safetensors from an hf model using model_builder.py - Provides support for brevitas quantized paramters
1 parent 0eac063 commit 9dae027

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from turbine_models.model_builder import HFTransformerBuilder
2+
from transformers import AutoTokenizer, AutoModelForCausalLM
3+
import torch
4+
5+
import argparse
6+
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument(
9+
"--hf_model_name",
10+
type=str,
11+
help="HF model name ID",
12+
default="meta-llama/Llama-2-7b-chat-hf",
13+
)
14+
parser.add_argument("--quantization", type=str, default="int4")
15+
parser.add_argument("--weight_path", type=str, default="Llama2_7b_i4quant.safetensors")
16+
parser.add_argument(
17+
"--hf_auth_token", type=str, help="The HF auth token required for some models"
18+
)
19+
20+
21+
def quantize(model, quantization):
22+
accumulates = torch.float32 # TODO (ian): adjust based on model precision
23+
if quantization in ["int4", "int8"]:
24+
from brevitas_examples.common.generative.quantize import quantize_model
25+
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
26+
27+
print("Applying weight quantization...")
28+
weight_bit_width = 4 if quantization == "int4" else 8
29+
quantize_model(
30+
get_model_impl(model).layers,
31+
dtype=accumulates,
32+
weight_bit_width=weight_bit_width,
33+
weight_param_method="stats",
34+
weight_scale_precision="float_scale",
35+
weight_quant_type="asym",
36+
weight_quant_granularity="per_group",
37+
weight_group_size=128, # TODO: make adjustable
38+
quantize_weight_zero_point=False,
39+
)
40+
from brevitas_examples.llm.llm_quant.export import LinearWeightBlockQuantHandler
41+
from brevitas.nn.quant_linear import QuantLinear
42+
43+
class DummyLinearWeightBlockQuantHandler(LinearWeightBlockQuantHandler):
44+
def forward(self, x):
45+
raise NotImplementedError
46+
47+
int_weights = {}
48+
for prefix, layer in model.named_modules():
49+
if isinstance(layer, QuantLinear):
50+
print(f"Exporting layer {prefix}")
51+
exporter = DummyLinearWeightBlockQuantHandler()
52+
exporter.prepare_for_export(layer)
53+
print(
54+
f" weight = ({exporter.int_weight.shape}, {exporter.int_weight.dtype}), "
55+
f"scale=({exporter.scale.shape}, {exporter.scale.dtype}), "
56+
f"zero=({exporter.zero_point.shape}, {exporter.zero_point.dtype})"
57+
)
58+
int_weights[f"{prefix}.weight"] = exporter.int_weight
59+
int_weights[f"{prefix}.weight_scale"] = exporter.scale
60+
int_weights[f"{prefix}.weight_zp"] = exporter.zero_point
61+
62+
all_weights = dict(model.named_parameters())
63+
for k in list(all_weights.keys()):
64+
if "wrapped_scaling_impl" in k or "wrapped_zero_point_impl" in k:
65+
del all_weights[k]
66+
67+
all_weights.update(int_weights)
68+
return all_weights
69+
70+
71+
if __name__ == "__main__":
72+
args = parser.parse_args()
73+
model_builder = HFTransformerBuilder(
74+
example_input=None,
75+
hf_id=args.hf_model_name,
76+
auto_model=AutoModelForCausalLM,
77+
hf_auth_token=args.hf_auth_token,
78+
)
79+
model_builder.build_model()
80+
quant_weights = quantize(model_builder.model, args.quantization)
81+
# TODO: Add more than just safetensor support
82+
import safetensors
83+
84+
safetensors.torch.save_file(quant_weights, args.weight_path)
85+
print("Saved safetensor output to ", args.weight_path)

python/turbine_models/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def build_model(self) -> None:
4040
"""
4141
# TODO: check cloud storage for existing ir
4242
self.model = self.auto_model.from_pretrained(
43-
self.hf_id, auth_token=self.hf_auth_token, config=self.auto_config
43+
self.hf_id, use_auth_token=self.hf_auth_token, config=self.auto_config
4444
)
4545
if self.auto_tokenizer is not None:
4646
self.tokenizer = self.auto_tokenizer.from_pretrained(
47-
self.hf_id, auth_token=self.hf_auth_token
47+
self.hf_id, use_auth_token=self.hf_auth_token
4848
)
4949
else:
5050
self.tokenizer = None

0 commit comments

Comments
 (0)