Skip to content

Commit

Permalink
Merge pull request #90 from casper-hansen/mistral_fused
Browse files Browse the repository at this point in the history
Mistral fused modules
  • Loading branch information
casper-hansen authored Oct 2, 2023
2 parents aaab103 + 92579e9 commit 11efba0
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 20 deletions.
4 changes: 2 additions & 2 deletions awq/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs)
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP

Expand Down Expand Up @@ -143,7 +143,7 @@ def _fuse_qkv(self, module: LlamaAttention):

def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)

def fuse_mlp(self):
Expand Down
95 changes: 94 additions & 1 deletion awq/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import logging
from typing import Dict
from .base import BaseAWQForCausalLM

try:
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM
except:
# TODO: Remove once released on PyPi
logging.warning("You need the latest transformers 4.34.0.dev0: pip install git+https://github.com/huggingface/transformers.git")
logging.warning("You need the latest transformers 4.34.0.dev0: pip install -U git+https://github.com/huggingface/transformers.git")
MistralForCausalLM = None
MistralDecoderLayer = None

class MistralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MistralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: MistralForCausalLM, quant_config: Dict):
fuser = MistralFuser(model, quant_config)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()

@staticmethod
def get_model_layers(model: MistralForCausalLM):
return model.model.layers
Expand Down Expand Up @@ -65,3 +73,88 @@ def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwarg
))

return layers

import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP

class MistralFuser:
def __init__(self, model, quant_config):
self.model = model
self.quant_config = quant_config

self.attention_modules: List[Tuple[str, MistralAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralAttention)
]

self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralRMSNorm)
]

self.mlp_modules: List[Tuple[str, MistralMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralMLP)
]

def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)

def _fuse_qkv(self, module: MistralAttention):
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None

if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM

qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
)

if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)

qkv_layer.bias = bias

return qkv_layer

def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)

def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
5 changes: 1 addition & 4 deletions awq/modules/fused/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
from torch import nn
import awq_inference_engine

class FTLlamaRMSNorm(nn.Module):
class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = weight
self.variance_epsilon = eps
Expand Down
18 changes: 9 additions & 9 deletions examples/basic_generate.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

quant_path = "casperhansen/vicuna-7b-v1.5-awq"
quant_file = "awq_model_w4_g128.pt"
quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ"

# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, fuse_layers=True)
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False, safetensors=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# Convert prompt to tokens
prompt_template = """\
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
USER: {prompt}
ASSISTANT:"""
<|im_start|>system
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant"""

tokens = tokenizer(
prompt_template.format(prompt="How are you today?"),
prompt_template.format(prompt="Why is ice cream so good, yes so good?"),
return_tensors='pt'
).input_ids.cuda()

Expand Down
11 changes: 7 additions & 4 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def generate(model, input_ids, n_generate):

return context_time, generate_time

def run_round(model_path, quant_file, n_generate, input_ids, batch_size):
def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safetensors):
print(f" -- Loading model...")
model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, fuse_layers=True,
max_new_tokens=n_generate, batch_size=batch_size
max_new_tokens=n_generate, batch_size=batch_size,
safetensors=safetensors
)

print(f" -- Warming up...")
Expand Down Expand Up @@ -108,7 +109,8 @@ def main(args):
args.quant_file,
settings["n_generate"],
input_ids,
args.batch_size
args.batch_size,
args.safetensors
)

all_stats.append(stats)
Expand All @@ -126,7 +128,8 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
args = parser.parse_args()

main(args)

0 comments on commit 11efba0

Please sign in to comment.