From 33dfb04853310e52fa30abf93af9d6ed85550855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=91=E5=B9=B4?= <48116214+shaonianyr@users.noreply.github.com> Date: Sat, 6 Apr 2024 21:06:04 +0800 Subject: [PATCH] add starcoder2 support (#406) Co-authored-by: charrli --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/starcoder2.py | 141 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+) create mode 100644 awq/models/starcoder2.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 75542fe4..b2496170 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -15,3 +15,4 @@ from .mixtral import MixtralAWQForCausalLM from .qwen2 import Qwen2AWQForCausalLM from .gemma import GemmaAWQForCausalLM +from .starcoder2 import Starcoder2AWQForCausalLM \ No newline at end of file diff --git a/awq/models/auto.py b/awq/models/auto.py index cf35a279..a99b7a75 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -24,6 +24,7 @@ "llava": LlavaAWQForCausalLM, "qwen2": Qwen2AWQForCausalLM, "gemma": GemmaAWQForCausalLM, + "starcoder2": Starcoder2AWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 12607348..f32576b8 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -68,6 +68,7 @@ "llava": "AutoModelForVision2Seq", "qwen2": "AutoModelForCausalLM", "gemma": "AutoModelForCausalLM", + "starcoder2": "AutoModelForCausalLM", } diff --git a/awq/models/starcoder2.py b/awq/models/starcoder2.py new file mode 100644 index 00000000..2e493514 --- /dev/null +++ b/awq/models/starcoder2.py @@ -0,0 +1,141 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel +from transformers.models.starcoder2.modeling_starcoder2 import ( + Starcoder2ForCausalLM as OldStarcoder2ForCausalLM, + Starcoder2DecoderLayer as OldStarcoder2DecoderLayer, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + + +class Starcoder2AWQForCausalLM(BaseAWQForCausalLM): + layer_type = "Starcoder2DecoderLayer" + max_seq_len_key = "max_position_embeddings" + + @staticmethod + def fuse_layers(model: OldStarcoder2ForCausalLM): + fuser = Starcoder2Fuser(model) + fuser.fuse_transformer() + + @staticmethod + def get_model_layers(model: OldStarcoder2ForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: OldStarcoder2DecoderLayer): + return dict( + is_scalable=True, + scale_name="mlp.act", + scale_layer=module.mlp.act, + scale_shape=module.mlp.c_fc.out_features, + ) + # return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldStarcoder2ForCausalLM, device): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.c_fc], + inp=input_feat["mlp.c_fc"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.act, + layers=[module.mlp.c_proj], + inp=input_feat["mlp.c_proj"], + ) + ) + + return layers + +class Starcoder2Fuser: + def __init__(self, model: OldStarcoder2ForCausalLM): + self.model = model + + self.starcoder2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [ + (name, module) + for name, module in self.model.named_modules() + if "Starcoder2DecoderLayer".lower() in module.__class__.__name__.lower() + ] + + def fuse_transformer(self): + blocks = [] + + module: OldStarcoder2DecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ) + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, module.input_layernorm.eps + ) + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.eps, + ) + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + ) + ) + + self.model.model = LlamaLikeModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) + setattr(self.model.model, "blocks", self.model.model.blocks) \ No newline at end of file