Skip to content

Commit

Permalink
add autoround and remove name in path (intel#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel authored Feb 5, 2024
1 parent 01e10e6 commit e2d3652
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 19 deletions.
8 changes: 5 additions & 3 deletions neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_model_type(model_config):
model_type = "chatglm2"
return model_type

def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False,
def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_autoround=False,
weight_dtype="int4", alg="sym", group_size=32,
scale_dtype="fp32", compute_dtype="int8", use_ggml=False):
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
Expand All @@ -108,6 +108,8 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False,
quant_desc = "gptq"
if use_awq:
quant_desc = "awq"
if use_awq:
quant_desc = "autoround"
quant_bin = "{}/ne_{}_q_{}.bin".format(output_path, model_type, quant_desc)

if not use_quant:
Expand All @@ -120,8 +122,8 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False,
format(self.bin_file))
return

if use_gptq or use_awq:
convert_model(model_name, quant_bin, "f32")
if use_gptq or use_awq or use_autoround:
convert_model(model_name, quant_bin, use_quantized_model=True)
return

if not os.path.exists(fp32_bin):
Expand Down
5 changes: 2 additions & 3 deletions neural_speed/convert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper"}


def convert_model(model, outfile, outtype, whisper_repo_path=None):
def convert_model(model, outfile, outtype="f32", whisper_repo_path=None, use_quantized_model=False):
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
model_type = model_maps.get(config.model_type, config.model_type)

quantized_model = 'gptq' in str(model).lower() or 'awq' in str(model).lower()
if quantized_model:
if use_quantized_model:
path = Path(Path(__file__).parent.absolute(), "convert_quantized_{}.py".format(model_type))
else:
path = Path(Path(__file__).parent.absolute(), "convert_{}.py".format(model_type))
Expand Down
8 changes: 4 additions & 4 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def unpack_weight(qweight, scales, qzeros, q_config):
if "quant_method" not in q_config:
raise ValueError(f"Unsupported q_config without quant_method: {q_config}")
quant_method = q_config["quant_method"]
if quant_method == "gptq":
if quant_method == "gptq" or quant_method == "autoround":
qbits = q_config["bits"]
if qbits == 4:
return unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config)
Expand Down Expand Up @@ -354,7 +354,7 @@ def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2
# gptq_scale = torch.cat([gptq_scale,gptq_scale,gptq_scale,gptq_scale], dim=1).view(-1,1)
pack_tensor = torch.cat((gptq_scale.half().view(torch.int8), tensor), dim=-1)
pack_tensor.numpy().tofile(fout)
print(f"converting {dst_name} qauntized tensor to ggml q4 block")
print(f"converting {dst_name} quantized tensor to ggml q4 block")

def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2=0, permute_func=None):
qzeros = model[f"{src_name}.qzeros"]
Expand All @@ -381,7 +381,7 @@ def convert_q4_1_tensor(src_name, dst_name, model, fout, q_config, n_head, n_hea
gptq_zeros = -gptq_scale*gptq_zeros
pack_tensor = torch.cat((gptq_scale.half().view(torch.int8), gptq_zeros.half().view(torch.int8), tensor), dim=-1)
pack_tensor.numpy().tofile(fout)
print(f"converting {dst_name} qauntized tensor to ggml q4 1 block")
print(f"converting {dst_name} quantized tensor to ggml q4 1 block")


def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
Expand Down Expand Up @@ -411,4 +411,4 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
write_header(fout, shape, dst_name, 0)
weight.numpy().tofile(fout)

print(f"converting {dst_name} qauntized tensor to fp32 tensor")
print(f"converting {dst_name} quantized tensor to fp32 tensor")
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_quantized_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
print(f"converting {dst_name} quantized tensor to bestla q4 block")


def main(args_in: Optional[List[str]] = None) -> None:
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_quantized_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
print(f"converting {dst_name} quantized tensor to bestla q4 block")

def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
Expand Down
5 changes: 4 additions & 1 deletion scripts/cal_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
parser.add_argument('--model_name', type=str, default="~/Llama-2-7b-chat-hf")
parser.add_argument('--tasks', type=str, default="lambada_openai")
parser.add_argument('--model_format', type=str, default="runtime")
parser.add_argument("--use_gptq", action="store_true")
parser.add_argument("--use_awq", action="store_true")
parser.add_argument("--use_autoround", action="store_true")
args = parser.parse_args()

model_name = args.model_name
model_format = args.model_format
tasks = args.tasks
results = evaluate(
model="hf-causal",
model_args=f'pretrained="{model_name}"',
model_args=f'pretrained="{model_name}",use_gptq={args.use_gptq},use_awq={args.use_awq},use_autoround={args.use_autoround}',
tasks=[f"{tasks}"],
# limit=5,
model_format=f"{model_format}"
Expand Down
3 changes: 2 additions & 1 deletion scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def main(args_in: Optional[List[str]] = None) -> None:
)
parser.add_argument("--outfile", type=Path, required=True, help="path to write to")
parser.add_argument("model", type=Path, help="directory containing model file or model id")
parser.add_argument("--use_quantized_model", action="store_true", help="use quantized model: awq/gptq/autoround")
args = parser.parse_args(args_in)

if args.model.exists():
dir_model = args.model.as_posix()
else:
dir_model = args.model

convert_model(dir_model, args.outfile, args.outtype)
convert_model(dir_model, args.outfile, args.outtype, use_quantized_model=args.use_quantized_model)


if __name__ == "__main__":
Expand Down
15 changes: 10 additions & 5 deletions scripts/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,16 +614,21 @@ class AutoCausalLM(HuggingFaceAutoLM):

def __init__(self, *args, pretrained, model_format, **kwargs):
self.model_format = model_format
# if self.model_format == "runtime":
# from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig
# use_gptq = kwargs.pop("use_gptq", False)
# self.woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_gptq=use_gptq)
if self.model_format == "runtime":
self.use_gptq = kwargs.pop("use_gptq", False)
self.use_awq = kwargs.pop("use_awq", False)
self.use_autoround = kwargs.pop("use_autoround", False)
self.use_quant = kwargs.pop("use_quant", True)
super().__init__(*args, pretrained=pretrained, model_format=model_format, **kwargs)

if self.model_format == "runtime":
from neural_speed import Model
self.runtime_model = Model()
self.runtime_model.init(pretrained, weight_dtype="int4", compute_dtype="int8")
self.runtime_model.init(pretrained, weight_dtype="int4", compute_dtype="int8",
use_quant=self.use_quant,
use_gptq=self.use_gptq,
use_awq=self.use_awq,
use_autoround=self.use_autoround)

if self.model_format == "onnx":
if not os.path.exists(os.path.join(pretrained, "decoder_model.onnx")) and \
Expand Down

0 comments on commit e2d3652

Please sign in to comment.