Skip to content

Commit

Permalink
add llama model patcher
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Feb 27, 2024
1 parent b04b435 commit f1970a3
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 142 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/ipex/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .model_patcher import export_model
from .model_patcher import LlamaModelPatcher
174 changes: 57 additions & 117 deletions optimum/exporters/ipex/llama_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional, Tuple, Union

import torch
from intel_extension_for_pytorch.llm.modules import linear2SiluMul, linearAdd
from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand Down Expand Up @@ -96,46 +95,6 @@ def llama_attn_forward(
return attn_output, attn_weights, past_key_value


def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs


def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -252,86 +211,67 @@ def llama_model_forward(
)


class _IPEXLlamaDecoderLayerRef(nn.Module):
def __init__(self, module, config, distributed=False):
super().__init__()
for k, v in module.__dict__.items():
setattr(self, k, v)
for k, v in module.__class__.__dict__.items():
if k.startswith("__") or k.startswith("forward"):
continue
setattr(self.__class__, k, getattr(module.__class__, k))
self.distributed = distributed
if not self.distributed:
self.mha_linear_add = linearAdd(module.self_attn.o_proj)
self.mlp_linear_add = linearAdd(module.mlp.down_proj)
del self.__dict__["_modules"]["self_attn"].o_proj
del self.__dict__["_modules"]["mlp"].down_proj
self.linear_silu_mul = linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj)
del self.__dict__["_modules"]["mlp"].gate_proj
del self.__dict__["_modules"]["mlp"].up_proj

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
if not self.distributed:
hidden_states = self.mha_linear_add(hidden_states, residual)
else:
hidden_states = self.self_attn.o_proj(hidden_states)
hidden_states = residual + hidden_states
def llama_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
if hasattr(self, "mha_linear_add"):
hidden_states = self.mha_linear_add(hidden_states, residual)
else:
hidden_states = self.self_attn.o_proj(hidden_states)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

mlp_gate = self.linear_silu_mul(hidden_states)
mlp_gate = self.linear_silu_mul(hidden_states)

if not self.distributed:
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = residual + hidden_states
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)
if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)
if use_cache:
outputs += (present_key_value,)

return outputs
return outputs
140 changes: 123 additions & 17 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from intel_extension_for_pytorch.llm.modules import ApplyRotaryEmbedding, IndirectAccessKVCache
from intel_extension_for_pytorch.llm.modules import (
ApplyRotaryEmbedding,
IndirectAccessKVCache,
linear2SiluMul,
linearAdd,
)
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)

from .llama_functions import (
_IPEXLlamaDecoderLayerRef,
llama_attn_forward,
llama_decoder_layer_forward,
llama_layer_norm_forward,
llama_model_forward,
prepare_inputs_for_generation,
)


Expand All @@ -28,19 +31,42 @@ def convert_functions(m, target_m, new_function_name, new_function):
convert_functions(sub_m, target_m, new_function_name, new_function)


def convert_class(m, target_m, new_class, config, distributed=False):
def patch_op(m, target_m, op_name, op):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, distributed)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config, distributed)
setattr(sub_m, op_name, op)
patch_op(sub_m, target_m, op_name, op)


def patch_op(m, target_m, new_op_name, new_op):
def unpatch_op(m, target_m, op_name):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
setattr(sub_m, new_op_name, new_op)
patch_op(sub_m, target_m, new_op_name, new_op)
delattr(sub_m, op_name)
unpatch_op(sub_m, target_m, op_name)


def patch_linear(m, target_m, linear_name, linear_class, attr_list, attr_list_2=None, distributed=None):
if attr_list_2:
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
attr_1 = sub_m
attr_2 = sub_m
for target_attr in attr_list:
attr_1 = getattr(attr_1, target_attr)
for target_attr in attr_list_2:
attr_2 = getattr(attr_2, target_attr)
setattr(sub_m, linear_name, linear_class(attr_1, attr_2))
patch_linear(sub_m, target_m, linear_name, linear_class, attr_list, attr_list_2)
else:
if isinstance(linear_class, linearAdd) and distributed:
return
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
attr = sub_m
for target_attr in attr_list:
attr = getattr(attr, target_attr)
setattr(sub_m, linear_name, linear_class(attr))
patch_linear(sub_m, target_m, linear_name, linear_class, attr_list)


def export_llama_model(model):
Expand All @@ -54,16 +80,96 @@ def export_llama_model(model):
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)

convert_func(model, "prepare_inputs_for_generation", prepare_inputs_for_generation)
convert_functions(model, LlamaModel, "forward", llama_model_forward)
convert_functions(model, LlamaAttention, "forward", llama_attn_forward)
convert_functions(model, LlamaRMSNorm, "forward", llama_layer_norm_forward)
convert_functions(model, LlamaDecoderLayer, "forward", llama_decoder_layer_forward)

convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
patch_linear(model, LlamaDecoderLayer, "mha_linear_add", linearAdd, ["self_attn", "o_proj"])
patch_linear(model, LlamaDecoderLayer, "mlp_linear_add", linearAdd, ["mlp", "down_proj"])
patch_linear(model, LlamaDecoderLayer, "linear_silu_mul", linear2SiluMul, ["mlp", "gate_proj"], ["mlp", "up_proj"])
return model


def export_model(model):
if isinstance(model, LlamaForCausalLM):
model = export_llama_model(model)
return model
class ModelPatcher:
def __init__(self, model, ipex_ops=None, ipex_functions=None, ipex_linears=None, original_functions=None):
self.model = model
self.ipex_ops = ipex_ops or []
self.ipex_functions = ipex_functions or []
self.ipex_linears = ipex_linears or []
self.original_functions = original_functions or []

def patch_ops(self):
for module, op_name, op in self.ipex_ops:
patch_op(self.model, module, op_name, op)

def unpatch_ops(self):
for module, op_name, op in self.ipex_ops:
unpatch_op(self.model, module, op_name)

def patch_functions(self):
for module, func_name, func in self.ipex_functions:
convert_functions(self.model, module, func_name, func)

def unpatch_functions(self):
for module, func_name, func in self.original_functions:
convert_functions(self.model, module, func_name, func)

def patch_linears(self):
for module, linear_name, linear_class, attr_list, attr_list_2, distributed in self.ipex_linears:
patch_linear(self.model, module, linear_name, linear_class, attr_list, attr_list_2, distributed)

def unpatch_linears(self):
for module, linear_name, linear_class, attr_list, attr_list_2, distributed in self.ipex_linears:
unpatch_op(self.model, module, linear_name)

def __enter__(self):
self.patch_ops()
self.patch_functions()
self.patch_linears()
return self.model

def __exit__(self, *args, **kwargs):
self.unpatch_ops()
self.unpatch_functions()
self.unpatch_linears()
return self.model

def __call__(self, *args, **kwargs):
return self._model(*args, **kwargs)


class LlamaModelPatcher(ModelPatcher):
def __init__(self, model):
super().__init__(model)

ipex_rope = ApplyRotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
self.ipex_ops = [
(LlamaAttention, "ipex_rope", ipex_rope),
(LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product),
]
self.ipex_functions = [
(LlamaModel, "forward", llama_model_forward),
(LlamaAttention, "forward", llama_attn_forward),
(LlamaRMSNorm, "forward", llama_layer_norm_forward),
(LlamaDecoderLayer, "forward", llama_decoder_layer_forward),
]

self.ipex_linears = [
(LlamaDecoderLayer, "mha_linear_add", linearAdd, ["self_attn", "o_proj"], None, None),
(LlamaDecoderLayer, "mlp_linear_add", linearAdd, ["mlp", "down_proj"], None, None),
(LlamaDecoderLayer, "linear_silu_mul", linear2SiluMul, ["mlp", "gate_proj"], ["mlp", "up_proj"], None),
]

self.original_functions = [
(LlamaModel, "forward", model.model.forward),
(LlamaAttention, "forward", model.model.layers[0].self_attn.forward),
(LlamaRMSNorm, "forward", model.model.norm.forward),
(LlamaDecoderLayer, "forward", model.model.layers[0].forward),
]
Loading

0 comments on commit f1970a3

Please sign in to comment.