Skip to content

Commit

Permalink
fix non patch path
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Dec 16, 2024
1 parent 72ac9e6 commit 3fdb3a5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
4 changes: 3 additions & 1 deletion optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -27,6 +27,7 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXGPT2MLP,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
Expand Down Expand Up @@ -111,6 +112,7 @@ def _patch_gpt2_model(model):
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config)
return model


Expand Down
55 changes: 52 additions & 3 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
LinearAdd,
LinearAddAdd,
LinearGelu,
LinearNewGelu,
PagedAttention,
)

Expand Down Expand Up @@ -557,7 +558,10 @@ def _gpt2_block_forward(
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
if hasattr(self.attn, "linear_add"):
hidden_states = self.attn.linear_add(attn_output, residual)
else:
hidden_states = attn_output + residual

if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
Expand Down Expand Up @@ -586,7 +590,10 @@ def _gpt2_block_forward(
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
if hasattr(self.mlp, "linear_add"):
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
else:
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
Expand Down Expand Up @@ -780,6 +787,13 @@ def __init__(self, module, config) -> None:
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
Expand All @@ -795,7 +809,8 @@ def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
attn_output = self.c_proj(attn_output)
if not hasattr(self, "linear_add"):
attn_output = self.c_proj(attn_output)
return attn_output


Expand Down Expand Up @@ -866,6 +881,40 @@ def forward(
return output


class _IPEXGPT2MLP(nn.Module):
def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
self.c_fc_linear.bias = self.c_fc.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)

if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
if hasattr(self, "linear_new_gelu"):
hidden_states = self.linear_new_gelu(hidden_states)
else:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
if not hasattr(self, "linear_add"):
hidden_states = self.c_proj(hidden_states)
return hidden_states


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, config):
Expand Down

0 comments on commit 3fdb3a5

Please sign in to comment.