From 6f14fc7436d9a3fb5fc69299e4eb37db4ee9c891 Mon Sep 17 00:00:00 2001 From: Crystalcareai <162942000+Crystalcareai@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:20:37 -0500 Subject: [PATCH] Add Internlm2 support (#576) Co-authored-by: Casper --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/internlm2.py | 76 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+) create mode 100644 awq/models/internlm2.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 78bde8ae..2f1a88e2 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -23,3 +23,4 @@ from .cohere import CohereAWQForCausalLM from .deepseek_v2 import DeepseekV2AWQForCausalLM from .minicpm import MiniCPMAWQForCausalLM +from .internlm2 import InternLM2AWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 1c806d77..3a6416f1 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -33,6 +33,7 @@ "cohere": CohereAWQForCausalLM, "deepseek_v2": DeepseekV2AWQForCausalLM, "minicpm": MiniCPMAWQForCausalLM, + "internlm2": InternLM2AWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 4338393e..1d376fc0 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -84,6 +84,7 @@ "cohere": "AutoModelForCausalLM", "deepseek_v2": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", + "internlm2": "AutoModelForCausalLM", } diff --git a/awq/models/internlm2.py b/awq/models/internlm2.py new file mode 100644 index 00000000..4ab9b7f7 --- /dev/null +++ b/awq/models/internlm2.py @@ -0,0 +1,76 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM + + +class InternLM2AWQForCausalLM(BaseAWQForCausalLM): + layer_type = "InternLM2DecoderLayer" + max_seq_len_key = "max_position_embeddings" + + @staticmethod + def get_model_layers(model): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module): + return dict( + is_scalable=True, + scale_name="feed_forward.w2", + scale_layer=module.feed_forward.w2, + scale_shape=module.feed_forward.w2.out_features, + ) + + @staticmethod + def move_embed(model, device: str): + model.model.tok_embeddings = model.model.tok_embeddings.to(device) + + @staticmethod + def get_layers_for_scaling(module, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.attention_norm, + layers=[ + module.attention.wqkv, + ], + inp=input_feat["attention.wqkv"], + module2inspect=module.attention, + kwargs=module_kwargs, + ) + ) + + # attention out + layers.append( + dict( + prev_op=module.attention.wqkv, + layers=[module.attention.wo], + inp=input_feat["attention.wo"], + ) + ) + + # feed forward input + layers.append( + dict( + prev_op=module.ffn_norm, + layers=[ + module.feed_forward.w1, + module.feed_forward.w3, + ], + inp=input_feat["feed_forward.w1"], + module2inspect=module.feed_forward, + kwargs=module_kwargs, + ) + ) + + # feed forward output + layers.append( + dict( + prev_op=module.feed_forward.w1, + layers=[module.feed_forward.w2], + inp=input_feat["feed_forward.w2"], + ) + ) + + return layers