Skip to content

Commit

Permalink
Add support for Phi-3-vision series model (#596)
Browse files Browse the repository at this point in the history
Co-authored-by: Casper <casperbh.96@gmail.com>
  • Loading branch information
Isotr0py and casper-hansen authored Nov 14, 2024
1 parent 76bc0a8 commit 0187ac1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 0 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .starcoder2 import Starcoder2AWQForCausalLM
from .llava_next import LlavaNextAWQForCausalLM
from .phi3 import Phi3AWQForCausalLM
from .phi3_v import Phi3VAWQForCausalLM
from .cohere import CohereAWQForCausalLM
from .deepseek_v2 import DeepseekV2AWQForCausalLM
from .minicpm import MiniCPMAWQForCausalLM
Expand Down
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"starcoder2": Starcoder2AWQForCausalLM,
"llava_next": LlavaNextAWQForCausalLM,
"phi3": Phi3AWQForCausalLM,
"phi3_v": Phi3VAWQForCausalLM,
"cohere": CohereAWQForCausalLM,
"deepseek_v2": DeepseekV2AWQForCausalLM,
"minicpm": MiniCPMAWQForCausalLM,
Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"starcoder2": "AutoModelForCausalLM",
"llava_next": "AutoModelForVision2Seq",
"phi3": "AutoModelForCausalLM",
"phi3_v": "AutoModelForCausalLM",
"cohere": "AutoModelForCausalLM",
"deepseek_v2": "AutoModelForCausalLM",
"minicpm": "AutoModelForCausalLM",
Expand Down
72 changes: 72 additions & 0 deletions awq/models/phi3_v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.modules.fused.block import Phi3Block
from awq.modules.fused.model import Phi3Model as AWQPhi3Model
from transformers.models.phi3.modeling_phi3 import (
Phi3DecoderLayer as OldPhi3DecoderLayer
)
from awq.modules.fused.norm import FasterTransformerRMSNorm


class Phi3VAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Phi3DecoderLayer"
max_seq_len_key = "max_position_embeddings"
modules_to_not_convert = ["vision_embed_tokens"]

@staticmethod
def get_model_layers(model):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: OldPhi3DecoderLayer):
return dict(is_scalable=False)

@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: OldPhi3DecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.qkv_proj],
inp=input_feat["self_attn.qkv_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
layers.append(
dict(
prev_op=module.self_attn.qkv_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_up_proj],
inp=input_feat["mlp.gate_up_proj"],
module2inspect=module.mlp,
)
)

# linear 2
layers.append(
dict(
prev_op=module.mlp.gate_up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers

0 comments on commit 0187ac1

Please sign in to comment.