diff --git a/awq/models/llama.py b/awq/models/llama.py index fa40a9ce..363621e9 100644 --- a/awq/models/llama.py +++ b/awq/models/llama.py @@ -70,8 +70,8 @@ def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs) from typing import List, Tuple, Union from awq.utils.utils import set_module_name from awq.modules.fused.mlp import QuantLlamaMLP -from awq.modules.fused.norm import FTLlamaRMSNorm from awq.modules.fused.attn import QuantAttentionFused +from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP @@ -143,7 +143,7 @@ def _fuse_qkv(self, module: LlamaAttention): def fuse_rmsnorm(self): for name, module in self.rmsnorm_modules: - norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon) + norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon) set_module_name(self.model, name, norm) def fuse_mlp(self): diff --git a/awq/models/mistral.py b/awq/models/mistral.py index f0bd233c..54d6719f 100644 --- a/awq/models/mistral.py +++ b/awq/models/mistral.py @@ -1,11 +1,12 @@ import logging +from typing import Dict from .base import BaseAWQForCausalLM try: from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM except: # TODO: Remove once released on PyPi - logging.warning("You need the latest transformers 4.34.0.dev0: pip install git+https://github.com/huggingface/transformers.git") + logging.warning("You need the latest transformers 4.34.0.dev0: pip install -U git+https://github.com/huggingface/transformers.git") MistralForCausalLM = None MistralDecoderLayer = None @@ -13,6 +14,13 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MistralDecoderLayer" max_new_tokens_key = "max_position_embeddings" + @staticmethod + def fuse_layers(model: MistralForCausalLM, quant_config: Dict): + fuser = MistralFuser(model, quant_config) + fuser.fuse_attention() + fuser.fuse_rmsnorm() + fuser.fuse_mlp() + @staticmethod def get_model_layers(model: MistralForCausalLM): return model.model.layers @@ -65,3 +73,88 @@ def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwarg )) return layers + +import torch +from typing import List, Tuple, Union +from awq.utils.utils import set_module_name +from awq.modules.fused.mlp import QuantLlamaMLP +from awq.modules.fused.attn import QuantAttentionFused +from awq.modules.fused.norm import FasterTransformerRMSNorm +from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV +from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP + +class MistralFuser: + def __init__(self, model, quant_config): + self.model = model + self.quant_config = quant_config + + self.attention_modules: List[Tuple[str, MistralAttention]] = [ + (name, module) for name, module in self.model.named_modules() + if isinstance(module, MistralAttention) + ] + + self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [ + (name, module) for name, module in self.model.named_modules() + if isinstance(module, MistralRMSNorm) + ] + + self.mlp_modules: List[Tuple[str, MistralMLP]] = [ + (name, module) for name, module in self.model.named_modules() + if isinstance(module, MistralMLP) + ] + + def fuse_attention(self): + for name, module in self.attention_modules: + qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module) + attn = QuantAttentionFused( + module.hidden_size, + module.num_heads, + module.num_key_value_heads, + qkv_layer, + module.o_proj, + next(iter(qkv_layer.state_dict().values())).device, + self.model.config.max_new_tokens + ) + set_module_name(self.model, name, attn) + + def _fuse_qkv(self, module: MistralAttention): + q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + + if isinstance(q_proj, WQLinear_GEMV): + q_linear = WQLinear_GEMV + else: + q_linear = WQLinear_GEMM + + qkv_layer = q_linear( + q_proj.w_bit, + q_proj.group_size, + q_proj.in_features, + q_proj.out_features + k_proj.out_features + v_proj.out_features, + q_proj.bias is not None, + next(iter(module.state_dict().values())).device + ) + + if isinstance(qkv_layer, WQLinear_GEMV): + qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) + qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) + qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) + qkv_layer.split_k_iters = q_proj.split_k_iters + else: + qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + + qkv_layer.bias = bias + + return qkv_layer + + def fuse_rmsnorm(self): + for name, module in self.rmsnorm_modules: + norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon) + set_module_name(self.model, name, norm) + + def fuse_mlp(self): + for name, module in self.mlp_modules: + mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj) + set_module_name(self.model, name, mlp) \ No newline at end of file diff --git a/awq/modules/fused/norm.py b/awq/modules/fused/norm.py index 3db77c90..8c6fe50a 100644 --- a/awq/modules/fused/norm.py +++ b/awq/modules/fused/norm.py @@ -2,11 +2,8 @@ from torch import nn import awq_inference_engine -class FTLlamaRMSNorm(nn.Module): +class FasterTransformerRMSNorm(nn.Module): def __init__(self, weight, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ super().__init__() self.weight = weight self.variance_epsilon = eps diff --git a/examples/basic_generate.py b/examples/basic_generate.py index 5a9a678f..1d76908c 100644 --- a/examples/basic_generate.py +++ b/examples/basic_generate.py @@ -1,23 +1,23 @@ from awq import AutoAWQForCausalLM from transformers import AutoTokenizer, TextStreamer -quant_path = "casperhansen/vicuna-7b-v1.5-awq" -quant_file = "awq_model_w4_g128.pt" +quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ" # Load model -model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, fuse_layers=True) +model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False, safetensors=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) -streamer = TextStreamer(tokenizer, skip_special_tokens=True) +streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Convert prompt to tokens prompt_template = """\ -A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. - -USER: {prompt} -ASSISTANT:""" +<|im_start|>system +You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|> +<|im_start|>user +{prompt}<|im_end|> +<|im_start|>assistant""" tokens = tokenizer( - prompt_template.format(prompt="How are you today?"), + prompt_template.format(prompt="Why is ice cream so good, yes so good?"), return_tensors='pt' ).input_ids.cuda() diff --git a/examples/benchmark.py b/examples/benchmark.py index c6b0db73..54f612b3 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -39,11 +39,12 @@ def generate(model, input_ids, n_generate): return context_time, generate_time -def run_round(model_path, quant_file, n_generate, input_ids, batch_size): +def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safetensors): print(f" -- Loading model...") model = AutoAWQForCausalLM.from_quantized( model_path, quant_file, fuse_layers=True, - max_new_tokens=n_generate, batch_size=batch_size + max_new_tokens=n_generate, batch_size=batch_size, + safetensors=safetensors ) print(f" -- Warming up...") @@ -108,7 +109,8 @@ def main(args): args.quant_file, settings["n_generate"], input_ids, - args.batch_size + args.batch_size, + args.safetensors ) all_stats.append(stats) @@ -126,7 +128,8 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model") parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename") - parser.add_argument("--batch_size", type=int, default=1, help="weights filename") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation") + parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors") args = parser.parse_args() main(args) \ No newline at end of file