Skip to content

Commit

Permalink
Add phi3 support (#481)
Browse files Browse the repository at this point in the history
Co-authored-by: Casper <casperbh.96@gmail.com>
  • Loading branch information
pprp and casper-hansen authored Jun 8, 2024
1 parent 5fa02b5 commit 6a46ad6
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 0 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .gemma import GemmaAWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
from .phi3 import Phi3AWQForCausalLM
from .cohere import CohereAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"gemma": GemmaAWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
"phi3": Phi3AWQForCausalLM,
"cohere": CohereAWQForCausalLM,
}

Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
"phi3": "AutoModelForCausalLM",
"cohere": "AutoModelForCausalLM",
}

Expand Down
128 changes: 128 additions & 0 deletions awq/models/phi3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import Phi3Block
from awq.modules.fused.model import Phi3Model as AWQPhi3Model
from transformers.models.phi3.modeling_phi3 import (
Phi3DecoderLayer as OldPhi3DecoderLayer,
Phi3ForCausalLM as OldPhi3ForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm


class Phi3AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Phi3DecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: OldPhi3ForCausalLM):
fuser = Phi3Fuser(model)
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldPhi3ForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: OldPhi3DecoderLayer):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: OldPhi3ForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: OldPhi3DecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.qkv_proj],
inp=input_feat["self_attn.qkv_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
layers.append(
dict(
prev_op=module.self_attn.qkv_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_up_proj],
inp=input_feat["mlp.gate_up_proj"],
module2inspect=module.mlp,
)
)

# linear 2
layers.append(
dict(
prev_op=module.mlp.gate_up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers


class Phi3Fuser:
def __init__(self, model: OldPhi3ForCausalLM):
self.model = model

self.phi3_blocks: List[Tuple[str, OldPhi3DecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "Phi3DecoderLayer".lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

module: OldPhi3DecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = module.self_attn.qkv_proj
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
Phi3Block(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_position_embeddings,
rope_theta=self.model.config.rope_theta,
rope_scaling=self.model.config.rope_scaling,
)
)

self.model.model = AWQPhi3Model(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
72 changes: 72 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,75 @@ def forward(
out = h_attn + h_mlp

return out, None, past_key_value


class Phi3Block(nn.Module):
"""
Phi3Block is intended to be reused across blocks that have
an architecture that closely resembles Phi-3.
"""

def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
mlp,
norm_1,
norm_2,
dev,
max_seq_len,
rope_theta=10000,
rope_scaling=None,
use_alibi=False,
head_dim=None,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = hidden_size // n_heads

# To support models with separate head_dim
if head_dim:
self.head_dim = head_dim

self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
head_dim=head_dim,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
self.device = dev

def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
)

h = hidden_states.to(attn_output.device) + attn_output
out = h + self.mlp.forward(self.norm_2(h))

return out, None, past_key_value
67 changes: 67 additions & 0 deletions awq/modules/fused/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FalconDecoderLayer,
LlamaLikeBlock,
MixtralBlock,
Phi3Block,
CohereBlock,
)

Expand Down Expand Up @@ -306,3 +307,69 @@ def forward(
hidden_states=(),
attentions=(),
)

class Phi3Model(nn.Module):
"""
Phi3LikeModel is intended to be reused across models that have
an architecture that closely resembles Phi-3.
"""

def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[Phi3Block] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0

@property
def embed_tokens(self):
return self.embedding

@property
def layers(self):
return self.blocks

@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape

fused_utils.prepare_cache(self.blocks, seqlen)

h = self.embedding(input_ids)

mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)

for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, _ = layer(
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm(h)

return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=None,
hidden_states=(),
attentions=(),
)

0 comments on commit 6a46ad6

Please sign in to comment.