Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds script to generate external parameters as safetensors #191

Merged
merged 1 commit into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions python/turbine_models/gen_external_params/gen_external_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from turbine_models.model_builder import HFTransformerBuilder
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name ID",
default="meta-llama/Llama-2-7b-chat-hf",
)
parser.add_argument("--quantization", type=str, default="int4")
parser.add_argument("--weight_path", type=str, default="Llama2_7b_i4quant.safetensors")
parser.add_argument(
"--hf_auth_token", type=str, help="The HF auth token required for some models"
)


def quantize(model, quantization):
accumulates = torch.float32 # TODO (ian): adjust based on model precision
if quantization in ["int4", "int8"]:
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

print("Applying weight quantization...")
weight_bit_width = 4 if quantization == "int4" else 8
quantize_model(
get_model_impl(model).layers,
dtype=accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128, # TODO: make adjustable
quantize_weight_zero_point=False,
)
from brevitas_examples.llm.llm_quant.export import LinearWeightBlockQuantHandler
from brevitas.nn.quant_linear import QuantLinear

class DummyLinearWeightBlockQuantHandler(LinearWeightBlockQuantHandler):
def forward(self, x):
raise NotImplementedError

int_weights = {}
for prefix, layer in model.named_modules():
if isinstance(layer, QuantLinear):
print(f"Exporting layer {prefix}")
exporter = DummyLinearWeightBlockQuantHandler()
exporter.prepare_for_export(layer)
print(
f" weight = ({exporter.int_weight.shape}, {exporter.int_weight.dtype}), "
f"scale=({exporter.scale.shape}, {exporter.scale.dtype}), "
f"zero=({exporter.zero_point.shape}, {exporter.zero_point.dtype})"
)
int_weights[f"{prefix}.weight"] = exporter.int_weight
int_weights[f"{prefix}.weight_scale"] = exporter.scale
int_weights[f"{prefix}.weight_zp"] = exporter.zero_point

all_weights = dict(model.named_parameters())
for k in list(all_weights.keys()):
if "wrapped_scaling_impl" in k or "wrapped_zero_point_impl" in k:
del all_weights[k]

all_weights.update(int_weights)
return all_weights


if __name__ == "__main__":
args = parser.parse_args()
model_builder = HFTransformerBuilder(
example_input=None,
hf_id=args.hf_model_name,
auto_model=AutoModelForCausalLM,
hf_auth_token=args.hf_auth_token,
)
model_builder.build_model()
quant_weights = quantize(model_builder.model, args.quantization)
# TODO: Add more than just safetensor support
import safetensors

safetensors.torch.save_file(quant_weights, args.weight_path)
print("Saved safetensor output to ", args.weight_path)
4 changes: 2 additions & 2 deletions python/turbine_models/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def build_model(self) -> None:
"""
# TODO: check cloud storage for existing ir
self.model = self.auto_model.from_pretrained(
self.hf_id, auth_token=self.hf_auth_token, config=self.auto_config
self.hf_id, use_auth_token=self.hf_auth_token, config=self.auto_config
)
if self.auto_tokenizer is not None:
self.tokenizer = self.auto_tokenizer.from_pretrained(
self.hf_id, auth_token=self.hf_auth_token
self.hf_id, use_auth_token=self.hf_auth_token
)
else:
self.tokenizer = None
Expand Down
Loading