diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 03937754a..8c5ef5030 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -14,7 +14,7 @@ from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -27,6 +27,7 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, + _IPEXGPT2MLP, _falcon_model_forward, _gpt2_block_forward, _gpt2_model_forward, @@ -111,6 +112,7 @@ def _patch_gpt2_model(model): convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) + convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index beb162d5c..aa558c437 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -46,6 +46,7 @@ LinearAdd, LinearAddAdd, LinearGelu, + LinearNewGelu, PagedAttention, ) @@ -557,7 +558,10 @@ def _gpt2_block_forward( attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection - hidden_states = attn_output + residual + if hasattr(self.attn, "linear_add"): + hidden_states = self.attn.linear_add(attn_output, residual) + else: + hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention @@ -586,7 +590,10 @@ def _gpt2_block_forward( hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection - hidden_states = residual + feed_forward_hidden_states + if hasattr(self.mlp, "linear_add"): + hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) + else: + hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs @@ -780,6 +787,13 @@ def __init__(self, module, config) -> None: self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def qkv_gemm(self, hidden_states): query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) @@ -795,7 +809,8 @@ def postprocess_attention_output(self, attn_output): if self.use_sdpa: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) - attn_output = self.c_proj(attn_output) + if not hasattr(self, "linear_add"): + attn_output = self.c_proj(attn_output) return attn_output @@ -866,6 +881,40 @@ def forward( return output +class _IPEXGPT2MLP(nn.Module): + def __init__(self, module, config) -> None: + super().__init__() + _setattr_from_module(self, module) + self.config = config + self.module_device = next(module.parameters()).device + self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) + self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) + self.c_fc_linear.bias = self.c_fc.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) + + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + if hasattr(self, "linear_new_gelu"): + hidden_states = self.linear_new_gelu(hidden_states) + else: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + if not hasattr(self, "linear_add"): + hidden_states = self.c_proj(hidden_states) + return hidden_states + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config):