From 5c4d13fe3e6b7cda9bab96f85dc4847118d51bc9 Mon Sep 17 00:00:00 2001 From: Pleaplusone <38376071+ganyi1996ppo@users.noreply.github.com> Date: Mon, 22 Apr 2024 14:06:26 +0800 Subject: [PATCH 01/31] add xpu patch to optimum intel (#7) * add xpu patch to optimum intel * simple path for xpu inference --- optimum/exporters/ipex/model_patcher.py | 25 +- .../exporters/ipex/modeling/modeling_llama.py | 208 ++++++++++++++++ optimum/exporters/ipex/modeling/utils.py | 86 +++++++ optimum/exporters/ipex/modeling/xpu/utils.py | 9 + .../ipex/modeling/xpu/xpu_modeling_llama.py | 226 ++++++++++++++++++ optimum/exporters/ipex/modeling_utils.py | 15 ++ 6 files changed, 561 insertions(+), 8 deletions(-) create mode 100644 optimum/exporters/ipex/modeling/modeling_llama.py create mode 100644 optimum/exporters/ipex/modeling/utils.py create mode 100644 optimum/exporters/ipex/modeling/xpu/utils.py create mode 100644 optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 60ff3b721b..3ff3e00302 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -29,6 +29,9 @@ _llama_model_forward, ) +# from modeling.utils import _IPEXPatcher +from modeling.modeling_llama import _IPEXLlamaDecoderLayer + _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) _IPEX_EXPORTED_TASK = ("text-generation",) @@ -73,19 +76,25 @@ def _patch_llama_model(model): model.config.rope_theta, model.config.architectures[0], ) - ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) - patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) - - convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) - convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) - - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) + if model.device_type == "cpu": + ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) + patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) + + convert_functions(model, LlamaModel, "forward", _llama_model_forward) + convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) + convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) + + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) + else: + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer) + convert_functions(model, LlamaModel, "forward", _llama_model_forward) return model def _patch_model(model): if isinstance(model, LlamaForCausalLM): model = _patch_llama_model(model) + # _IPEXPatcher.patch_model(model) + return model diff --git a/optimum/exporters/ipex/modeling/modeling_llama.py b/optimum/exporters/ipex/modeling/modeling_llama.py new file mode 100644 index 0000000000..92d28a6381 --- /dev/null +++ b/optimum/exporters/ipex/modeling/modeling_llama.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +from typing import Optional, Tuple +import intel_extension_for_pytorch as ipex + + + +class _IPEXLlamaAttention(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + self.module = module + self.config = config + self.distributed = distributed + + def preprocess_for_optimize(self, hidden_states, layer_past, **kwargs): + pass + + def qkv_gemm(self, hidden_states, **kwargs): + pass + + def rope(self, query, key, value, position_ids, layer_past, **kwargs): + pass + + def get_present(self, query, key, value, use_cache, **kwargs): + pass + + def sdpa(self, query, key, value, attention_mask, past_key_value, **kwargs): + pass + + def out_proj(self, hidden_states, residual, **kwargs): + pass + + def post_process_for_optimize(self): + pass + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + residual (`torch.Tensor`): residual tensor to the layer of shape ` + """ + + self.preprocess_for_optimize(hidden_states, past_key_value, kwargs) + + query, key, value = self.qkv_gemm(hidden_states, kwargs) + + key, value = self.rope(key, value, position_ids, past_key_value, kwargs) + + present = self.get_present(query, key, value, use_cache) + + attn_output, attn_weight = self.sdpa(query, key, value, attention_mask, past_key_value, kwargs) + + attn_output = self.out_proj(attn_output, residual) + + self.post_process_for_optimize() + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weight, ) + else: + outputs += (None, ) + return outputs + +class _IPEXLlamaMLP(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + self.module = module + self.config = config + self.distributed = distributed + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + **kwargs + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + pass + + +class _IPEXLlamaDecoderLayer(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + super().__init__() + self.layer_idx = module.self_attn.layer_idx + self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) + self.input_layernorm = ipex.llm.modules.RMSNorm(module.input_layernorm.weight, module.input_layernorm.variance_epsilon) + self.post_attention_layernorm = ipex.llm.modules.RMSNorm(module.post_attention_layernorm.weight, module.post_attention_layernorm.variance_epsilon) + + + + def preprocess_for_optimize( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + postion_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attention: Optional[bool] = True, + use_cache: Optional[bool] = False, + **kwargs + ): + return hidden_states, attention_mask, postion_ids, past_key_value + + def postprocess_for_optimize( + self, + hidden_states, + output_attention, + use_cache, + self_attn_weight, + present_key_value, + **kwargs + ): + outputs = (hidden_states,) + if output_attention: + outputs += (self_attn_weight,) + if use_cache: + outputs += (present_key_value,) + return outputs + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + outputs = self.preprocess_for_optimize( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + kwargs + ) + (hidden_states, attention_mask, position_ids, past_key_value) = outputs + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weight, present_key_value = self.attn( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + residual, + kwargs, + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, residual, kwargs) + + outputs = self.postprocess_for_optimize( + hidden_states, + output_attentions, + use_cache, + self_attn_weight, + present_key_value, + kwargs + ) + + return outputs + + + diff --git a/optimum/exporters/ipex/modeling/utils.py b/optimum/exporters/ipex/modeling/utils.py new file mode 100644 index 0000000000..acfe88bab4 --- /dev/null +++ b/optimum/exporters/ipex/modeling/utils.py @@ -0,0 +1,86 @@ +import torch +import intel_extension_for_pytorch +from typing import List +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) + +from .xpu.utils import update_patcher_info_on_xpu + +def update_patcher_info_on_cpu(model_name): + pass + + +class _IPEXPatcher: + def __init__(self): + self.op_patch_list: List = [] + self.function_convert_list: List = [] + self.class_convert_list: List = [] + + def update_op_list(self, op_list): + self.op_patch_list.extend(op_list) + + def update_function_convert_list(self, function_converts): + self.function_convert_list.extend(function_converts) + + def update_class_convert_list(self, class_converts): + self.class_convert_list.extend(class_converts) + + def patch_op_recursive(self, model): + + def patch_op(model, target_m, new_op_name, new_op): + for name, sub_m in model.named_children(): + if isinstance(sub_m, target_m): + setattr(sub_m, new_op_name, new_op) + patch_op(sub_m, target_m, new_op_name, new_op) + + for op_patch in self.op_patch_list: + target_m, new_op_name, new_op = op_patch + new_op_inst = new_op(model) + patch_op(model, target_m, new_op_name, new_op_inst) + + def convert_function_recursive(self, model): + + def convert_functions(m, target_m, new_function_name, new_function): + for _, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + bound_method = new_function.__get__(sub_m, sub_m.__class__) + setattr(m, new_function_name, bound_method) + convert_functions(sub_m, target_m, new_function_name, new_function) + + for function_convert in self.function_convert_list: + target_m, new_function_name, new_function = function_convert + convert_functions(model, target_m, new_function_name, new_function) + + def convert_class_recursive(self, model): + + def convert_class(m, target_m, new_class, config, distributed=False): + for name, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + new_m = new_class(sub_m, config, distributed) + setattr(m, name, new_m) + convert_class(sub_m, target_m, new_class, config, distributed) + + for class_convert in self.class_convert_list: + target_m, new_class, config, distributed = class_convert + convert_class(model, target_m, new_class, config, distributed) + + def retrive_patch_info(self, model_name, device): + if device.device_type == "xpu": + update_patcher_info_on_xpu(model_name) + elif device.device_type == "cpu": + update_patcher_info_on_cpu(model_name) + else: + raise RuntimeError(f"Optimum-intel only support CPU and XPU device optimization. But we find this model on {device}.") + + def patch_model(self, model): + # if isinstance(model, LlamaForCausalLM): + self.retrive_patch_info(model.__class__.name, model.device) + self.patch_op_recursive(model) + self.convert_function_recursive(model) + self.convert_class_recursive(model) + diff --git a/optimum/exporters/ipex/modeling/xpu/utils.py b/optimum/exporters/ipex/modeling/xpu/utils.py new file mode 100644 index 0000000000..24dd107dad --- /dev/null +++ b/optimum/exporters/ipex/modeling/xpu/utils.py @@ -0,0 +1,9 @@ + +from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.ModuleReplacer import ModuleReplacer + +def update_patcher_info_on_xpu(patcher, model_name): + patch_info = ModuleReplacer.get_patch_info_from_model(model_name) + op_patch_list, function_convert_list, class_convert_list = patch_info + patcher.update_op_list(op_patch_list) + patcher.update_function_convert_list(function_convert_list) + patcher.update_class_convert_list(class_convert_list) \ No newline at end of file diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py new file mode 100644 index 0000000000..30b6a7e6d5 --- /dev/null +++ b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py @@ -0,0 +1,226 @@ +from typing import Tuple +import torch +import torch.nn as nn +from typing import Optional +import math + +import intel_extension_for_pytorch +from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.modules.llama import NewIPEXLLAMABlock + +from ..modeling_llama import _IPEXLlamaDecoderLayer, _IPEXLlamaAttention, _IPEXLlamaMLP + + +def matmul_add_add(attn_output, weight, bias=None, residual=None): + seq_len, bs, _ = attn_output.size() + if residual is None: + attn_output = torch.matmul(attn_output, weight) + if bias is not None: + attn_output += bias + else: + if bias is not None: + attn_output = torch.ops.torch_ipex.mm_bias_resadd( + attn_output, weight, bias, 1.0, residual, 1.0 + ) + else: + attn_output = torch.addmm( + residual.flatten(0, -2), + attn_output.flatten(0, -2), + weight, + beta=1.0, + ) + attn_output = attn_output.view(seq_len, bs, -1) + return attn_output + +class _IPEXLlamaAttentionXPU(_IPEXLlamaAttention): + def __init__(self, module, config, distributed=False, optimized_module=None) -> None: + super().__init__(module, config, distributed) + self.num_heads = module.num_heads + self.head_dim = module.head_dim + self.num_kv_heads = module.num_key_value_heads + self.embed_dim = module.embed_dim + self.port_parameters(module) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + residual (`torch.Tensor`): residual tensor to the layer of shape ` + """ + # allocate cache and copy past_key_value + bs, seqlen, _ = hidden_states.size() + _, prev_seqlen, _, _ = past_key_value[0].size() + if self.num_kv_heads == self.num_heads: + query = torch.empty_like([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) + key = torch.empty_like([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) + value = torch.empty_like([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) + torch.ops.torch_ipex.mm_qkv_out( + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query[:, prev_seqlen:, :, :], key[:, prev_seqlen:, :, :], value[:, prev_seqlen:, :, :]) + else: + query = torch.empty([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) + key = torch.empty([bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim], dtype=query.dtype, device=query.device) + value = torch.empty([bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim], dtype=query.dtype, device=query.device) + torch.ops.torch_ipex.mm_qkv_group_out( + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value) + key[:, :prev_seqlen, :, :] = past_key_value[0].tranpose(1, 2) + value[:, :prev_seqlen, :, :] = past_key_value[1].tranpose(1, 2) + + # rope + query = query.view([-1, seqlen, self.num_heads, self.head_dim]) + key = key.view([-1, seqlen, self.num_kv_heads, self.head_dim]) + value = value.view([-1, seqlen, self.num_kv_heads, self.head_dim]) + + query = self.ipex_rope( + query, + position_ids, + self.num_kv_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + seqlen, + ) + + key = self.ipex_rope( + key, + position_ids, + self.num_kv_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + seqlen, + ) + + key = key.tranpose(1, 2) + value = value.tranpose(1, 2) + present = (key, value) if use_cache else None + + scale = 1.0 / math.sqrt(self.head_dim) + attn_output, attn_weight = torch.nn.functional.scaled_dot_product_attention(query, key, value, attention_mask, dropout_p=0.0, scale=scale) + attn_output = attn_output.tranpose(1, 2).view([bs, seqlen, self.embed_dim]) + attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view([bs, seqlen, self.embed_dim]) + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weight, ) + else: + outputs += (None) + return outputs + + + def port_parameters(self, module): + self.qkv_proj_bias = None + self.qkv_proj_weight = None + if self.num_heads == self.num_kv_heads: + q_proj = module.self_attn.q_proj.weight.contiguous(0, 1) + k_proj = module.self_attn.k_proj.weight.contiguous(0, 1) + v_proj = module.self_attn.v_proj.weight.contiguous(0, 1) + self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) + if module.self_attn.q_proj.bias is not None: + self.qkv_proj_bias = torch.stack([module.self_attn.q_proj.bias, module.self_attn.k_proj.bias, module.self_attn.v_proj.bias]).contiguous().view([3, -1]) + else: + group = self.num_heads // self.num_kv_heads + q_proj = module.self_attn.q_proj.weight.view(self.num_kv_heads, group, self.head_dim, self.embed_dim) + k_proj = module.self_attn.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + v_proj = module.self_attn.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view([self.num_kv_heads, group + 2, self.head_dim, self.embed_dim]) + if module.self_attn.q_proj.bias is not None: + q_bias = module.self_attn.q_proj.bias.view(self.num_kv_heads, group, self.head_dim) + k_bias = module.self_attn.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) + v_bias = module.self_attn.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) + self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view([self.num_kv_heads, group + 2, self.head_dim]) + self.o_proj_weight = module.o_proj.weight + self.o_proj_bias = module.o_proj.bias + + + +class _IPEXLlamaMLPXPU(_IPEXLlamaMLP): + def __init__(self, module, config, distributed=False, optimized_module=None) -> None: + super().__init__(module, config, distributed) + self.mlp_impl = None + if optimized_module is not None: + self.mlp_impl = optimized_module + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor = None, + **kwargs + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + out = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + out = matmul_add_add(out, self.down_proj_weight, self.down_proj_bias, residual) + return out + + + def port_parameter(self, module): + self.up_proj_weight = module.up_proj.weight.tranpose(0, 1).contiguous() + self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() + self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() + self.up_proj_bias = module.up_proj.bias + self.gate_proj_bias = module.gate_proj.bias + self.down_proj_bias = module.down_proj.bias + + + +# class _IPEXLlamaDecoderLayerXPU(_IPEXLlamaDecoderLayer): +# def __init__(self, module, config, distributed=False) -> None: +# super().__init__(module, config, distributed) +# self.block_impl = NewIPEXLLAMABlock(module, config) +# self.attn = _IPEXLlamaAttentionXPU(module.self_attn, config, self.block_impl.attn) +# self.mlp = _IPEXLlamaMLPXPU(module.mlp, config, self.block_impl.mlp) + +# def preprocess_for_optimize( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_value: Optional[Tuple[torch.Tensor]] = None, +# output_attention: Optional[bool] = True, +# use_cache: Optional[bool] = False, +# **kwargs +# ): +# return self.block_impl.preprocess_for_optimize( +# hidden_states, +# attention_mask, +# position_ids, +# past_key_value, +# output_attention, +# use_cache, +# **kwargs +# ) + + + +# def postprocess_for_optimize(self, hidden_states, output_attention, use_cache, self_attn_weight, present_key_value, **kwargs): +# return self.block_impl.postprocess_for_optimize( +# hidden_states, +# output_attention, +# use_cache, +# self_attn_weight, +# present_key_value, +# **kwargs +# ) + diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f75e559eaf..45345f8538 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -112,6 +112,19 @@ def _llama_attn_forward( return attn_output, attn_weights, past_key_value +def padding_attn_mask(attn_mask, alignment): + if attn_mask is None: + return None + assert isinstance(attn_mask, torch.Tensor), f"attn mask is supposed to be a tensor, instead we got {type(attn_mask)}" + if attn_mask.device == torch.device("cpu"): + return attn_mask + last_dim_size= attn_mask.size(-1) + aligned_size = (last_dim_size + alignment - 1) // alignment * alignment + mask_size = [*attn_mask.size()[:-1], aligned_size] + new_attn_mask = torch.empty(mask_size, dtype=attn_mask.dtype, device=attn_mask.device).fill_(-65504.0) + new_attn_mask[..., :last_dim_size] = attn_mask + return new_attn_mask + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( @@ -168,6 +181,8 @@ def _llama_model_forward( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + attention_mask = padding_attn_mask(attention_mask, 8) + # embed positions hidden_states = inputs_embeds From b1d6989da03e2b21bc7e1896c3b6843797ce8ce2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 25 Apr 2024 05:40:31 +0000 Subject: [PATCH 02/31] can run but precision error --- optimum/exporters/ipex/model_patcher.py | 13 +-- optimum/exporters/ipex/modeling/__init__.py | 2 + .../exporters/ipex/modeling/modeling_llama.py | 55 +++++++++---- .../ipex/modeling/xpu/xpu_modeling_llama.py | 80 +++++++++++-------- optimum/exporters/ipex/modeling_utils.py | 10 ++- 5 files changed, 103 insertions(+), 57 deletions(-) create mode 100644 optimum/exporters/ipex/modeling/__init__.py diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 3ff3e00302..0058e09fff 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -30,7 +30,7 @@ ) # from modeling.utils import _IPEXPatcher -from modeling.modeling_llama import _IPEXLlamaDecoderLayer +from .modeling.modeling_llama import _IPEXLlamaDecoderLayer _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) @@ -65,10 +65,10 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - if is_ipex_version("<", "2.5.0"): - raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache") + if is_ipex_version("<", "2.1.0"): + raise ImportError("Only ipex version > 2.1.0 supports RotaryEmbedding and IndirectAccessKVCache") - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding + from intel_extension_for_pytorch.llm.modules import RotaryEmbedding ipex_rope = RotaryEmbedding( model.config.max_position_embeddings, @@ -77,7 +77,8 @@ def _patch_llama_model(model): model.config.architectures[0], ) patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) - if model.device_type == "cpu": + if "cpu" in str(model.device): + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) @@ -87,7 +88,7 @@ def _patch_llama_model(model): convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) else: - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer) + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) convert_functions(model, LlamaModel, "forward", _llama_model_forward) return model diff --git a/optimum/exporters/ipex/modeling/__init__.py b/optimum/exporters/ipex/modeling/__init__.py new file mode 100644 index 0000000000..139597f9cb --- /dev/null +++ b/optimum/exporters/ipex/modeling/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/optimum/exporters/ipex/modeling/modeling_llama.py b/optimum/exporters/ipex/modeling/modeling_llama.py index 92d28a6381..6e9216d645 100644 --- a/optimum/exporters/ipex/modeling/modeling_llama.py +++ b/optimum/exporters/ipex/modeling/modeling_llama.py @@ -7,11 +7,18 @@ class _IPEXLlamaAttention(nn.Module): def __init__(self, module, config, distributed=False) -> None: + super().__init__() + # for k, v in module.__dict__.items(): + # setattr(self, k, v) + # for k, v in module.__class__.__dict__.items(): + # if k.startswith("__") or k.startswith("forward"): + # continue + # setattr(self.__class__, k, getattr(module.__class__, k)) self.module = module self.config = config self.distributed = distributed - def preprocess_for_optimize(self, hidden_states, layer_past, **kwargs): + def epreprocess_for_optimize(self, hidden_states, layer_past, **kwargs): pass def qkv_gemm(self, hidden_states, **kwargs): @@ -59,10 +66,10 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states residual (`torch.Tensor`): residual tensor to the layer of shape ` """ + + self.preprocess_for_optimize(hidden_states=hidden_states, layer_past=past_key_value, **kwargs) - self.preprocess_for_optimize(hidden_states, past_key_value, kwargs) - - query, key, value = self.qkv_gemm(hidden_states, kwargs) + query, key, value = self.qkv_gemm(hidden_states= hidden_states, **kwargs) key, value = self.rope(key, value, position_ids, past_key_value, kwargs) @@ -82,6 +89,13 @@ def forward( class _IPEXLlamaMLP(nn.Module): def __init__(self, module, config, distributed=False) -> None: + super().__init__() + # for k, v in module.__dict__.items(): + # setattr(self, k, v) + # for k, v in module.__class__.__dict__.items(): + # if k.startswith("__") or k.startswith("forward"): + # continue + # setattr(self.__class__, k, getattr(module.__class__, k)) self.module = module self.config = config self.distributed = distributed @@ -102,14 +116,24 @@ def forward( class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config, distributed=False) -> None: super().__init__() + # for k, v in module.__dict__.items(): + # setattr(self, k, v) + # for k, v in module.__class__.__dict__.items(): + # if k.startswith("__") or k.startswith("forward"): + # continue + # setattr(self.__class__, k, getattr(module.__class__, k)) self.layer_idx = module.self_attn.layer_idx - self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) + # TODO: add device check + if False: + self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) + else: + from .xpu.xpu_modeling_llama import _IPEXLlamaAttentionXPU, _IPEXLlamaMLPXPU + self.attn = _IPEXLlamaAttentionXPU(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLPXPU(module.mlp, config, distributed) self.input_layernorm = ipex.llm.modules.RMSNorm(module.input_layernorm.weight, module.input_layernorm.variance_epsilon) self.post_attention_layernorm = ipex.llm.modules.RMSNorm(module.post_attention_layernorm.weight, module.post_attention_layernorm.variance_epsilon) - - def preprocess_for_optimize( self, hidden_states: torch.Tensor, @@ -132,10 +156,11 @@ def postprocess_for_optimize( **kwargs ): outputs = (hidden_states,) - if output_attention: - outputs += (self_attn_weight,) if use_cache: outputs += (present_key_value,) + if output_attention: + outputs += (self_attn_weight,) + return outputs @@ -170,14 +195,14 @@ def forward( past_key_value, output_attentions, use_cache, - kwargs + **kwargs ) (hidden_states, attention_mask, position_ids, past_key_value) = outputs residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weight, present_key_value = self.attn( + hidden_states, present_key_value, self_attn_weight = self.attn( hidden_states, attention_mask, position_ids, @@ -185,13 +210,13 @@ def forward( output_attentions, use_cache, residual, - kwargs, + **kwargs, ) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states, residual, kwargs) + hidden_states = self.mlp(hidden_states, residual, **kwargs) outputs = self.postprocess_for_optimize( hidden_states, @@ -199,7 +224,7 @@ def forward( use_cache, self_attn_weight, present_key_value, - kwargs + **kwargs ) return outputs diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py index 30b6a7e6d5..24a4254e6a 100644 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py @@ -37,8 +37,16 @@ def __init__(self, module, config, distributed=False, optimized_module=None) -> self.num_heads = module.num_heads self.head_dim = module.head_dim self.num_kv_heads = module.num_key_value_heads - self.embed_dim = module.embed_dim + self.embed_dim = module.config.hidden_size self.port_parameters(module) + from intel_extension_for_pytorch.llm.modules import RotaryEmbedding + + self.ipex_rope = RotaryEmbedding( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) def forward( self, @@ -69,26 +77,29 @@ def forward( """ # allocate cache and copy past_key_value bs, seqlen, _ = hidden_states.size() - _, prev_seqlen, _, _ = past_key_value[0].size() + prev_seqlen = 0 + if past_key_value: + _, _, prev_seqlen, _ = past_key_value[0].size() if self.num_kv_heads == self.num_heads: - query = torch.empty_like([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) - key = torch.empty_like([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) - value = torch.empty_like([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) + query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + key = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + value = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) torch.ops.torch_ipex.mm_qkv_out( - hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query[:, prev_seqlen:, :, :], key[:, prev_seqlen:, :, :], value[:, prev_seqlen:, :, :]) + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key[:, prev_seqlen:, :], value[:, prev_seqlen:, :]) else: - query = torch.empty([bs, prev_seqlen + seqlen, self.num_heads * self.head_dim], dtype=query.dtype, device=query.device) - key = torch.empty([bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim], dtype=query.dtype, device=query.device) - value = torch.empty([bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim], dtype=query.dtype, device=query.device) + query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + key = torch.empty((bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + value = torch.empty((bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) torch.ops.torch_ipex.mm_qkv_group_out( hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value) - key[:, :prev_seqlen, :, :] = past_key_value[0].tranpose(1, 2) - value[:, :prev_seqlen, :, :] = past_key_value[1].tranpose(1, 2) + if past_key_value: + key[:, :prev_seqlen, :] = past_key_value[0].transpose(1, 2).view(bs, prev_seqlen, -1) + value[:, :prev_seqlen, :] = past_key_value[1].transpose(1, 2).view(bs, prev_seqlen, -1) # rope - query = query.view([-1, seqlen, self.num_heads, self.head_dim]) - key = key.view([-1, seqlen, self.num_kv_heads, self.head_dim]) - value = value.view([-1, seqlen, self.num_kv_heads, self.head_dim]) + #query = query.view([-1, seqlen, self.num_heads, self.head_dim]) + #key = key.view([-1, seqlen, self.num_kv_heads, self.head_dim]) + value = value.view([bs, prev_seqlen + seqlen, self.num_kv_heads, self.head_dim]) query = self.ipex_rope( query, @@ -110,19 +121,21 @@ def forward( seqlen, ) - key = key.tranpose(1, 2) - value = value.tranpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) present = (key, value) if use_cache else None scale = 1.0 / math.sqrt(self.head_dim) - attn_output, attn_weight = torch.nn.functional.scaled_dot_product_attention(query, key, value, attention_mask, dropout_p=0.0, scale=scale) - attn_output = attn_output.tranpose(1, 2).view([bs, seqlen, self.embed_dim]) + attn_output = torch.xpu.IpexSDP(query.transpose(1,2), key, value, None, attention_mask, None, scale, 1.0, 0.0, True, False) + # attn_output, attn_weight = torch.nn.functional.scaled_dot_product_attention(query, key, value, attention_mask, dropout_p=0.0, scale=scale) + attn_output = attn_output.transpose(1, 2).view([bs, seqlen, self.embed_dim]) attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view([bs, seqlen, self.embed_dim]) outputs = (attn_output, present) if output_attentions: - outputs += (attn_weight, ) + raise ValueError("not support output attn_weight") + # outputs += (attn_weight, ) else: - outputs += (None) + outputs += (None, ) return outputs @@ -130,22 +143,22 @@ def port_parameters(self, module): self.qkv_proj_bias = None self.qkv_proj_weight = None if self.num_heads == self.num_kv_heads: - q_proj = module.self_attn.q_proj.weight.contiguous(0, 1) - k_proj = module.self_attn.k_proj.weight.contiguous(0, 1) - v_proj = module.self_attn.v_proj.weight.contiguous(0, 1) + q_proj = module.q_proj.weight.transpose(0, 1) + k_proj = module.k_proj.weight.transpose(0, 1) + v_proj = module.v_proj.weight.transpose(0, 1) self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) - if module.self_attn.q_proj.bias is not None: - self.qkv_proj_bias = torch.stack([module.self_attn.q_proj.bias, module.self_attn.k_proj.bias, module.self_attn.v_proj.bias]).contiguous().view([3, -1]) + if module.q_proj.bias is not None: + self.qkv_proj_bias = torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]).contiguous().view([3, -1]) else: group = self.num_heads // self.num_kv_heads - q_proj = module.self_attn.q_proj.weight.view(self.num_kv_heads, group, self.head_dim, self.embed_dim) - k_proj = module.self_attn.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - v_proj = module.self_attn.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + q_proj = module.q_proj.weight.view(self.num_kv_heads, group, self.head_dim, self.embed_dim) + k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view([self.num_kv_heads, group + 2, self.head_dim, self.embed_dim]) - if module.self_attn.q_proj.bias is not None: - q_bias = module.self_attn.q_proj.bias.view(self.num_kv_heads, group, self.head_dim) - k_bias = module.self_attn.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) - v_bias = module.self_attn.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) + if module.q_proj.bias is not None: + q_bias = module.q_proj.bias.view(self.num_kv_heads, group, self.head_dim) + k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) + v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view([self.num_kv_heads, group + 2, self.head_dim]) self.o_proj_weight = module.o_proj.weight self.o_proj_bias = module.o_proj.bias @@ -158,6 +171,7 @@ def __init__(self, module, config, distributed=False, optimized_module=None) -> self.mlp_impl = None if optimized_module is not None: self.mlp_impl = optimized_module + self.port_parameter(module) def forward( self, @@ -176,7 +190,7 @@ def forward( def port_parameter(self, module): - self.up_proj_weight = module.up_proj.weight.tranpose(0, 1).contiguous() + self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() self.up_proj_bias = module.up_proj.bias diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 45345f8538..a3503c3f29 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -110,7 +110,7 @@ def _llama_attn_forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, past_key_value, attn_weights def padding_attn_mask(attn_mask, alignment): if attn_mask is None: @@ -191,11 +191,15 @@ def _llama_model_forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + # XPU + #if True: + # past_key_values = [] + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None and len(past_key_values) > idx else None layer_outputs = decoder_layer( hidden_states, @@ -285,7 +289,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, present_key_value, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, From f2de914dfec1e6cd7a5c10084e7263fd4c0f0b10 Mon Sep 17 00:00:00 2001 From: "yi.gan" Date: Thu, 25 Apr 2024 19:54:23 -0700 Subject: [PATCH 03/31] optimize optimum --- .../exporters/ipex/modeling/modeling_llama.py | 1 + .../ipex/modeling/xpu/xpu_modeling_llama.py | 104 +++++++++++------- 2 files changed, 66 insertions(+), 39 deletions(-) diff --git a/optimum/exporters/ipex/modeling/modeling_llama.py b/optimum/exporters/ipex/modeling/modeling_llama.py index 6e9216d645..9d3204c766 100644 --- a/optimum/exporters/ipex/modeling/modeling_llama.py +++ b/optimum/exporters/ipex/modeling/modeling_llama.py @@ -209,6 +209,7 @@ def forward( past_key_value, output_attentions, use_cache, + None, residual, **kwargs, ) diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py index 24a4254e6a..752e226f2f 100644 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py @@ -6,6 +6,7 @@ import intel_extension_for_pytorch from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.modules.llama import NewIPEXLLAMABlock +from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU from ..modeling_llama import _IPEXLlamaDecoderLayer, _IPEXLlamaAttention, _IPEXLlamaMLP @@ -41,12 +42,18 @@ def __init__(self, module, config, distributed=False, optimized_module=None) -> self.port_parameters(module) from intel_extension_for_pytorch.llm.modules import RotaryEmbedding - self.ipex_rope = RotaryEmbedding( - module.config.max_position_embeddings, - module.config.hidden_size // module.config.num_attention_heads, - module.config.rope_theta, - module.config.architectures[0], - ) + # self.ipex_rope = RotaryEmbedding( + # module.config.max_position_embeddings, + # module.config.hidden_size // module.config.num_attention_heads, + # module.config.rope_theta, + # module.config.architectures[0], + # ) + self.ipex_rope = _IPEXRopeXPU( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) def forward( self, @@ -81,55 +88,74 @@ def forward( if past_key_value: _, _, prev_seqlen, _ = past_key_value[0].size() if self.num_kv_heads == self.num_heads: - query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - key = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - value = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + # query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + # key = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + # value = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + query = torch.empty_like(hidden_states) + key = torch.empty_like(hidden_states) + value = torch.empty_like(hidden_states) torch.ops.torch_ipex.mm_qkv_out( - hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key[:, prev_seqlen:, :], value[:, prev_seqlen:, :]) + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value) + # torch.ops.torch_ipex.mm_qkv_out( + # hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key[:, prev_seqlen:, :], value[:, prev_seqlen:, :]) else: query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - key = torch.empty((bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - value = torch.empty((bs, prev_seqlen + seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + key = torch.empty((bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + value = torch.empty((bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) torch.ops.torch_ipex.mm_qkv_group_out( hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value) - if past_key_value: - key[:, :prev_seqlen, :] = past_key_value[0].transpose(1, 2).view(bs, prev_seqlen, -1) - value[:, :prev_seqlen, :] = past_key_value[1].transpose(1, 2).view(bs, prev_seqlen, -1) - + # if past_key_value: + # key[:, :prev_seqlen, :] = past_key_value[0].transpose(1, 2).view(bs, prev_seqlen, -1) + # value[:, :prev_seqlen, :] = past_key_value[1].transpose(1, 2).view(bs, prev_seqlen, -1) # rope #query = query.view([-1, seqlen, self.num_heads, self.head_dim]) #key = key.view([-1, seqlen, self.num_kv_heads, self.head_dim]) - value = value.view([bs, prev_seqlen + seqlen, self.num_kv_heads, self.head_dim]) - - query = self.ipex_rope( - query, - position_ids, - self.num_kv_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - seqlen, - ) - - key = self.ipex_rope( - key, - position_ids, - self.num_kv_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - seqlen, - ) + # value = value.view([bs, prev_seqlen + seqlen, self.num_kv_heads, self.head_dim]) + query = query.view([bs, seqlen, self.num_heads, self.head_dim]) + key = key.view([bs, seqlen, self.num_kv_heads, self.head_dim]) + sin, cos = self.ipex_rope.get_sin_cos(seqlen, self.head_dim // 2) + sin = sin.squeeze()[position_ids].unsqueeze(2) + cos = cos.squeeze()[position_ids].unsqueeze(2) + self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) + # query = self.ipex_rope( + # query, + # position_ids, + # self.num_kv_heads, + # self.head_dim, + # self.head_dim // 2, + # self.head_dim, + # seqlen, + # ) + # key = self.ipex_rope( + # key, + # position_ids, + # self.num_kv_heads, + # self.head_dim, + # self.head_dim // 2, + # self.head_dim, + # seqlen, + # ) + value = value.view([bs, seqlen, self.num_kv_heads, self.head_dim]) + if past_key_value is not None: + value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) + key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) present = (key, value) if use_cache else None scale = 1.0 / math.sqrt(self.head_dim) - attn_output = torch.xpu.IpexSDP(query.transpose(1,2), key, value, None, attention_mask, None, scale, 1.0, 0.0, True, False) + # import pdb;pdb.set_trace() + is_causal = False + # if past_key_value is not None: + # is_causal = True + attn_output = torch.xpu.IpexSDP(query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False) # attn_output, attn_weight = torch.nn.functional.scaled_dot_product_attention(query, key, value, attention_mask, dropout_p=0.0, scale=scale) attn_output = attn_output.transpose(1, 2).view([bs, seqlen, self.embed_dim]) attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view([bs, seqlen, self.embed_dim]) + # attn_output = torch.matmul(attn_output, self.o_proj_weight) + # attn_output = attn_output + residual outputs = (attn_output, present) if output_attentions: raise ValueError("not support output attn_weight") @@ -160,7 +186,7 @@ def port_parameters(self, module): k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view([self.num_kv_heads, group + 2, self.head_dim]) - self.o_proj_weight = module.o_proj.weight + self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() self.o_proj_bias = module.o_proj.bias From 92954572bc80cbe2186d006945675305e799f3c2 Mon Sep 17 00:00:00 2001 From: "yi.gan" Date: Thu, 25 Apr 2024 23:28:16 -0700 Subject: [PATCH 04/31] further optimize --- .../ipex/modeling/xpu/xpu_modeling_llama.py | 40 ++++++++++++------- optimum/exporters/ipex/modeling_utils.py | 8 +++- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py index 752e226f2f..7e8a1e8c0f 100644 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py @@ -88,14 +88,14 @@ def forward( if past_key_value: _, _, prev_seqlen, _ = past_key_value[0].size() if self.num_kv_heads == self.num_heads: - # query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - # key = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - # value = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - query = torch.empty_like(hidden_states) - key = torch.empty_like(hidden_states) - value = torch.empty_like(hidden_states) + query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + key = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + value = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + # query = torch.empty_like(hidden_states) + # key = torch.empty_like(hidden_states) + # value = torch.empty_like(hidden_states) torch.ops.torch_ipex.mm_qkv_out( - hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value) + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key[:, prev_seqlen:, :], value[:, prev_seqlen:, :]) # torch.ops.torch_ipex.mm_qkv_out( # hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key[:, prev_seqlen:, :], value[:, prev_seqlen:, :]) else: @@ -112,11 +112,19 @@ def forward( #key = key.view([-1, seqlen, self.num_kv_heads, self.head_dim]) # value = value.view([bs, prev_seqlen + seqlen, self.num_kv_heads, self.head_dim]) query = query.view([bs, seqlen, self.num_heads, self.head_dim]) - key = key.view([bs, seqlen, self.num_kv_heads, self.head_dim]) - sin, cos = self.ipex_rope.get_sin_cos(seqlen, self.head_dim // 2) - sin = sin.squeeze()[position_ids].unsqueeze(2) - cos = cos.squeeze()[position_ids].unsqueeze(2) - self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) + key = key.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim]) + # sin, cos = self.ipex_rope.get_sin_cos(seqlen, self.head_dim // 2) + # sin = sin.squeeze()[position_ids].unsqueeze(2) + # cos = cos.squeeze()[position_ids].unsqueeze(2) + if hasattr(kwargs, "sin") and hasattr(kwargs, "cos"): + print("cache sin cos") + sin = kwargs["sin"] + cos = kwargs["cos"] + else: + sin, cos = self.ipex_rope.get_sin_cos(seqlen, self.head_dim // 2) + sin = sin.squeeze()[position_ids].unsqueeze(2) + cos = cos.squeeze()[position_ids].unsqueeze(2) + self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key[:, prev_seqlen:, :, :]) # query = self.ipex_rope( # query, # position_ids, @@ -136,10 +144,12 @@ def forward( # self.head_dim, # seqlen, # ) - value = value.view([bs, seqlen, self.num_kv_heads, self.head_dim]) + value = value.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim]) if past_key_value is not None: - value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) - key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + value[:, :prev_seqlen, :, :] = past_key_value[1].transpose(1, 2) + key[:, :prev_seqlen, :, :] = past_key_value[0].transpose(1, 2) + # value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) + # key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index a3503c3f29..d5b5bfd94c 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -194,7 +194,12 @@ def _llama_model_forward( # XPU #if True: # past_key_values = [] - + seqlen = hidden_states.size(1) + head_dim = self.layers[0].attn.head_dim + sin, cos = self.layers[0].attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) + sin = sin.squeeze()[position_ids].unsqueeze(2) + cos = cos.squeeze()[position_ids].unsqueeze(2) + sin_cos = {"sin": sin, "cos": cos} for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -208,6 +213,7 @@ def _llama_model_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + **sin_cos, ) hidden_states = layer_outputs[0] From c55216a07c486646ee5e809269ca1ff8908862d4 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Wed, 8 May 2024 04:34:41 -0700 Subject: [PATCH 05/31] finalize --- optimum/exporters/ipex/model_patcher.py | 31 ++- optimum/exporters/ipex/modeling/__init__.py | 1 - .../exporters/ipex/modeling/modeling_llama.py | 115 ++------- optimum/exporters/ipex/modeling/utils.py | 86 ------- optimum/exporters/ipex/modeling/xpu/utils.py | 9 - .../ipex/modeling/xpu/xpu_modeling_llama.py | 231 +++++++++--------- optimum/exporters/ipex/modeling_utils.py | 10 +- optimum/intel/ipex/modeling_base.py | 23 +- 8 files changed, 168 insertions(+), 338 deletions(-) delete mode 100644 optimum/exporters/ipex/modeling/utils.py delete mode 100644 optimum/exporters/ipex/modeling/xpu/utils.py diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 0058e09fff..50b22c7389 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -29,10 +29,8 @@ _llama_model_forward, ) -# from modeling.utils import _IPEXPatcher from .modeling.modeling_llama import _IPEXLlamaDecoderLayer - _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) _IPEX_EXPORTED_TASK = ("text-generation",) @@ -65,21 +63,24 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - if is_ipex_version("<", "2.1.0"): - raise ImportError("Only ipex version > 2.1.0 supports RotaryEmbedding and IndirectAccessKVCache") - - from intel_extension_for_pytorch.llm.modules import RotaryEmbedding - - ipex_rope = RotaryEmbedding( - model.config.max_position_embeddings, - model.config.hidden_size // model.config.num_attention_heads, - model.config.rope_theta, - model.config.architectures[0], - ) - patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) + + ipex_version = "2.1.0" if "xpu" in str(model.device) else "2.3.0" + if is_ipex_version("<", ipex_version): + raise ImportError(f"Only ipex version > {ipex_version} supports RotaryEmbedding and IndirectAccessKVCache") + if "cpu" in str(model.device): + from intel_extension_for_pytorch.llm.modules import RotaryEmbedding from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache + + ipex_rope = RotaryEmbedding( + model.config.max_position_embeddings, + model.config.hidden_size // model.config.num_attention_heads, + model.config.rope_theta, + model.config.architectures[0], + ) ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) + + patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) convert_functions(model, LlamaModel, "forward", _llama_model_forward) @@ -96,6 +97,4 @@ def _patch_llama_model(model): def _patch_model(model): if isinstance(model, LlamaForCausalLM): model = _patch_llama_model(model) - # _IPEXPatcher.patch_model(model) - return model diff --git a/optimum/exporters/ipex/modeling/__init__.py b/optimum/exporters/ipex/modeling/__init__.py index 139597f9cb..8b13789179 100644 --- a/optimum/exporters/ipex/modeling/__init__.py +++ b/optimum/exporters/ipex/modeling/__init__.py @@ -1,2 +1 @@ - diff --git a/optimum/exporters/ipex/modeling/modeling_llama.py b/optimum/exporters/ipex/modeling/modeling_llama.py index 9d3204c766..0ecef1ff83 100644 --- a/optimum/exporters/ipex/modeling/modeling_llama.py +++ b/optimum/exporters/ipex/modeling/modeling_llama.py @@ -4,41 +4,13 @@ import intel_extension_for_pytorch as ipex - class _IPEXLlamaAttention(nn.Module): def __init__(self, module, config, distributed=False) -> None: super().__init__() - # for k, v in module.__dict__.items(): - # setattr(self, k, v) - # for k, v in module.__class__.__dict__.items(): - # if k.startswith("__") or k.startswith("forward"): - # continue - # setattr(self.__class__, k, getattr(module.__class__, k)) self.module = module self.config = config self.distributed = distributed - def epreprocess_for_optimize(self, hidden_states, layer_past, **kwargs): - pass - - def qkv_gemm(self, hidden_states, **kwargs): - pass - - def rope(self, query, key, value, position_ids, layer_past, **kwargs): - pass - - def get_present(self, query, key, value, use_cache, **kwargs): - pass - - def sdpa(self, query, key, value, attention_mask, past_key_value, **kwargs): - pass - - def out_proj(self, hidden_states, residual, **kwargs): - pass - - def post_process_for_optimize(self): - pass - def forward( self, hidden_states: torch.Tensor, @@ -66,46 +38,17 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states residual (`torch.Tensor`): residual tensor to the layer of shape ` """ - - self.preprocess_for_optimize(hidden_states=hidden_states, layer_past=past_key_value, **kwargs) - - query, key, value = self.qkv_gemm(hidden_states= hidden_states, **kwargs) - - key, value = self.rope(key, value, position_ids, past_key_value, kwargs) - - present = self.get_present(query, key, value, use_cache) - - attn_output, attn_weight = self.sdpa(query, key, value, attention_mask, past_key_value, kwargs) - - attn_output = self.out_proj(attn_output, residual) + pass - self.post_process_for_optimize() - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weight, ) - else: - outputs += (None, ) - return outputs class _IPEXLlamaMLP(nn.Module): def __init__(self, module, config, distributed=False) -> None: super().__init__() - # for k, v in module.__dict__.items(): - # setattr(self, k, v) - # for k, v in module.__class__.__dict__.items(): - # if k.startswith("__") or k.startswith("forward"): - # continue - # setattr(self.__class__, k, getattr(module.__class__, k)) self.module = module self.config = config self.distributed = distributed - def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - **kwargs - ): + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor, **kwargs): """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -116,23 +59,22 @@ def forward( class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config, distributed=False) -> None: super().__init__() - # for k, v in module.__dict__.items(): - # setattr(self, k, v) - # for k, v in module.__class__.__dict__.items(): - # if k.startswith("__") or k.startswith("forward"): - # continue - # setattr(self.__class__, k, getattr(module.__class__, k)) self.layer_idx = module.self_attn.layer_idx - # TODO: add device check - if False: - self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) - else: + module_device = str(module.self_attn.q_proj.weight.device) + if "xpu" in module_device: from .xpu.xpu_modeling_llama import _IPEXLlamaAttentionXPU, _IPEXLlamaMLPXPU + self.attn = _IPEXLlamaAttentionXPU(module.self_attn, config, distributed) self.mlp = _IPEXLlamaMLPXPU(module.mlp, config, distributed) - self.input_layernorm = ipex.llm.modules.RMSNorm(module.input_layernorm.weight, module.input_layernorm.variance_epsilon) - self.post_attention_layernorm = ipex.llm.modules.RMSNorm(module.post_attention_layernorm.weight, module.post_attention_layernorm.variance_epsilon) + else: + self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) + self.input_layernorm = ipex.llm.modules.RMSNorm( + module.input_layernorm.weight, module.input_layernorm.variance_epsilon + ) + self.post_attention_layernorm = ipex.llm.modules.RMSNorm( + module.post_attention_layernorm.weight, module.post_attention_layernorm.variance_epsilon + ) def preprocess_for_optimize( self, @@ -142,18 +84,12 @@ def preprocess_for_optimize( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attention: Optional[bool] = True, use_cache: Optional[bool] = False, - **kwargs + **kwargs, ): return hidden_states, attention_mask, postion_ids, past_key_value def postprocess_for_optimize( - self, - hidden_states, - output_attention, - use_cache, - self_attn_weight, - present_key_value, - **kwargs + self, hidden_states, output_attention, use_cache, self_attn_weight, present_key_value, **kwargs ): outputs = (hidden_states,) if use_cache: @@ -163,7 +99,6 @@ def postprocess_for_optimize( return outputs - def forward( self, hidden_states: torch.Tensor, @@ -189,13 +124,7 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ outputs = self.preprocess_for_optimize( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - **kwargs + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs ) (hidden_states, attention_mask, position_ids, past_key_value) = outputs residual = hidden_states @@ -220,15 +149,7 @@ def forward( hidden_states = self.mlp(hidden_states, residual, **kwargs) outputs = self.postprocess_for_optimize( - hidden_states, - output_attentions, - use_cache, - self_attn_weight, - present_key_value, - **kwargs + hidden_states, output_attentions, use_cache, self_attn_weight, present_key_value, **kwargs ) return outputs - - - diff --git a/optimum/exporters/ipex/modeling/utils.py b/optimum/exporters/ipex/modeling/utils.py deleted file mode 100644 index acfe88bab4..0000000000 --- a/optimum/exporters/ipex/modeling/utils.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import intel_extension_for_pytorch -from typing import List -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, - LlamaRMSNorm, -) - -from .xpu.utils import update_patcher_info_on_xpu - -def update_patcher_info_on_cpu(model_name): - pass - - -class _IPEXPatcher: - def __init__(self): - self.op_patch_list: List = [] - self.function_convert_list: List = [] - self.class_convert_list: List = [] - - def update_op_list(self, op_list): - self.op_patch_list.extend(op_list) - - def update_function_convert_list(self, function_converts): - self.function_convert_list.extend(function_converts) - - def update_class_convert_list(self, class_converts): - self.class_convert_list.extend(class_converts) - - def patch_op_recursive(self, model): - - def patch_op(model, target_m, new_op_name, new_op): - for name, sub_m in model.named_children(): - if isinstance(sub_m, target_m): - setattr(sub_m, new_op_name, new_op) - patch_op(sub_m, target_m, new_op_name, new_op) - - for op_patch in self.op_patch_list: - target_m, new_op_name, new_op = op_patch - new_op_inst = new_op(model) - patch_op(model, target_m, new_op_name, new_op_inst) - - def convert_function_recursive(self, model): - - def convert_functions(m, target_m, new_function_name, new_function): - for _, sub_m in m.named_children(): - if isinstance(sub_m, target_m): - bound_method = new_function.__get__(sub_m, sub_m.__class__) - setattr(m, new_function_name, bound_method) - convert_functions(sub_m, target_m, new_function_name, new_function) - - for function_convert in self.function_convert_list: - target_m, new_function_name, new_function = function_convert - convert_functions(model, target_m, new_function_name, new_function) - - def convert_class_recursive(self, model): - - def convert_class(m, target_m, new_class, config, distributed=False): - for name, sub_m in m.named_children(): - if isinstance(sub_m, target_m): - new_m = new_class(sub_m, config, distributed) - setattr(m, name, new_m) - convert_class(sub_m, target_m, new_class, config, distributed) - - for class_convert in self.class_convert_list: - target_m, new_class, config, distributed = class_convert - convert_class(model, target_m, new_class, config, distributed) - - def retrive_patch_info(self, model_name, device): - if device.device_type == "xpu": - update_patcher_info_on_xpu(model_name) - elif device.device_type == "cpu": - update_patcher_info_on_cpu(model_name) - else: - raise RuntimeError(f"Optimum-intel only support CPU and XPU device optimization. But we find this model on {device}.") - - def patch_model(self, model): - # if isinstance(model, LlamaForCausalLM): - self.retrive_patch_info(model.__class__.name, model.device) - self.patch_op_recursive(model) - self.convert_function_recursive(model) - self.convert_class_recursive(model) - diff --git a/optimum/exporters/ipex/modeling/xpu/utils.py b/optimum/exporters/ipex/modeling/xpu/utils.py deleted file mode 100644 index 24dd107dad..0000000000 --- a/optimum/exporters/ipex/modeling/xpu/utils.py +++ /dev/null @@ -1,9 +0,0 @@ - -from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.ModuleReplacer import ModuleReplacer - -def update_patcher_info_on_xpu(patcher, model_name): - patch_info = ModuleReplacer.get_patch_info_from_model(model_name) - op_patch_list, function_convert_list, class_convert_list = patch_info - patcher.update_op_list(op_patch_list) - patcher.update_function_convert_list(function_convert_list) - patcher.update_class_convert_list(class_convert_list) \ No newline at end of file diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py index 7e8a1e8c0f..07d9d3ad5b 100644 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py @@ -3,6 +3,7 @@ import torch.nn as nn from typing import Optional import math +import gc import intel_extension_for_pytorch from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.modules.llama import NewIPEXLLAMABlock @@ -19,9 +20,7 @@ def matmul_add_add(attn_output, weight, bias=None, residual=None): attn_output += bias else: if bias is not None: - attn_output = torch.ops.torch_ipex.mm_bias_resadd( - attn_output, weight, bias, 1.0, residual, 1.0 - ) + attn_output = torch.ops.torch_ipex.mm_bias_resadd(attn_output, weight, bias, 1.0, residual, 1.0) else: attn_output = torch.addmm( residual.flatten(0, -2), @@ -32,6 +31,13 @@ def matmul_add_add(attn_output, weight, bias=None, residual=None): attn_output = attn_output.view(seq_len, bs, -1) return attn_output + +def reference_elimination(c, b): + for item in gc.get_objects(): + if isinstance(item, torch.Tensor) and item.data_ptr() == c.data_ptr() and item is not c: + item.data = b + + class _IPEXLlamaAttentionXPU(_IPEXLlamaAttention): def __init__(self, module, config, distributed=False, optimized_module=None) -> None: super().__init__(module, config, distributed) @@ -40,14 +46,9 @@ def __init__(self, module, config, distributed=False, optimized_module=None) -> self.num_kv_heads = module.num_key_value_heads self.embed_dim = module.config.hidden_size self.port_parameters(module) + torch.xpu.empty_cache() from intel_extension_for_pytorch.llm.modules import RotaryEmbedding - # self.ipex_rope = RotaryEmbedding( - # module.config.max_position_embeddings, - # module.config.hidden_size // module.config.num_attention_heads, - # module.config.rope_theta, - # module.config.architectures[0], - # ) self.ipex_rope = _IPEXRopeXPU( module.config.max_position_embeddings, module.config.hidden_size // module.config.num_attention_heads, @@ -88,34 +89,44 @@ def forward( if past_key_value: _, _, prev_seqlen, _ = past_key_value[0].size() if self.num_kv_heads == self.num_heads: - query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - key = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - value = torch.empty((bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - # query = torch.empty_like(hidden_states) - # key = torch.empty_like(hidden_states) - # value = torch.empty_like(hidden_states) + query = torch.empty( + (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + key = torch.empty( + (bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + value = torch.empty( + (bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) torch.ops.torch_ipex.mm_qkv_out( - hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key[:, prev_seqlen:, :], value[:, prev_seqlen:, :]) - # torch.ops.torch_ipex.mm_qkv_out( - # hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key[:, prev_seqlen:, :], value[:, prev_seqlen:, :]) + hidden_states, + self.qkv_proj_weight, + self.qkv_proj_bias, + query, + key[:, prev_seqlen:, :], + value[:, prev_seqlen:, :], + ) else: - query = torch.empty((bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - key = torch.empty((bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) - value = torch.empty((bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device) + query = torch.empty( + (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + key = torch.empty( + (bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + value = torch.empty( + (bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) torch.ops.torch_ipex.mm_qkv_group_out( - hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value) - # if past_key_value: - # key[:, :prev_seqlen, :] = past_key_value[0].transpose(1, 2).view(bs, prev_seqlen, -1) - # value[:, :prev_seqlen, :] = past_key_value[1].transpose(1, 2).view(bs, prev_seqlen, -1) - # rope - #query = query.view([-1, seqlen, self.num_heads, self.head_dim]) - #key = key.view([-1, seqlen, self.num_kv_heads, self.head_dim]) - # value = value.view([bs, prev_seqlen + seqlen, self.num_kv_heads, self.head_dim]) + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value + ) + query = query.view([bs, seqlen, self.num_heads, self.head_dim]) key = key.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim]) - # sin, cos = self.ipex_rope.get_sin_cos(seqlen, self.head_dim // 2) - # sin = sin.squeeze()[position_ids].unsqueeze(2) - # cos = cos.squeeze()[position_ids].unsqueeze(2) + if hasattr(kwargs, "sin") and hasattr(kwargs, "cos"): print("cache sin cos") sin = kwargs["sin"] @@ -125,56 +136,32 @@ def forward( sin = sin.squeeze()[position_ids].unsqueeze(2) cos = cos.squeeze()[position_ids].unsqueeze(2) self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key[:, prev_seqlen:, :, :]) - # query = self.ipex_rope( - # query, - # position_ids, - # self.num_kv_heads, - # self.head_dim, - # self.head_dim // 2, - # self.head_dim, - # seqlen, - # ) - - # key = self.ipex_rope( - # key, - # position_ids, - # self.num_kv_heads, - # self.head_dim, - # self.head_dim // 2, - # self.head_dim, - # seqlen, - # ) value = value.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim]) if past_key_value is not None: value[:, :prev_seqlen, :, :] = past_key_value[1].transpose(1, 2) key[:, :prev_seqlen, :, :] = past_key_value[0].transpose(1, 2) - # value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) - # key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) present = (key, value) if use_cache else None scale = 1.0 / math.sqrt(self.head_dim) - # import pdb;pdb.set_trace() is_causal = False - # if past_key_value is not None: - # is_causal = True - attn_output = torch.xpu.IpexSDP(query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False) - # attn_output, attn_weight = torch.nn.functional.scaled_dot_product_attention(query, key, value, attention_mask, dropout_p=0.0, scale=scale) + attn_output = torch.xpu.IpexSDP( + query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False + ) attn_output = attn_output.transpose(1, 2).view([bs, seqlen, self.embed_dim]) - attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view([bs, seqlen, self.embed_dim]) - # attn_output = torch.matmul(attn_output, self.o_proj_weight) - # attn_output = attn_output + residual + attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view( + [bs, seqlen, self.embed_dim] + ) outputs = (attn_output, present) if output_attentions: raise ValueError("not support output attn_weight") - # outputs += (attn_weight, ) else: - outputs += (None, ) + outputs += (None,) return outputs - def port_parameters(self, module): self.qkv_proj_bias = None self.qkv_proj_weight = None @@ -183,24 +170,76 @@ def port_parameters(self, module): k_proj = module.k_proj.weight.transpose(0, 1) v_proj = module.v_proj.weight.transpose(0, 1) self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) + reference_elimination(module.q_proj.weight.data, self.qkv_proj_weight[0, :, :].transpose(0, 1)) + module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) + reference_elimination(module.k_proj.weight.data, self.qkv_proj_weight[1, :, :].transpose(0, 1)) + module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) + reference_elimination(module.v_proj.weight.data, self.qkv_proj_weight[2, :, :].transpose(0, 1)) + module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) if module.q_proj.bias is not None: - self.qkv_proj_bias = torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]).contiguous().view([3, -1]) + self.qkv_proj_bias = ( + torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) + .contiguous() + .view([3, -1]) + ) + reference_elimination(module.q_proj.bias.data, self.qkv_proj_bias[0]) + module.q_proj.bias.data = self.qkv_proj_bias[0] + reference_elimination(module.k_proj.bias.data, self.qkv_proj_bias[1]) + module.k_proj.bias.data = self.qkv_proj_bias[1] + reference_elimination(module.v_proj.bias.data, self.qkv_proj_bias[2]) + module.v_proj.bias.data = self.qkv_proj_bias[2] else: group = self.num_heads // self.num_kv_heads q_proj = module.q_proj.weight.view(self.num_kv_heads, group, self.head_dim, self.embed_dim) k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view([self.num_kv_heads, group + 2, self.head_dim, self.embed_dim]) + self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( + [self.num_kv_heads, group + 2, self.head_dim, self.embed_dim] + ) + reference_elimination( + module.q_proj.data, + self.qkv_proj_weight[:, :group, :, :].view( + [self.num_kv_heads * group * self.head_dim, self.embed_dim] + ), + ) + module.q_proj.data = self.qkv_proj_weight[:, :group, :, :].view( + [self.num_kv_heads * group * self.head_dim, self.embed_dim] + ) + reference_elimination( + module.k_proj.data, + self.qkv_proj_weight[:, group, :, :].view([self.num_kv_heads * self.head_dim, self.embed_dim]), + ) + module.k_proj.data = self.qkv_proj_weight[:, group, :, :].view( + [self.num_kv_heads * self.head_dim, self.embed_dim] + ) + reference_elimination( + module.v_proj.data, + self.qkv_proj_weight[:, group + 1, :, :].view([self.num_kv_heads * self.head_dim, self.embed_dim]), + ) + module.v_proj.data = self.qkv_proj_weight[:, group + 1, :, :].view( + [self.num_kv_heads * self.head_dim, self.embed_dim] + ) if module.q_proj.bias is not None: q_bias = module.q_proj.bias.view(self.num_kv_heads, group, self.head_dim) k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) - self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view([self.num_kv_heads, group + 2, self.head_dim]) + self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( + [self.num_kv_heads, group + 2, self.head_dim] + ) + reference_elimination(module.q_proj.bias.data, self.qkv_proj_bias[:, :group, self.head_dim].view(-1)) + module.q_proj.bias.data = self.qkv_proj_bias[:, :group, self.head_dim].view(-1) + reference_elimination(module.k_proj.bias.data, self.qkv_proj_bias[:, group, self.head_dim].view(-1)) + module.k_proj.bias.data = self.qkv_proj_bias[:, group, self.head_dim].view(-1) + reference_elimination( + module.v_proj.bias.data, self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1) + ) + module.v_proj.bias.data = self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1) self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() + reference_elimination(module.o_proj.weight.data, self.o_proj_weight.transpose(0, 1)) + module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) self.o_proj_bias = module.o_proj.bias - class _IPEXLlamaMLPXPU(_IPEXLlamaMLP): def __init__(self, module, config, distributed=False, optimized_module=None) -> None: super().__init__(module, config, distributed) @@ -209,12 +248,7 @@ def __init__(self, module, config, distributed=False, optimized_module=None) -> self.mlp_impl = optimized_module self.port_parameter(module) - def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor = None, - **kwargs - ): + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -224,53 +258,16 @@ def forward( out = matmul_add_add(out, self.down_proj_weight, self.down_proj_bias, residual) return out - def port_parameter(self, module): self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() + reference_elimination(module.up_proj.weight.data, self.up_proj_weight.transpose(0, 1)) + module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1) self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() + reference_elimination(module.gate_proj.weight.data, self.gate_proj_weight.transpose(0, 1)) + module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1) self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() + reference_elimination(module.down_proj.weight.data, self.down_proj_weight.transpose(0, 1)) + module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1) self.up_proj_bias = module.up_proj.bias self.gate_proj_bias = module.gate_proj.bias self.down_proj_bias = module.down_proj.bias - - - -# class _IPEXLlamaDecoderLayerXPU(_IPEXLlamaDecoderLayer): -# def __init__(self, module, config, distributed=False) -> None: -# super().__init__(module, config, distributed) -# self.block_impl = NewIPEXLLAMABlock(module, config) -# self.attn = _IPEXLlamaAttentionXPU(module.self_attn, config, self.block_impl.attn) -# self.mlp = _IPEXLlamaMLPXPU(module.mlp, config, self.block_impl.mlp) - -# def preprocess_for_optimize( -# self, -# hidden_states: torch.Tensor, -# attention_mask: Optional[torch.Tensor] = None, -# position_ids: Optional[torch.LongTensor] = None, -# past_key_value: Optional[Tuple[torch.Tensor]] = None, -# output_attention: Optional[bool] = True, -# use_cache: Optional[bool] = False, -# **kwargs -# ): -# return self.block_impl.preprocess_for_optimize( -# hidden_states, -# attention_mask, -# position_ids, -# past_key_value, -# output_attention, -# use_cache, -# **kwargs -# ) - - - -# def postprocess_for_optimize(self, hidden_states, output_attention, use_cache, self_attn_weight, present_key_value, **kwargs): -# return self.block_impl.postprocess_for_optimize( -# hidden_states, -# output_attention, -# use_cache, -# self_attn_weight, -# present_key_value, -# **kwargs -# ) - diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index d5b5bfd94c..991cf8d9b8 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -112,13 +112,16 @@ def _llama_attn_forward( return attn_output, past_key_value, attn_weights + def padding_attn_mask(attn_mask, alignment): if attn_mask is None: return None - assert isinstance(attn_mask, torch.Tensor), f"attn mask is supposed to be a tensor, instead we got {type(attn_mask)}" + assert isinstance( + attn_mask, torch.Tensor + ), f"attn mask is supposed to be a tensor, instead we got {type(attn_mask)}" if attn_mask.device == torch.device("cpu"): return attn_mask - last_dim_size= attn_mask.size(-1) + last_dim_size = attn_mask.size(-1) aligned_size = (last_dim_size + alignment - 1) // alignment * alignment mask_size = [*attn_mask.size()[:-1], aligned_size] new_attn_mask = torch.empty(mask_size, dtype=attn_mask.dtype, device=attn_mask.device).fill_(-65504.0) @@ -191,9 +194,6 @@ def _llama_model_forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - # XPU - #if True: - # past_key_values = [] seqlen = hidden_states.size(1) head_dim = self.layers[0].attn.head_dim sin, cos = self.layers[0].attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 2b739ea502..671b95b8a4 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -62,8 +62,6 @@ def _is_patched_with_ipex(model, task): - if is_ipex_version("<", "2.5.0"): - return False if isinstance(model, torch.jit.ScriptModule): for node in model.graph.nodes(): @@ -161,7 +159,9 @@ def _from_transformers( local_files_only: bool = False, torch_dtype: Optional[Union[str, "torch.dtype"]] = None, trust_remote_code: bool = False, + **kwargs, ): + device_map = kwargs.pop("device_map", None) if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", @@ -172,7 +172,6 @@ def _from_transformers( "Both the arguments `use_auth_token` and `token` were specified, which is not supported. Please specify only `token`." ) token = use_auth_token - if is_torch_version("<", "2.1.0"): raise ImportError("`torch>=2.0.0` is needed to trace your model") @@ -186,15 +185,25 @@ def _from_transformers( "force_download": force_download, "torch_dtype": torch_dtype, "trust_remote_code": trust_remote_code, + "device_map": device_map, } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - traced_model = ipex_jit_trace(model, task, use_cache) - config.torchscript = True - config.torch_dtype = torch_dtype + if "cpu" in str(model.device): + traced_model = ipex_jit_trace(model, task, use_cache) + config.torchscript = True + config.torch_dtype = torch_dtype + return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False) + else: + from optimum.exporters.ipex.model_patcher import _patch_model + + if _is_patched_with_ipex(model, task): + model = _patch_model(model) + else: + raise NotImplementedError(f"The given model is not support yet") - return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False) + return model @classmethod def _from_pretrained( From 5b3b72de66e5bf8550fb04903ea932a80c22808d Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Wed, 8 May 2024 18:34:41 -0700 Subject: [PATCH 06/31] fix version --- optimum/exporters/ipex/model_patcher.py | 4 ++-- optimum/intel/ipex/modeling_base.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 50b22c7389..737bc14b0e 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -64,9 +64,9 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - ipex_version = "2.1.0" if "xpu" in str(model.device) else "2.3.0" + ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.5.0" if is_ipex_version("<", ipex_version): - raise ImportError(f"Only ipex version > {ipex_version} supports RotaryEmbedding and IndirectAccessKVCache") + raise ImportError(f"Only ipex version >= {ipex_version} supports RotaryEmbedding and IndirectAccessKVCache") if "cpu" in str(model.device): from intel_extension_for_pytorch.llm.modules import RotaryEmbedding diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 671b95b8a4..41de50cd2a 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -62,7 +62,9 @@ def _is_patched_with_ipex(model, task): - + ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.5.0" + if is_ipex_version("<", ipex_version): + return False if isinstance(model, torch.jit.ScriptModule): for node in model.graph.nodes(): # Jit will record the codes position so we can check if the node use ipex exporter. @@ -172,8 +174,6 @@ def _from_transformers( "Both the arguments `use_auth_token` and `token` were specified, which is not supported. Please specify only `token`." ) token = use_auth_token - if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.0.0` is needed to trace your model") task = cls.export_feature model_kwargs = { @@ -191,6 +191,8 @@ def _from_transformers( model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) if "cpu" in str(model.device): + if is_torch_version("<", "2.1.0"): + raise ImportError("`torch>=2.1.0` is needed to trace your model") traced_model = ipex_jit_trace(model, task, use_cache) config.torchscript = True config.torch_dtype = torch_dtype From 48971449d8b81fc88f3373541406253ce265aa80 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 10 May 2024 23:25:13 -0700 Subject: [PATCH 07/31] fix ipex version check --- optimum/intel/ipex/modeling_base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 41de50cd2a..04c5a6508f 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -62,16 +62,18 @@ def _is_patched_with_ipex(model, task): - ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.5.0" - if is_ipex_version("<", ipex_version): - return False if isinstance(model, torch.jit.ScriptModule): + if is_ipex_version("<", "2.5.0"): + return False for node in model.graph.nodes(): # Jit will record the codes position so we can check if the node use ipex exporter. if "torch_ipex::rotary_position_embedding" in node.__str__(): return True return False else: + ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.5.0" + if is_ipex_version("<", ipex_version): + return False return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK From 5351f4a1d411b6813e8ed47a826ffdc667491518 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 23 May 2024 13:05:07 -0400 Subject: [PATCH 08/31] ipex 2.3 released --- optimum/exporters/ipex/model_patcher.py | 8 ++++---- optimum/exporters/ipex/modeling_utils.py | 17 ++++++++++------- optimum/intel/ipex/modeling_base.py | 2 +- tests/ipex/test_modeling.py | 16 ---------------- 4 files changed, 15 insertions(+), 28 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 60ff3b721b..3996c2b23f 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -62,10 +62,10 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - if is_ipex_version("<", "2.5.0"): - raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache") + if is_ipex_version("<", "2.3.0"): + raise ImportError("Only ipex version >= 2.3.0 supports RotaryEmbedding and IndirectAccessKVCacheAttention") - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding ipex_rope = RotaryEmbedding( model.config.max_position_embeddings, @@ -73,7 +73,7 @@ def _patch_llama_model(model): model.config.rope_theta, model.config.architectures[0], ) - ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) + ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings) patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f75e559eaf..9d53caf4fc 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -219,7 +219,7 @@ def _llama_model_forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayerRef(nn.Module): def __init__(self, module, config, distributed=False): - if is_ipex_version("<", "2.5.0"): + if is_ipex_version("<", "2.3.0"): raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd @@ -278,7 +278,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, ) - if not self.distributed: + if hasattr(self, "mha_linear_add"): hidden_states = self.mha_linear_add(hidden_states, residual) else: hidden_states = self.self_attn.o_proj(hidden_states) @@ -288,12 +288,15 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - mlp_gate = self.linear_silu_mul(hidden_states) - - if not self.distributed: - hidden_states = self.mlp_linear_add(mlp_gate, residual) + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.mlp.down_proj(mlp_gate) + hidden_states = residual + hidden_states else: - hidden_states = self.mlp.down_proj(mlp_gate) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index e929a4ddb8..c9d43e3dc0 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -63,7 +63,7 @@ def _is_patched_with_ipex(model, task): - if is_ipex_version("<", "2.5.0"): + if is_ipex_version("<", "2.3.0"): return False if isinstance(model, torch.jit.ScriptModule): diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 2a2f18f6f8..7eb34ef47c 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -219,22 +219,6 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_assisted_decoding(self, model_arch): - model_id = MODEL_NAMES[model_arch] - tokenizer = AutoTokenizer.from_pretrained(model_id) - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) - tokens = tokenizer("This is a sample input", return_tensors="pt") - ipex_output = ipex_model.generate(**tokens, do_sample=False) - ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) - transformers_output = transformers_model.generate(**tokens, do_sample=False) - transformers_output_assisted = transformers_model.generate( - **tokens, do_sample=False, assistant_model=ipex_model - ) - self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) - self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) - @parameterized.expand( grid_parameters( { From 6289b573704312c77e34a22e4498887986340d64 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 24 May 2024 00:18:23 -0700 Subject: [PATCH 09/31] change versions --- optimum/exporters/ipex/model_patcher.py | 2 +- optimum/intel/ipex/modeling_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 737bc14b0e..64fd199925 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -64,7 +64,7 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.5.0" + ipex_version = "2.1.0" if "xpu" in str(model.device) else "2.5.0" if is_ipex_version("<", ipex_version): raise ImportError(f"Only ipex version >= {ipex_version} supports RotaryEmbedding and IndirectAccessKVCache") diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 04c5a6508f..2edc6ba214 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -71,7 +71,7 @@ def _is_patched_with_ipex(model, task): return True return False else: - ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.5.0" + ipex_version = "2.1.0" if "xpu" in str(model.device) else "2.5.0" if is_ipex_version("<", ipex_version): return False return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK From 3824300519fd99b0f557bf950ab6b9af3dc0a2be Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 24 May 2024 00:18:40 -0700 Subject: [PATCH 10/31] debug beam search --- .../ipex/modeling/xpu/xpu_modeling_llama.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py index 07d9d3ad5b..2ad62f37e6 100644 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py @@ -93,22 +93,22 @@ def forward( (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device ) key = torch.empty( - (bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), + (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, - device=hidden_states.device, + device=hidden_states.device ) value = torch.empty( - (bs, prev_seqlen + seqlen, self.num_heads * self.head_dim), + (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, - device=hidden_states.device, + device=hidden_states.device ) torch.ops.torch_ipex.mm_qkv_out( hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, - key[:, prev_seqlen:, :], - value[:, prev_seqlen:, :], + key, + value, ) else: query = torch.empty( @@ -125,21 +125,17 @@ def forward( ) query = query.view([bs, seqlen, self.num_heads, self.head_dim]) - key = key.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim]) + key = key.view([bs, seqlen, self.num_kv_heads, self.head_dim]) - if hasattr(kwargs, "sin") and hasattr(kwargs, "cos"): - print("cache sin cos") - sin = kwargs["sin"] - cos = kwargs["cos"] - else: - sin, cos = self.ipex_rope.get_sin_cos(seqlen, self.head_dim // 2) - sin = sin.squeeze()[position_ids].unsqueeze(2) - cos = cos.squeeze()[position_ids].unsqueeze(2) - self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key[:, prev_seqlen:, :, :]) - value = value.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim]) + + sin = kwargs.pop("sin", None) + cos = kwargs.pop("cos", None) + + self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) + value = value.view([bs, seqlen, self.num_kv_heads, self.head_dim]) if past_key_value is not None: - value[:, :prev_seqlen, :, :] = past_key_value[1].transpose(1, 2) - key[:, :prev_seqlen, :, :] = past_key_value[0].transpose(1, 2) + key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) query = query.transpose(1, 2) key = key.transpose(1, 2) From 872a3ebcf5557b14ce99a831a6e7e7e99537bf7e Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 24 May 2024 02:50:39 -0700 Subject: [PATCH 11/31] remove reference elimination --- .../ipex/modeling/xpu/xpu_modeling_llama.py | 38 +------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py index 2ad62f37e6..961c6b9609 100644 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py @@ -32,12 +32,6 @@ def matmul_add_add(attn_output, weight, bias=None, residual=None): return attn_output -def reference_elimination(c, b): - for item in gc.get_objects(): - if isinstance(item, torch.Tensor) and item.data_ptr() == c.data_ptr() and item is not c: - item.data = b - - class _IPEXLlamaAttentionXPU(_IPEXLlamaAttention): def __init__(self, module, config, distributed=False, optimized_module=None) -> None: super().__init__(module, config, distributed) @@ -166,11 +160,8 @@ def port_parameters(self, module): k_proj = module.k_proj.weight.transpose(0, 1) v_proj = module.v_proj.weight.transpose(0, 1) self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) - reference_elimination(module.q_proj.weight.data, self.qkv_proj_weight[0, :, :].transpose(0, 1)) module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) - reference_elimination(module.k_proj.weight.data, self.qkv_proj_weight[1, :, :].transpose(0, 1)) module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) - reference_elimination(module.v_proj.weight.data, self.qkv_proj_weight[2, :, :].transpose(0, 1)) module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) if module.q_proj.bias is not None: self.qkv_proj_bias = ( @@ -178,11 +169,8 @@ def port_parameters(self, module): .contiguous() .view([3, -1]) ) - reference_elimination(module.q_proj.bias.data, self.qkv_proj_bias[0]) module.q_proj.bias.data = self.qkv_proj_bias[0] - reference_elimination(module.k_proj.bias.data, self.qkv_proj_bias[1]) module.k_proj.bias.data = self.qkv_proj_bias[1] - reference_elimination(module.v_proj.bias.data, self.qkv_proj_bias[2]) module.v_proj.bias.data = self.qkv_proj_bias[2] else: group = self.num_heads // self.num_kv_heads @@ -192,26 +180,12 @@ def port_parameters(self, module): self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( [self.num_kv_heads, group + 2, self.head_dim, self.embed_dim] ) - reference_elimination( - module.q_proj.data, - self.qkv_proj_weight[:, :group, :, :].view( - [self.num_kv_heads * group * self.head_dim, self.embed_dim] - ), - ) module.q_proj.data = self.qkv_proj_weight[:, :group, :, :].view( [self.num_kv_heads * group * self.head_dim, self.embed_dim] ) - reference_elimination( - module.k_proj.data, - self.qkv_proj_weight[:, group, :, :].view([self.num_kv_heads * self.head_dim, self.embed_dim]), - ) module.k_proj.data = self.qkv_proj_weight[:, group, :, :].view( [self.num_kv_heads * self.head_dim, self.embed_dim] ) - reference_elimination( - module.v_proj.data, - self.qkv_proj_weight[:, group + 1, :, :].view([self.num_kv_heads * self.head_dim, self.embed_dim]), - ) module.v_proj.data = self.qkv_proj_weight[:, group + 1, :, :].view( [self.num_kv_heads * self.head_dim, self.embed_dim] ) @@ -222,16 +196,10 @@ def port_parameters(self, module): self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( [self.num_kv_heads, group + 2, self.head_dim] ) - reference_elimination(module.q_proj.bias.data, self.qkv_proj_bias[:, :group, self.head_dim].view(-1)) module.q_proj.bias.data = self.qkv_proj_bias[:, :group, self.head_dim].view(-1) - reference_elimination(module.k_proj.bias.data, self.qkv_proj_bias[:, group, self.head_dim].view(-1)) module.k_proj.bias.data = self.qkv_proj_bias[:, group, self.head_dim].view(-1) - reference_elimination( - module.v_proj.bias.data, self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1) - ) module.v_proj.bias.data = self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1) self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() - reference_elimination(module.o_proj.weight.data, self.o_proj_weight.transpose(0, 1)) module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) self.o_proj_bias = module.o_proj.bias @@ -243,7 +211,8 @@ def __init__(self, module, config, distributed=False, optimized_module=None) -> if optimized_module is not None: self.mlp_impl = optimized_module self.port_parameter(module) - + torch.xpu.empty_cache() + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): """ Args: @@ -256,13 +225,10 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** def port_parameter(self, module): self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() - reference_elimination(module.up_proj.weight.data, self.up_proj_weight.transpose(0, 1)) module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1) self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() - reference_elimination(module.gate_proj.weight.data, self.gate_proj_weight.transpose(0, 1)) module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1) self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() - reference_elimination(module.down_proj.weight.data, self.down_proj_weight.transpose(0, 1)) module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1) self.up_proj_bias = module.up_proj.bias self.gate_proj_bias = module.gate_proj.bias From d1d0ca0a50ed697271c912a266b10ac71e4b5892 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sat, 25 May 2024 14:22:33 -0400 Subject: [PATCH 12/31] refactor IPEXLlamaAttention --- optimum/exporters/ipex/model_patcher.py | 18 +- optimum/exporters/ipex/modeling_utils.py | 295 ++++++++++++++--------- 2 files changed, 184 insertions(+), 129 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 3996c2b23f..7f99dbddd1 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -13,7 +13,6 @@ # limitations under the License. from transformers.models.llama.modeling_llama import ( - LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, @@ -24,7 +23,6 @@ from .modeling_utils import ( _IPEXLlamaDecoderLayerRef, - _llama_attn_forward, _llama_layer_norm_forward, _llama_model_forward, ) @@ -63,24 +61,10 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): if is_ipex_version("<", "2.3.0"): - raise ImportError("Only ipex version >= 2.3.0 supports RotaryEmbedding and IndirectAccessKVCacheAttention") - - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding - - ipex_rope = RotaryEmbedding( - model.config.max_position_embeddings, - model.config.hidden_size // model.config.num_attention_heads, - model.config.rope_theta, - model.config.architectures[0], - ) - ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings) - patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) - patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) + raise ImportError("Only ipex version >= 2.3.0 supports llama model patching") convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 9d53caf4fc..04a6fe5082 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -29,90 +29,6 @@ def _llama_layer_norm_forward(self, hidden_states): return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 -def _llama_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len - - query = query.view(bsz, q_len, self.num_heads, self.head_dim) - key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - # Use ipex op to rotary position embedding more efficient. - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - - if use_cache: - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - None, - attention_mask, - ) - else: - value_states = value.transpose(1, 2) - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -216,12 +132,147 @@ def _llama_model_forward( ) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 -class _IPEXLlamaDecoderLayerRef(nn.Module): - def __init__(self, module, config, distributed=False): +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 +class _IPEXLlamaAttentionRef(nn.Module): + def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", "2.3.0"): - raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") + raise ImportError( + "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding" + ) + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + + super().__init__() + for k, v in module.__dict__.items(): + setattr(self, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(self.__class__, k, getattr(module.__class__, k)) + self.config = config + self.distributed = distributed + if not self.distributed: + self.mha_linear_add = LinearAdd(self.o_proj) + del self.__dict__["_modules"]["o_proj"] + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=module.config.max_position_embeddings + ) + self.ipex_rope = RotaryEmbedding( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + residual (`torch.Tensor`): residual tensor to the layer of shape ` + """ + bsz, seq_len, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len + + query = query.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + # Use ipex op to rotary position embedding more efficient. + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + + if use_cache: + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) + else: + value_states = value.transpose(1, 2) + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + kv_seq_len = key_states.shape[-2] + + past_key_value = None + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, seq_len, self.hidden_size) + + if hasattr(self, "mha_linear_add"): + attn_output = self.mha_linear_add(attn_output, residual) + else: + attn_output = self.o_proj(attn_output) + attn_output = residual + attn_output + + if not output_attentions: + attn_weights = None + + return attn_output, past_key_value, attn_weights + + +class _IPEXLlamaMLP(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + if is_ipex_version("<", "2.3.0"): + raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd super().__init__() @@ -231,15 +282,47 @@ def __init__(self, module, config, distributed=False): if k.startswith("__") or k.startswith("forward"): continue setattr(self.__class__, k, getattr(module.__class__, k)) + self.config = config self.distributed = distributed if not self.distributed: - self.mha_linear_add = LinearAdd(module.self_attn.o_proj) - self.mlp_linear_add = LinearAdd(module.mlp.down_proj) - del self.__dict__["_modules"]["self_attn"].o_proj - del self.__dict__["_modules"]["mlp"].down_proj - self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj) - del self.__dict__["_modules"]["mlp"].gate_proj - del self.__dict__["_modules"]["mlp"].up_proj + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj(mlp_gate) + hidden_states = residual + hidden_states + else: + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = residual + hidden_states + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 +class _IPEXLlamaDecoderLayerRef(nn.Module): + def __init__(self, module, config, distributed=False): + super().__init__() + for k, v in module.__dict__.items(): + setattr(self, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(self.__class__, k, getattr(module.__class__, k)) + self.distributed = distributed + self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) def forward( self, @@ -270,34 +353,22 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, present_key_value, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=None, + residual=residual, + **kwargs, ) - if hasattr(self, "mha_linear_add"): - hidden_states = self.mha_linear_add(hidden_states, residual) - else: - hidden_states = self.self_attn.o_proj(hidden_states) - hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - - if hasattr(self, "linear_silu_mul"): - mlp_gate = self.linear_silu_mul(hidden_states) - if hasattr(self, "mlp_linear_add"): - hidden_states = self.mlp_linear_add(mlp_gate, residual) - else: - hidden_states = self.mlp.down_proj(mlp_gate) - hidden_states = residual + hidden_states - else: - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, residual, **kwargs) outputs = (hidden_states,) From 89e10d61718370df6bc0174066e8619e47bb343d Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sat, 25 May 2024 18:19:19 -0700 Subject: [PATCH 13/31] add xpu port --- optimum/exporters/ipex/modeling_utils.py | 156 +++++++++++++++++------ optimum/intel/ipex/modeling_base.py | 36 +++--- 2 files changed, 132 insertions(+), 60 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index c0eca8361f..e361802842 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -178,12 +178,6 @@ def _llama_model_forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 class _IPEXLlamaAttentionRef(nn.Module): def __init__(self, module, config, distributed=False) -> None: - if is_ipex_version("<", "2.3.0"): - raise ImportError( - "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding" - ) - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding - super().__init__() for k, v in module.__dict__.items(): setattr(self, k, v) @@ -193,19 +187,33 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed - if not self.distributed: - self.mha_linear_add = LinearAdd(self.o_proj) - del self.__dict__["_modules"]["o_proj"] - self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( - text_max_length=module.config.max_position_embeddings - ) - self.ipex_rope = RotaryEmbedding( - module.config.max_position_embeddings, - module.config.hidden_size // module.config.num_attention_heads, - module.config.rope_theta, - module.config.architectures[0], - ) - + self.module_device = module.q_proj.weight.device.type + if self.module_device == "xpu": + from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU + self.ipex_rope = _IPEXRopeXPU( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) + self.port_parameters(module) + torch.xpu.empty_cache() + else: + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + if not self.distributed: + self.mha_linear_add = LinearAdd(self.o_proj) + del self.__dict__["_modules"]["o_proj"] + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=module.config.max_position_embeddings + ) + self.ipex_rope = RotaryEmbedding( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) + + def forward( self, hidden_states: torch.Tensor, @@ -310,9 +318,60 @@ def forward( attn_weights = None return attn_output, past_key_value, attn_weights - - -class _IPEXLlamaMLP(nn.Module): + + def port_parameters(self, module): + self.qkv_proj_bias = None + self.qkv_proj_weight = None + if self.num_heads == self.num_key_value_heads: + q_proj = module.q_proj.weight.transpose(0, 1) + k_proj = module.k_proj.weight.transpose(0, 1) + v_proj = module.v_proj.weight.transpose(0, 1) + self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) + module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) + module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) + module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) + if module.q_proj.bias is not None: + self.qkv_proj_bias = ( + torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) + .contiguous() + .view([3, -1]) + ) + module.q_proj.bias.data = self.qkv_proj_bias[0] + module.k_proj.bias.data = self.qkv_proj_bias[1] + module.v_proj.bias.data = self.qkv_proj_bias[2] + else: + q_proj = module.q_proj.weight.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim, self.embed_dim) + k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( + [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim, self.embed_dim] + ) + module.q_proj.data = self.qkv_proj_weight[:, :self.num_key_value_groups, :, :].view( + [self.num_kv_heads * self.num_key_value_groups * self.head_dim, self.embed_dim] + ) + module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].view( + [self.num_kv_heads * self.head_dim, self.embed_dim] + ) + module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].view( + [self.num_kv_heads * self.head_dim, self.embed_dim] + ) + if module.q_proj.bias is not None: + q_bias = module.q_proj.bias.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim) + k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) + v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) + self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( + [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim] + ) + module.q_proj.bias.data = self.qkv_proj_bias[:, :self.num_key_value_groups, self.head_dim].view(-1) + module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) + module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1) + + self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() + module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) + self.o_proj_bias = module.o_proj.bias + + +class _IPEXLlamaMLPRef(nn.Module): def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", "2.3.0"): raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") @@ -327,31 +386,50 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed - if not self.distributed: - self.mlp_linear_add = LinearAdd(module.down_proj) - del self.__dict__["_modules"]["down_proj"] - self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) - del self.__dict__["_modules"]["gate_proj"] - del self.__dict__["_modules"]["up_proj"] + self.module_device = module.gate_proj.weight.device.type + if self.module_device == "xpu": + self.port_parameter(module) + torch.xpu.empty_cache() + else: + if not self.distributed: + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` """ - if hasattr(self, "linear_silu_mul"): - mlp_gate = self.linear_silu_mul(hidden_states) - if hasattr(self, "mlp_linear_add"): - hidden_states = self.mlp_linear_add(mlp_gate, residual) - else: - hidden_states = self.down_proj(mlp_gate) - hidden_states = residual + hidden_states + if self.module_device == "xpu": + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + hidden_states = matmul_add_add(hidden_states, self.down_proj_weight, self.down_proj_bias, residual) else: - hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) - hidden_states = residual + hidden_states - + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj(mlp_gate) + hidden_states = residual + hidden_states + else: + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = residual + hidden_states return hidden_states + def port_parameter(self, module): + self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() + module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1) + self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() + module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1) + self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() + module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1) + self.up_proj_bias = module.up_proj.bias + self.gate_proj_bias = module.gate_proj.bias + self.down_proj_bias = module.down_proj.bias # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayerRef(nn.Module): @@ -365,7 +443,7 @@ def __init__(self, module, config, distributed=False): setattr(self.__class__, k, getattr(module.__class__, k)) self.distributed = distributed self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) + self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) def forward( self, diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index af21b4a402..ebc57ec3c1 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -140,10 +140,12 @@ def __init__( self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model_save_dir = model_save_dir self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) - - self.input_names = { - inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self" - } + if self._device.type == "cpu": + self.input_names = { + inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self" + } + else: + self.input_names = {"past_key_values": None, "position_ids": None} # Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating # a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 AutoConfig.register(self.base_model_prefix, AutoConfig) @@ -169,7 +171,6 @@ def _from_transformers( trust_remote_code: bool = False, _commit_hash: str = None, ): - device_map = kwargs.pop("device_map", None) if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", @@ -196,22 +197,15 @@ def _from_transformers( model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - if "cpu" in str(model.device): - if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.1.0` is needed to trace your model") - traced_model = ipex_jit_trace(model, task, use_cache) - config.torchscript = True - config.torch_dtype = torch_dtype - return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False) - else: - from optimum.exporters.ipex.model_patcher import _patch_model - + if is_torch_xpu_available(check_device=True): + model.to("xpu:0") if _is_patched_with_ipex(model, task): model = _patch_model(model) - else: - raise NotImplementedError(f"The given model is not support yet") - - return model + else: + model = ipex_jit_trace(model, task, use_cache) + config.torchscript = True + config.torch_dtype = torch_dtype + return cls(model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False) @classmethod def _from_pretrained( @@ -462,7 +456,7 @@ def __init__( except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - if self._is_ipex_exported: + if self._is_ipex_exported and self._device.type == "cpu": self._reorder_cache = _ipex_reorder_cache else: # Check if _reorder_cache is a static method @@ -552,7 +546,7 @@ def forward( if "position_ids" in self.input_names or not self.input_names: inputs["position_ids"] = position_ids - if self.use_cache: + if self.use_cache and self._device.type == "cpu": if past_key_values is None: past_key_values = self._prepare_past_key_values(input_ids) From 9acaba40e19d10cd48e48e36f065ba91824e98f5 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Thu, 23 May 2024 10:32:53 +0200 Subject: [PATCH 14/31] Fix llama and gemma modeling patching for openvino export (#714) * Fix compatibility for transformers v4.41.0 llama and gemma modeling patching * fix for dev transformers version * update setup --- optimum/exporters/openvino/model_patcher.py | 104 +++++++++++++++++++- optimum/intel/openvino/trainer.py | 6 +- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 93a8430522..0265b3a5fc 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -301,7 +301,7 @@ def __exit__(self, exc_type, exc_value, traceback): # adopted from # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965 # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058 -def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): +def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): from transformers.modeling_attn_mask_utils import AttentionMaskConverter if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None: @@ -314,10 +314,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling # using minimum from dtype with larger bandwith (floa32) may lead to overflow # during execution on platforms with default lower precision (bfloat16, float16) min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] + # difference with original modeling if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache @@ -329,7 +331,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + # difference with original modeling causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -366,6 +370,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po return causal_mask +# adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036 +def _llama_gemma_update_causal_mask_latest( + self, + attention_mask, + input_tensor, + cache_position, + past_key_values, + output_attentions, +): + from transformers.cache_utils import StaticCache + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling + # using minimum from dtype with larger bandwith (floa32) may lead to overflow + # during execution on platforms with default lower precision (bfloat16, float16) + min_dtype = torch.finfo(torch.float16).min + + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + # difference with original modeling + causal_mask = ( + torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + ) + + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0 +if is_transformers_version(">", "4.40.2"): + _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest +else: + _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy + + class GemmaModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 0745a1cd79..c8b29800fa 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -906,7 +906,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): output_path = os.path.join(output_dir, OV_XML_FILE_NAME) self.compression_controller.prepare_for_export() model_type = self.model.config.model_type.replace("_", "-") - onnx_config_class = TasksManager.get_exporter_config_constructor( + exporter_config_class = TasksManager.get_exporter_config_constructor( exporter="onnx", model=self.model, task=self.task, @@ -914,9 +914,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): ) if self.task == "text-generation": - onnx_config = onnx_config_class(self.model.config, use_past=self.model.config.use_cache) + onnx_config = exporter_config_class(self.model.config, use_past=self.model.config.use_cache) else: - onnx_config = onnx_config_class(self.model.config) + onnx_config = exporter_config_class(self.model.config) num_parameters = self.model.num_parameters() save_as_external_data = use_external_data_format(num_parameters) or self.ov_config.save_onnx_model From 2f4909cacf36238955042de415cc43a63e332b7d Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Fri, 24 May 2024 17:59:38 +0200 Subject: [PATCH 15/31] Fix nncf quantization for decoder models (#727) * Fix nncf quantization for decoder models * add test * update op quant op * remove deprecated warning * update expected quantized * enable stateful * style --- optimum/exporters/ipex/model_patcher.py | 2 - optimum/exporters/ipex/modeling_utils.py | 250 ++++++++++++++------- optimum/intel/ipex/modeling_base.py | 4 +- optimum/intel/openvino/modeling_decoder.py | 5 +- optimum/intel/openvino/quantization.py | 9 +- tests/openvino/test_quantization.py | 18 +- 6 files changed, 186 insertions(+), 102 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index be40cdda46..ef30c2cc07 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -23,7 +23,6 @@ from .modeling_utils import ( _IPEXLlamaDecoderLayerRef, - _llama_layer_norm_forward, _llama_model_forward, ) @@ -66,7 +65,6 @@ def _patch_llama_model(model): raise ImportError(f"Only ipex version >= {ipex_version} supports llama model patching") convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e361802842..3c760f1ad0 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -58,11 +58,6 @@ def padding_attn_mask(attn_mask, alignment): new_attn_mask[..., :last_dim_size] = attn_mask return new_attn_mask -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 -def _llama_layer_norm_forward(self, hidden_states): - return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) - - # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -129,8 +124,8 @@ def _llama_model_forward( next_decoder_cache = () if use_cache else None seqlen = hidden_states.size(1) - head_dim = self.layers[0].attn.head_dim - sin, cos = self.layers[0].attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) + head_dim = self.layers[0].self_attn.head_dim + sin, cos = self.layers[0].self_attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) sin = sin.squeeze()[position_ids].unsqueeze(2) cos = cos.squeeze()[position_ids].unsqueeze(2) sin_cos = {"sin": sin, "cos": cos} @@ -213,7 +208,128 @@ def __init__(self, module, config, distributed=False) -> None: module.config.architectures[0], ) - + def qkv_gemm(self, hidden_states): + bsz, seq_len, _ = hidden_states.size() + if self.module_device == "xpu": + if self.num_key_value_heads == self.num_heads: + query = torch.empty( + (bsz, seq_len, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + key = torch.empty( + (bsz, seq_len, self.num_heads * self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + value = torch.empty( + (bsz, seq_len, self.num_heads * self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + torch.ops.torch_ipex.mm_qkv_out( + hidden_states, + self.qkv_proj_weight, + self.qkv_proj_bias, + query, + key, + value, + ) + else: + query = torch.empty( + (bsz, seq_len, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + key = torch.empty( + (bsz, seq_len, self.num_key_value_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + value = torch.empty( + (bsz, seq_len, self.num_key_value_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + torch.ops.torch_ipex.mm_qkv_group_out( + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value + ) + else: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = query.view(bsz, seq_len, self.num_heads, self.head_dim) + key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + return query, key, value + + def rope(self, query, key, kv_seq_len, position_ids, **kwargs): + if self.module_device == "xpu": + sin = kwargs.pop("sin", None) + cos = kwargs.pop("cos", None) + self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) + else: + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + + return query, key + + def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): + + if self.module_device == "xpu": + scale = 1.0 / math.sqrt(self.head_dim) + is_causal = False + attn_output = torch.xpu.IpexSDP( + query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False + ) + attn_weights = None + past_key_value = (key, value) if use_cache else None + else: + if use_cache: + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) + else: + value_states = value.transpose(1, 2) + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + + past_key_value = None + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + return attn_output, attn_weights, past_key_value + + def forward( self, hidden_states: torch.Tensor, @@ -242,83 +358,37 @@ def forward( residual (`torch.Tensor`): residual tensor to the layer of shape ` """ bsz, seq_len, _ = hidden_states.size() - - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len - - query = query.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - # Use ipex op to rotary position embedding more efficient. - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - - if use_cache: - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - None, - attention_mask, - ) - else: - value_states = value.transpose(1, 2) - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, seq_len, self.hidden_size) - - if hasattr(self, "mha_linear_add"): - attn_output = self.mha_linear_add(attn_output, residual) + + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, kv_seq_len, position_ids, **kwargs) + + if self.module_device == "xpu": + if past_key_value is not None: + key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output, attn_weights, past_key_value= self.sdpa(query, key, value, past_key_value, attention_mask, use_cache) + attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) + + if self.module_device == "xpu": + attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view([bsz, seq_len, self.hidden_size]) else: - attn_output = self.o_proj(attn_output) - attn_output = residual + attn_output - + if hasattr(self, "mha_linear_add"): + attn_output = self.mha_linear_add(attn_output, residual) + else: + attn_output = self.o_proj(attn_output) + attn_output = residual + attn_output + if not output_attentions: attn_weights = None return attn_output, past_key_value, attn_weights + def port_parameters(self, module): self.qkv_proj_bias = None self.qkv_proj_weight = None @@ -373,10 +443,6 @@ def port_parameters(self, module): class _IPEXLlamaMLPRef(nn.Module): def __init__(self, module, config, distributed=False) -> None: - if is_ipex_version("<", "2.3.0"): - raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") - from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd - super().__init__() for k, v in module.__dict__.items(): setattr(self, k, v) @@ -391,6 +457,7 @@ def __init__(self, module, config, distributed=False) -> None: self.port_parameter(module) torch.xpu.empty_cache() else: + from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd if not self.distributed: self.mlp_linear_add = LinearAdd(module.down_proj) del self.__dict__["_modules"]["down_proj"] @@ -444,7 +511,24 @@ def __init__(self, module, config, distributed=False): self.distributed = distributed self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) - + self.module_device = module.mlp.gate_proj.weight.device.type + from intel_extension_for_pytorch.llm.modules import RMSNorm + if self.module_device == "xpu": + self.input_layernorm = RMSNorm( + self.input_layernorm.weight, self.input_layernorm.variance_epsilon + ) + self.post_attention_layernorm = RMSNorm( + self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon + ) + else: + self.input_layernorm = RMSNorm( + self.hidden_size, self.input_layernorm.variance_epsilon, self.input_layernorm.weight + ) + self.post_attention_layernorm = RMSNorm( + self.hidden_size, self.post_attention_layernorm.variance_epsilon, self.post_attention_layernorm.weight + ) + + def forward( self, hidden_states: torch.Tensor, diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index ebc57ec3c1..af9ba30426 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -546,8 +546,8 @@ def forward( if "position_ids" in self.input_names or not self.input_names: inputs["position_ids"] = position_ids - if self.use_cache and self._device.type == "cpu": - if past_key_values is None: + if self.use_cache: + if past_key_values is None and self._device.type == "cpu": past_key_values = self._prepare_past_key_values(input_ids) inputs["past_key_values"] = past_key_values diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 933d92a502..72cd1b6487 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -42,7 +42,7 @@ from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .configuration import _DEFAULT_4BIT_CONFIGS, OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel -from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE +from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, STR_TO_OV_TYPE if TYPE_CHECKING: @@ -409,6 +409,7 @@ def prepare_inputs( elif self.use_cache: for input_name in self.key_value_input_names: model_inputs = self.model.input(input_name) + dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()] shape = model_inputs.get_partial_shape() if self.config.model_type == "chatglm": shape[0] = 0 @@ -419,7 +420,7 @@ def prepare_inputs( shape[2] = 0 else: shape[1] = 0 - inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) + inputs[input_name] = np.empty([dim.get_length() for dim in shape], dtype=dtype) else: # past_key_values are not used explicitly, instead they are handled inside the model if past_key_values is None: diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 17305b947e..43cf1dd93b 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -347,7 +347,6 @@ def _quantize_ovbasemodel( remove_unused_columns=remove_unused_columns, data_collator=data_collator, ) - if self.model.export_feature == "text-generation" and self.model.use_cache: calibration_dataset = self._prepare_text_generation_dataset( quantization_config, calibration_dataloader @@ -430,6 +429,7 @@ def _quantize_ovbasemodel( ), **kwargs, ) + self.model.model = quantized_model if save_directory is not None: self.model.save_pretrained(save_directory) @@ -696,8 +696,6 @@ def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConf def _prepare_text_generation_dataset( self, quantization_config: OVQuantizationConfig, calibration_dataloader: OVDataLoader ) -> nncf.Dataset: - # TODO: this function is not covered by tests, remove if not relevant anymore or cover by tests otherwise - # Prefetch past_key_values self.model.update_pkv_precision(True) self.model.compile() @@ -705,15 +703,16 @@ def _prepare_text_generation_dataset( num_samples = quantization_config.num_samples or 200 - self.model.request = InferRequestWrapper(self.model.model.request, collected_inputs) + self.model.request = InferRequestWrapper(self.model.request, collected_inputs) try: for data in calibration_dataloader: self.model.generate(**data, max_new_tokens=1) if len(collected_inputs) >= num_samples: break finally: - self.model.model.request = self.model.model.request.request + self.model.request = self.model.request.request calibration_dataset = nncf.Dataset(collected_inputs) + return calibration_dataset def _prepare_unet_dataset( diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 98eb121d72..09b395ea12 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -73,12 +73,16 @@ class OVQuantizerTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( + SUPPORTED_ARCHITECTURES_TORCH_MODEL = ( (OVModelForSequenceClassification, "bert", 32, 35), - # (OVModelForCausalLM, "gpt2", 41, 23), + (OVModelForCausalLM, "gpt2", 41, 3), + ) + SUPPORTED_ARCHITECTURES_OV_MODEL = ( + (OVModelForSequenceClassification, "bert", 32, 35), + (OVModelForCausalLM, "gpt2", 31, 22), ) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) + @parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL) def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8): model_id = MODEL_NAMES[model_name] task = model_cls.export_feature @@ -123,23 +127,21 @@ def preprocess_function(examples, tokenizer): loaded_config = OVConfig.from_pretrained(tmp_dir) self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict()) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) + @parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL) def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8): model_id = MODEL_NAMES[model_name] task = model_cls.export_feature dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task] - if "gpt2" in model_id: - expected_int8 -= 1 def preprocess_function(examples, tokenizer): return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True) with tempfile.TemporaryDirectory() as tmp_dir: - transformers_model = model_cls.from_pretrained(model_id, export=True) + ov_model = model_cls.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) + quantizer = OVQuantizer.from_pretrained(ov_model, task=task) calibration_dataset = quantizer.get_calibration_dataset( dataset_name, From f186ce72c3925c1bd07f6743836e38e376bcccab Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sat, 25 May 2024 21:36:31 -0700 Subject: [PATCH 16/31] remove --- optimum/exporters/ipex/model_patcher.py | 4 +- optimum/exporters/ipex/modeling/__init__.py | 1 - .../exporters/ipex/modeling/modeling_llama.py | 155 ------------ .../ipex/modeling/xpu/xpu_modeling_llama.py | 235 ------------------ optimum/exporters/ipex/modeling_utils.py | 124 +++++---- 5 files changed, 62 insertions(+), 457 deletions(-) delete mode 100644 optimum/exporters/ipex/modeling/__init__.py delete mode 100644 optimum/exporters/ipex/modeling/modeling_llama.py delete mode 100644 optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index ef30c2cc07..839d501637 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -16,7 +16,6 @@ LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, - LlamaRMSNorm, ) from optimum.intel.utils.import_utils import is_ipex_version @@ -26,7 +25,6 @@ _llama_model_forward, ) -from .modeling.modeling_llama import _IPEXLlamaDecoderLayer _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) _IPEX_EXPORTED_TASK = ("text-generation",) @@ -60,7 +58,7 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - ipex_version = "2.1.0" if "xpu" in str(model.device) else "2.5.0" + ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.3.0" if is_ipex_version("<", ipex_version): raise ImportError(f"Only ipex version >= {ipex_version} supports llama model patching") diff --git a/optimum/exporters/ipex/modeling/__init__.py b/optimum/exporters/ipex/modeling/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/optimum/exporters/ipex/modeling/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/optimum/exporters/ipex/modeling/modeling_llama.py b/optimum/exporters/ipex/modeling/modeling_llama.py deleted file mode 100644 index 0ecef1ff83..0000000000 --- a/optimum/exporters/ipex/modeling/modeling_llama.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch -import torch.nn as nn -from typing import Optional, Tuple -import intel_extension_for_pytorch as ipex - - -class _IPEXLlamaAttention(nn.Module): - def __init__(self, module, config, distributed=False) -> None: - super().__init__() - self.module = module - self.config = config - self.distributed = distributed - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - residual: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - residual (`torch.Tensor`): residual tensor to the layer of shape ` - """ - pass - - -class _IPEXLlamaMLP(nn.Module): - def __init__(self, module, config, distributed=False) -> None: - super().__init__() - self.module = module - self.config = config - self.distributed = distributed - - def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor, **kwargs): - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - """ - pass - - -class _IPEXLlamaDecoderLayer(nn.Module): - def __init__(self, module, config, distributed=False) -> None: - super().__init__() - self.layer_idx = module.self_attn.layer_idx - module_device = str(module.self_attn.q_proj.weight.device) - if "xpu" in module_device: - from .xpu.xpu_modeling_llama import _IPEXLlamaAttentionXPU, _IPEXLlamaMLPXPU - - self.attn = _IPEXLlamaAttentionXPU(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLPXPU(module.mlp, config, distributed) - else: - self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) - self.input_layernorm = ipex.llm.modules.RMSNorm( - module.input_layernorm.weight, module.input_layernorm.variance_epsilon - ) - self.post_attention_layernorm = ipex.llm.modules.RMSNorm( - module.post_attention_layernorm.weight, module.post_attention_layernorm.variance_epsilon - ) - - def preprocess_for_optimize( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - postion_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attention: Optional[bool] = True, - use_cache: Optional[bool] = False, - **kwargs, - ): - return hidden_states, attention_mask, postion_ids, past_key_value - - def postprocess_for_optimize( - self, hidden_states, output_attention, use_cache, self_attn_weight, present_key_value, **kwargs - ): - outputs = (hidden_states,) - if use_cache: - outputs += (present_key_value,) - if output_attention: - outputs += (self_attn_weight,) - - return outputs - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - outputs = self.preprocess_for_optimize( - hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs - ) - (hidden_states, attention_mask, position_ids, past_key_value) = outputs - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, present_key_value, self_attn_weight = self.attn( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - None, - residual, - **kwargs, - ) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states, residual, **kwargs) - - outputs = self.postprocess_for_optimize( - hidden_states, output_attentions, use_cache, self_attn_weight, present_key_value, **kwargs - ) - - return outputs diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py deleted file mode 100644 index 961c6b9609..0000000000 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ /dev/null @@ -1,235 +0,0 @@ -from typing import Tuple -import torch -import torch.nn as nn -from typing import Optional -import math -import gc - -import intel_extension_for_pytorch -from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.modules.llama import NewIPEXLLAMABlock -from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU - -from ..modeling_llama import _IPEXLlamaDecoderLayer, _IPEXLlamaAttention, _IPEXLlamaMLP - - -def matmul_add_add(attn_output, weight, bias=None, residual=None): - seq_len, bs, _ = attn_output.size() - if residual is None: - attn_output = torch.matmul(attn_output, weight) - if bias is not None: - attn_output += bias - else: - if bias is not None: - attn_output = torch.ops.torch_ipex.mm_bias_resadd(attn_output, weight, bias, 1.0, residual, 1.0) - else: - attn_output = torch.addmm( - residual.flatten(0, -2), - attn_output.flatten(0, -2), - weight, - beta=1.0, - ) - attn_output = attn_output.view(seq_len, bs, -1) - return attn_output - - -class _IPEXLlamaAttentionXPU(_IPEXLlamaAttention): - def __init__(self, module, config, distributed=False, optimized_module=None) -> None: - super().__init__(module, config, distributed) - self.num_heads = module.num_heads - self.head_dim = module.head_dim - self.num_kv_heads = module.num_key_value_heads - self.embed_dim = module.config.hidden_size - self.port_parameters(module) - torch.xpu.empty_cache() - from intel_extension_for_pytorch.llm.modules import RotaryEmbedding - - self.ipex_rope = _IPEXRopeXPU( - module.config.max_position_embeddings, - module.config.hidden_size // module.config.num_attention_heads, - module.config.rope_theta, - module.config.architectures[0], - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - residual: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - residual (`torch.Tensor`): residual tensor to the layer of shape ` - """ - # allocate cache and copy past_key_value - bs, seqlen, _ = hidden_states.size() - prev_seqlen = 0 - if past_key_value: - _, _, prev_seqlen, _ = past_key_value[0].size() - if self.num_kv_heads == self.num_heads: - query = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - value = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - torch.ops.torch_ipex.mm_qkv_out( - hidden_states, - self.qkv_proj_weight, - self.qkv_proj_bias, - query, - key, - value, - ) - else: - query = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - value = torch.empty( - (bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - torch.ops.torch_ipex.mm_qkv_group_out( - hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value - ) - - query = query.view([bs, seqlen, self.num_heads, self.head_dim]) - key = key.view([bs, seqlen, self.num_kv_heads, self.head_dim]) - - - sin = kwargs.pop("sin", None) - cos = kwargs.pop("cos", None) - - self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) - value = value.view([bs, seqlen, self.num_kv_heads, self.head_dim]) - if past_key_value is not None: - key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) - value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - present = (key, value) if use_cache else None - - scale = 1.0 / math.sqrt(self.head_dim) - is_causal = False - attn_output = torch.xpu.IpexSDP( - query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False - ) - attn_output = attn_output.transpose(1, 2).view([bs, seqlen, self.embed_dim]) - attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view( - [bs, seqlen, self.embed_dim] - ) - outputs = (attn_output, present) - if output_attentions: - raise ValueError("not support output attn_weight") - else: - outputs += (None,) - return outputs - - def port_parameters(self, module): - self.qkv_proj_bias = None - self.qkv_proj_weight = None - if self.num_heads == self.num_kv_heads: - q_proj = module.q_proj.weight.transpose(0, 1) - k_proj = module.k_proj.weight.transpose(0, 1) - v_proj = module.v_proj.weight.transpose(0, 1) - self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) - module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) - module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) - module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) - if module.q_proj.bias is not None: - self.qkv_proj_bias = ( - torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) - .contiguous() - .view([3, -1]) - ) - module.q_proj.bias.data = self.qkv_proj_bias[0] - module.k_proj.bias.data = self.qkv_proj_bias[1] - module.v_proj.bias.data = self.qkv_proj_bias[2] - else: - group = self.num_heads // self.num_kv_heads - q_proj = module.q_proj.weight.view(self.num_kv_heads, group, self.head_dim, self.embed_dim) - k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( - [self.num_kv_heads, group + 2, self.head_dim, self.embed_dim] - ) - module.q_proj.data = self.qkv_proj_weight[:, :group, :, :].view( - [self.num_kv_heads * group * self.head_dim, self.embed_dim] - ) - module.k_proj.data = self.qkv_proj_weight[:, group, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] - ) - module.v_proj.data = self.qkv_proj_weight[:, group + 1, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] - ) - if module.q_proj.bias is not None: - q_bias = module.q_proj.bias.view(self.num_kv_heads, group, self.head_dim) - k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) - v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) - self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( - [self.num_kv_heads, group + 2, self.head_dim] - ) - module.q_proj.bias.data = self.qkv_proj_bias[:, :group, self.head_dim].view(-1) - module.k_proj.bias.data = self.qkv_proj_bias[:, group, self.head_dim].view(-1) - module.v_proj.bias.data = self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1) - self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() - module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) - self.o_proj_bias = module.o_proj.bias - - -class _IPEXLlamaMLPXPU(_IPEXLlamaMLP): - def __init__(self, module, config, distributed=False, optimized_module=None) -> None: - super().__init__(module, config, distributed) - self.mlp_impl = None - if optimized_module is not None: - self.mlp_impl = optimized_module - self.port_parameter(module) - torch.xpu.empty_cache() - - def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - """ - up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) - out = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) - out = matmul_add_add(out, self.down_proj_weight, self.down_proj_bias, residual) - return out - - def port_parameter(self, module): - self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() - module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1) - self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() - module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1) - self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() - module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1) - self.up_proj_bias = module.up_proj.bias - self.gate_proj_bias = module.gate_proj.bias - self.down_proj_bias = module.down_proj.bias diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 3c760f1ad0..c0575e0e88 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -21,8 +21,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import repeat_kv -from optimum.intel.utils.import_utils import is_ipex_version - def matmul_add_add(attn_output, weight, bias=None, residual=None): seq_len, bs, _ = attn_output.size() @@ -43,6 +41,7 @@ def matmul_add_add(attn_output, weight, bias=None, residual=None): attn_output = attn_output.view(seq_len, bs, -1) return attn_output + def padding_attn_mask(attn_mask, alignment): if attn_mask is None: return None @@ -58,6 +57,7 @@ def padding_attn_mask(attn_mask, alignment): new_attn_mask[..., :last_dim_size] = attn_mask return new_attn_mask + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -185,6 +185,7 @@ def __init__(self, module, config, distributed=False) -> None: self.module_device = module.q_proj.weight.device.type if self.module_device == "xpu": from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU + self.ipex_rope = _IPEXRopeXPU( module.config.max_position_embeddings, module.config.hidden_size // module.config.num_attention_heads, @@ -194,7 +195,12 @@ def __init__(self, module, config, distributed=False) -> None: self.port_parameters(module) torch.xpu.empty_cache() else: - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + from intel_extension_for_pytorch.llm.modules import ( + IndirectAccessKVCacheAttention, + LinearAdd, + RotaryEmbedding, + ) + if not self.distributed: self.mha_linear_add = LinearAdd(self.o_proj) del self.__dict__["_modules"]["o_proj"] @@ -207,24 +213,18 @@ def __init__(self, module, config, distributed=False) -> None: module.config.rope_theta, module.config.architectures[0], ) - + def qkv_gemm(self, hidden_states): bsz, seq_len, _ = hidden_states.size() if self.module_device == "xpu": + query_shape = (bsz, seq_len, self.num_heads * self.head_dim) + kv_shape = (bsz, seq_len, self.num_key_value_heads * self.head_dim) + dtype = hidden_states.dtype + device = hidden_states.device if self.num_key_value_heads == self.num_heads: - query = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - value = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) + query = torch.empty(query_shape, dtype=dtype, device=device) + key = torch.empty(query_shape, dtype=dtype, device=device) + value = torch.empty(query_shape, dtype=dtype, device=device) torch.ops.torch_ipex.mm_qkv_out( hidden_states, self.qkv_proj_weight, @@ -234,15 +234,9 @@ def qkv_gemm(self, hidden_states): value, ) else: - query = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bsz, seq_len, self.num_key_value_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - value = torch.empty( - (bsz, seq_len, self.num_key_value_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) + query = torch.empty(query_shape, dtype=dtype, device=device) + key = torch.empty(kv_shape, dtype=dtype, device=device) + value = torch.empty(kv_shape, dtype=dtype, device=device) torch.ops.torch_ipex.mm_qkv_group_out( hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value ) @@ -254,9 +248,9 @@ def qkv_gemm(self, hidden_states): query = query.view(bsz, seq_len, self.num_heads, self.head_dim) key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - + return query, key, value - + def rope(self, query, key, kv_seq_len, position_ids, **kwargs): if self.module_device == "xpu": sin = kwargs.pop("sin", None) @@ -281,11 +275,10 @@ def rope(self, query, key, kv_seq_len, position_ids, **kwargs): self.head_dim, kv_seq_len, ) - - return query, key - + + return query, key + def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): - if self.module_device == "xpu": scale = 1.0 / math.sqrt(self.head_dim) is_causal = False @@ -326,10 +319,9 @@ def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - + return attn_output, attn_weights, past_key_value - - + def forward( self, hidden_states: torch.Tensor, @@ -359,10 +351,10 @@ def forward( """ bsz, seq_len, _ = hidden_states.size() kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len - + query, key, value = self.qkv_gemm(hidden_states) query, key = self.rope(query, key, kv_seq_len, position_ids, **kwargs) - + if self.module_device == "xpu": if past_key_value is not None: key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) @@ -370,25 +362,28 @@ def forward( query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - - attn_output, attn_weights, past_key_value= self.sdpa(query, key, value, past_key_value, attention_mask, use_cache) + + attn_output, attn_weights, past_key_value = self.sdpa( + query, key, value, past_key_value, attention_mask, use_cache + ) attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) - + if self.module_device == "xpu": - attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view([bsz, seq_len, self.hidden_size]) + attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view( + [bsz, seq_len, self.hidden_size] + ) else: if hasattr(self, "mha_linear_add"): attn_output = self.mha_linear_add(attn_output, residual) else: attn_output = self.o_proj(attn_output) attn_output = residual + attn_output - + if not output_attentions: attn_weights = None return attn_output, past_key_value, attn_weights - - + def port_parameters(self, module): self.qkv_proj_bias = None self.qkv_proj_weight = None @@ -410,20 +405,22 @@ def port_parameters(self, module): module.k_proj.bias.data = self.qkv_proj_bias[1] module.v_proj.bias.data = self.qkv_proj_bias[2] else: - q_proj = module.q_proj.weight.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim, self.embed_dim) - k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + q_proj = module.q_proj.weight.view( + self.num_kv_heads, self.num_key_value_groups, self.head_dim, self.hidden_size + ) + k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.hidden_size) + v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.hidden_size) self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( - [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim, self.embed_dim] + [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] ) - module.q_proj.data = self.qkv_proj_weight[:, :self.num_key_value_groups, :, :].view( - [self.num_kv_heads * self.num_key_value_groups * self.head_dim, self.embed_dim] + module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].view( + [self.num_kv_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] ) module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] + [self.num_kv_heads * self.head_dim, self.hidden_size] ) module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] + [self.num_kv_heads * self.head_dim, self.hidden_size] ) if module.q_proj.bias is not None: q_bias = module.q_proj.bias.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim) @@ -432,10 +429,9 @@ def port_parameters(self, module): self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim] ) - module.q_proj.bias.data = self.qkv_proj_bias[:, :self.num_key_value_groups, self.head_dim].view(-1) + module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1) module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1) - self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) self.o_proj_bias = module.o_proj.bias @@ -458,6 +454,7 @@ def __init__(self, module, config, distributed=False) -> None: torch.xpu.empty_cache() else: from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + if not self.distributed: self.mlp_linear_add = LinearAdd(module.down_proj) del self.__dict__["_modules"]["down_proj"] @@ -483,8 +480,10 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** hidden_states = self.down_proj(mlp_gate) hidden_states = residual + hidden_states else: - hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) - hidden_states = residual + hidden_states + hidden_states = self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + hidden_states = residual + hidden_states return hidden_states def port_parameter(self, module): @@ -498,6 +497,7 @@ def port_parameter(self, module): self.gate_proj_bias = module.gate_proj.bias self.down_proj_bias = module.down_proj.bias + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayerRef(nn.Module): def __init__(self, module, config, distributed=False): @@ -513,22 +513,20 @@ def __init__(self, module, config, distributed=False): self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) self.module_device = module.mlp.gate_proj.weight.device.type from intel_extension_for_pytorch.llm.modules import RMSNorm + if self.module_device == "xpu": - self.input_layernorm = RMSNorm( - self.input_layernorm.weight, self.input_layernorm.variance_epsilon - ) + self.input_layernorm = RMSNorm(self.input_layernorm.weight, self.input_layernorm.variance_epsilon) self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon ) else: self.input_layernorm = RMSNorm( - self.hidden_size, self.input_layernorm.variance_epsilon, self.input_layernorm.weight + self.hidden_size, self.input_layernorm.variance_epsilon, self.input_layernorm.weight ) self.post_attention_layernorm = RMSNorm( - self.hidden_size, self.post_attention_layernorm.variance_epsilon, self.post_attention_layernorm.weight + self.hidden_size, self.post_attention_layernorm.variance_epsilon, self.post_attention_layernorm.weight ) - - + def forward( self, hidden_states: torch.Tensor, From 1ff78b21663585660cfb55bf022f34c39d17a487 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sat, 25 May 2024 22:28:25 -0700 Subject: [PATCH 17/31] fix version --- optimum/intel/ipex/modeling_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index af9ba30426..e71d0c0ad1 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -63,7 +63,7 @@ def _is_patched_with_ipex(model, task): - ipex_version = "2.1.0" if model.device.type == "xpu" else "2.3.0" + ipex_version = "2.2.0" if model.device.type == "xpu" else "2.3.0" if is_ipex_version("<", ipex_version): return False From ff7f785ced1ff0d85c76570ccefa5c8eb4011b1f Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sun, 26 May 2024 19:17:36 -0400 Subject: [PATCH 18/31] bug fix --- optimum/intel/ipex/modeling_base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index e71d0c0ad1..6316fbd225 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -63,17 +63,18 @@ def _is_patched_with_ipex(model, task): - ipex_version = "2.2.0" if model.device.type == "xpu" else "2.3.0" - if is_ipex_version("<", ipex_version): - return False - if isinstance(model, torch.jit.ScriptModule): + if is_ipex_version("<", "2.3.0"): + return False for node in model.graph.nodes(): # Jit will record the codes position so we can check if the node use ipex exporter. if "torch_ipex::rotary_position_embedding" in node.__str__(): return True return False else: + ipex_version = "2.2.0" if model.device.type == "xpu" else "2.3.0" + if is_ipex_version("<", ipex_version): + return False return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK From e3dac8975ea1004b00dbfae56aec37a0d0988af6 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sun, 26 May 2024 19:39:35 -0400 Subject: [PATCH 19/31] change module --- optimum/exporters/ipex/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index c0575e0e88..ca03b62fd9 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -448,7 +448,7 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed - self.module_device = module.gate_proj.weight.device.type + self.module_device = module.down_proj.weight.device.type if self.module_device == "xpu": self.port_parameter(module) torch.xpu.empty_cache() From 8725f49add550bfff9bcc1270b7adfda7210ed1c Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sun, 26 May 2024 19:57:27 -0400 Subject: [PATCH 20/31] improve device --- optimum/exporters/ipex/modeling_utils.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index ca03b62fd9..6dc2145082 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -122,13 +122,15 @@ def _llama_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - - seqlen = hidden_states.size(1) - head_dim = self.layers[0].self_attn.head_dim - sin, cos = self.layers[0].self_attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) - sin = sin.squeeze()[position_ids].unsqueeze(2) - cos = cos.squeeze()[position_ids].unsqueeze(2) - sin_cos = {"sin": sin, "cos": cos} + if hidden_states.device.type == "xpu": + seqlen = hidden_states.size(1) + head_dim = self.layers[0].self_attn.head_dim + sin, cos = self.layers[0].self_attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) + sin = sin.squeeze()[position_ids].unsqueeze(2) + cos = cos.squeeze()[position_ids].unsqueeze(2) + decoder_layer_kwargs = {"sin": sin, "cos": cos} + else: + decoder_layer_kwargs = {} for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -142,7 +144,7 @@ def _llama_model_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - **sin_cos, + **decoder_layer_kwargs, ) hidden_states = layer_outputs[0] @@ -182,7 +184,7 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed - self.module_device = module.q_proj.weight.device.type + self.module_device = next(module.parameters()).device.type if self.module_device == "xpu": from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU @@ -448,7 +450,7 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed - self.module_device = module.down_proj.weight.device.type + self.module_device = next(module.parameters()).device.type if self.module_device == "xpu": self.port_parameter(module) torch.xpu.empty_cache() @@ -510,8 +512,8 @@ def __init__(self, module, config, distributed=False): setattr(self.__class__, k, getattr(module.__class__, k)) self.distributed = distributed self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) + self.module_device = next(module.parameters()).device.type self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) - self.module_device = module.mlp.gate_proj.weight.device.type from intel_extension_for_pytorch.llm.modules import RMSNorm if self.module_device == "xpu": From 57cfe117bff9ee6a1e184e9299d2ac63fc4dfe55 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sat, 25 May 2024 21:36:31 -0700 Subject: [PATCH 21/31] remove --- optimum/exporters/ipex/model_patcher.py | 4 +- optimum/exporters/ipex/modeling/__init__.py | 1 - .../exporters/ipex/modeling/modeling_llama.py | 155 ------------ .../ipex/modeling/xpu/xpu_modeling_llama.py | 235 ------------------ optimum/exporters/ipex/modeling_utils.py | 148 +++++------ optimum/intel/ipex/modeling_base.py | 9 +- 6 files changed, 80 insertions(+), 472 deletions(-) delete mode 100644 optimum/exporters/ipex/modeling/__init__.py delete mode 100644 optimum/exporters/ipex/modeling/modeling_llama.py delete mode 100644 optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index ef30c2cc07..839d501637 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -16,7 +16,6 @@ LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, - LlamaRMSNorm, ) from optimum.intel.utils.import_utils import is_ipex_version @@ -26,7 +25,6 @@ _llama_model_forward, ) -from .modeling.modeling_llama import _IPEXLlamaDecoderLayer _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) _IPEX_EXPORTED_TASK = ("text-generation",) @@ -60,7 +58,7 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - ipex_version = "2.1.0" if "xpu" in str(model.device) else "2.5.0" + ipex_version = "2.2.0" if "xpu" in str(model.device) else "2.3.0" if is_ipex_version("<", ipex_version): raise ImportError(f"Only ipex version >= {ipex_version} supports llama model patching") diff --git a/optimum/exporters/ipex/modeling/__init__.py b/optimum/exporters/ipex/modeling/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/optimum/exporters/ipex/modeling/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/optimum/exporters/ipex/modeling/modeling_llama.py b/optimum/exporters/ipex/modeling/modeling_llama.py deleted file mode 100644 index 0ecef1ff83..0000000000 --- a/optimum/exporters/ipex/modeling/modeling_llama.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch -import torch.nn as nn -from typing import Optional, Tuple -import intel_extension_for_pytorch as ipex - - -class _IPEXLlamaAttention(nn.Module): - def __init__(self, module, config, distributed=False) -> None: - super().__init__() - self.module = module - self.config = config - self.distributed = distributed - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - residual: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - residual (`torch.Tensor`): residual tensor to the layer of shape ` - """ - pass - - -class _IPEXLlamaMLP(nn.Module): - def __init__(self, module, config, distributed=False) -> None: - super().__init__() - self.module = module - self.config = config - self.distributed = distributed - - def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor, **kwargs): - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - """ - pass - - -class _IPEXLlamaDecoderLayer(nn.Module): - def __init__(self, module, config, distributed=False) -> None: - super().__init__() - self.layer_idx = module.self_attn.layer_idx - module_device = str(module.self_attn.q_proj.weight.device) - if "xpu" in module_device: - from .xpu.xpu_modeling_llama import _IPEXLlamaAttentionXPU, _IPEXLlamaMLPXPU - - self.attn = _IPEXLlamaAttentionXPU(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLPXPU(module.mlp, config, distributed) - else: - self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) - self.input_layernorm = ipex.llm.modules.RMSNorm( - module.input_layernorm.weight, module.input_layernorm.variance_epsilon - ) - self.post_attention_layernorm = ipex.llm.modules.RMSNorm( - module.post_attention_layernorm.weight, module.post_attention_layernorm.variance_epsilon - ) - - def preprocess_for_optimize( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - postion_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attention: Optional[bool] = True, - use_cache: Optional[bool] = False, - **kwargs, - ): - return hidden_states, attention_mask, postion_ids, past_key_value - - def postprocess_for_optimize( - self, hidden_states, output_attention, use_cache, self_attn_weight, present_key_value, **kwargs - ): - outputs = (hidden_states,) - if use_cache: - outputs += (present_key_value,) - if output_attention: - outputs += (self_attn_weight,) - - return outputs - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - outputs = self.preprocess_for_optimize( - hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs - ) - (hidden_states, attention_mask, position_ids, past_key_value) = outputs - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, present_key_value, self_attn_weight = self.attn( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - None, - residual, - **kwargs, - ) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states, residual, **kwargs) - - outputs = self.postprocess_for_optimize( - hidden_states, output_attentions, use_cache, self_attn_weight, present_key_value, **kwargs - ) - - return outputs diff --git a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py b/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py deleted file mode 100644 index 961c6b9609..0000000000 --- a/optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py +++ /dev/null @@ -1,235 +0,0 @@ -from typing import Tuple -import torch -import torch.nn as nn -from typing import Optional -import math -import gc - -import intel_extension_for_pytorch -from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.modules.llama import NewIPEXLLAMABlock -from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU - -from ..modeling_llama import _IPEXLlamaDecoderLayer, _IPEXLlamaAttention, _IPEXLlamaMLP - - -def matmul_add_add(attn_output, weight, bias=None, residual=None): - seq_len, bs, _ = attn_output.size() - if residual is None: - attn_output = torch.matmul(attn_output, weight) - if bias is not None: - attn_output += bias - else: - if bias is not None: - attn_output = torch.ops.torch_ipex.mm_bias_resadd(attn_output, weight, bias, 1.0, residual, 1.0) - else: - attn_output = torch.addmm( - residual.flatten(0, -2), - attn_output.flatten(0, -2), - weight, - beta=1.0, - ) - attn_output = attn_output.view(seq_len, bs, -1) - return attn_output - - -class _IPEXLlamaAttentionXPU(_IPEXLlamaAttention): - def __init__(self, module, config, distributed=False, optimized_module=None) -> None: - super().__init__(module, config, distributed) - self.num_heads = module.num_heads - self.head_dim = module.head_dim - self.num_kv_heads = module.num_key_value_heads - self.embed_dim = module.config.hidden_size - self.port_parameters(module) - torch.xpu.empty_cache() - from intel_extension_for_pytorch.llm.modules import RotaryEmbedding - - self.ipex_rope = _IPEXRopeXPU( - module.config.max_position_embeddings, - module.config.hidden_size // module.config.num_attention_heads, - module.config.rope_theta, - module.config.architectures[0], - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - residual: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - residual (`torch.Tensor`): residual tensor to the layer of shape ` - """ - # allocate cache and copy past_key_value - bs, seqlen, _ = hidden_states.size() - prev_seqlen = 0 - if past_key_value: - _, _, prev_seqlen, _ = past_key_value[0].size() - if self.num_kv_heads == self.num_heads: - query = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - value = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - torch.ops.torch_ipex.mm_qkv_out( - hidden_states, - self.qkv_proj_weight, - self.qkv_proj_bias, - query, - key, - value, - ) - else: - query = torch.empty( - (bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - value = torch.empty( - (bs, seqlen, self.num_kv_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - torch.ops.torch_ipex.mm_qkv_group_out( - hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value - ) - - query = query.view([bs, seqlen, self.num_heads, self.head_dim]) - key = key.view([bs, seqlen, self.num_kv_heads, self.head_dim]) - - - sin = kwargs.pop("sin", None) - cos = kwargs.pop("cos", None) - - self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) - value = value.view([bs, seqlen, self.num_kv_heads, self.head_dim]) - if past_key_value is not None: - key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) - value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - present = (key, value) if use_cache else None - - scale = 1.0 / math.sqrt(self.head_dim) - is_causal = False - attn_output = torch.xpu.IpexSDP( - query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False - ) - attn_output = attn_output.transpose(1, 2).view([bs, seqlen, self.embed_dim]) - attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view( - [bs, seqlen, self.embed_dim] - ) - outputs = (attn_output, present) - if output_attentions: - raise ValueError("not support output attn_weight") - else: - outputs += (None,) - return outputs - - def port_parameters(self, module): - self.qkv_proj_bias = None - self.qkv_proj_weight = None - if self.num_heads == self.num_kv_heads: - q_proj = module.q_proj.weight.transpose(0, 1) - k_proj = module.k_proj.weight.transpose(0, 1) - v_proj = module.v_proj.weight.transpose(0, 1) - self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) - module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) - module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) - module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) - if module.q_proj.bias is not None: - self.qkv_proj_bias = ( - torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) - .contiguous() - .view([3, -1]) - ) - module.q_proj.bias.data = self.qkv_proj_bias[0] - module.k_proj.bias.data = self.qkv_proj_bias[1] - module.v_proj.bias.data = self.qkv_proj_bias[2] - else: - group = self.num_heads // self.num_kv_heads - q_proj = module.q_proj.weight.view(self.num_kv_heads, group, self.head_dim, self.embed_dim) - k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( - [self.num_kv_heads, group + 2, self.head_dim, self.embed_dim] - ) - module.q_proj.data = self.qkv_proj_weight[:, :group, :, :].view( - [self.num_kv_heads * group * self.head_dim, self.embed_dim] - ) - module.k_proj.data = self.qkv_proj_weight[:, group, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] - ) - module.v_proj.data = self.qkv_proj_weight[:, group + 1, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] - ) - if module.q_proj.bias is not None: - q_bias = module.q_proj.bias.view(self.num_kv_heads, group, self.head_dim) - k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) - v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) - self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( - [self.num_kv_heads, group + 2, self.head_dim] - ) - module.q_proj.bias.data = self.qkv_proj_bias[:, :group, self.head_dim].view(-1) - module.k_proj.bias.data = self.qkv_proj_bias[:, group, self.head_dim].view(-1) - module.v_proj.bias.data = self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1) - self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() - module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) - self.o_proj_bias = module.o_proj.bias - - -class _IPEXLlamaMLPXPU(_IPEXLlamaMLP): - def __init__(self, module, config, distributed=False, optimized_module=None) -> None: - super().__init__(module, config, distributed) - self.mlp_impl = None - if optimized_module is not None: - self.mlp_impl = optimized_module - self.port_parameter(module) - torch.xpu.empty_cache() - - def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - """ - up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) - out = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) - out = matmul_add_add(out, self.down_proj_weight, self.down_proj_bias, residual) - return out - - def port_parameter(self, module): - self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() - module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1) - self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() - module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1) - self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() - module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1) - self.up_proj_bias = module.up_proj.bias - self.gate_proj_bias = module.gate_proj.bias - self.down_proj_bias = module.down_proj.bias diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 3c760f1ad0..6dc2145082 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -21,8 +21,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import repeat_kv -from optimum.intel.utils.import_utils import is_ipex_version - def matmul_add_add(attn_output, weight, bias=None, residual=None): seq_len, bs, _ = attn_output.size() @@ -43,6 +41,7 @@ def matmul_add_add(attn_output, weight, bias=None, residual=None): attn_output = attn_output.view(seq_len, bs, -1) return attn_output + def padding_attn_mask(attn_mask, alignment): if attn_mask is None: return None @@ -58,6 +57,7 @@ def padding_attn_mask(attn_mask, alignment): new_attn_mask[..., :last_dim_size] = attn_mask return new_attn_mask + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -122,13 +122,15 @@ def _llama_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - - seqlen = hidden_states.size(1) - head_dim = self.layers[0].self_attn.head_dim - sin, cos = self.layers[0].self_attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) - sin = sin.squeeze()[position_ids].unsqueeze(2) - cos = cos.squeeze()[position_ids].unsqueeze(2) - sin_cos = {"sin": sin, "cos": cos} + if hidden_states.device.type == "xpu": + seqlen = hidden_states.size(1) + head_dim = self.layers[0].self_attn.head_dim + sin, cos = self.layers[0].self_attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) + sin = sin.squeeze()[position_ids].unsqueeze(2) + cos = cos.squeeze()[position_ids].unsqueeze(2) + decoder_layer_kwargs = {"sin": sin, "cos": cos} + else: + decoder_layer_kwargs = {} for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -142,7 +144,7 @@ def _llama_model_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - **sin_cos, + **decoder_layer_kwargs, ) hidden_states = layer_outputs[0] @@ -182,9 +184,10 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed - self.module_device = module.q_proj.weight.device.type + self.module_device = next(module.parameters()).device.type if self.module_device == "xpu": from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU + self.ipex_rope = _IPEXRopeXPU( module.config.max_position_embeddings, module.config.hidden_size // module.config.num_attention_heads, @@ -194,7 +197,12 @@ def __init__(self, module, config, distributed=False) -> None: self.port_parameters(module) torch.xpu.empty_cache() else: - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + from intel_extension_for_pytorch.llm.modules import ( + IndirectAccessKVCacheAttention, + LinearAdd, + RotaryEmbedding, + ) + if not self.distributed: self.mha_linear_add = LinearAdd(self.o_proj) del self.__dict__["_modules"]["o_proj"] @@ -207,24 +215,18 @@ def __init__(self, module, config, distributed=False) -> None: module.config.rope_theta, module.config.architectures[0], ) - + def qkv_gemm(self, hidden_states): bsz, seq_len, _ = hidden_states.size() if self.module_device == "xpu": + query_shape = (bsz, seq_len, self.num_heads * self.head_dim) + kv_shape = (bsz, seq_len, self.num_key_value_heads * self.head_dim) + dtype = hidden_states.dtype + device = hidden_states.device if self.num_key_value_heads == self.num_heads: - query = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - value = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) + query = torch.empty(query_shape, dtype=dtype, device=device) + key = torch.empty(query_shape, dtype=dtype, device=device) + value = torch.empty(query_shape, dtype=dtype, device=device) torch.ops.torch_ipex.mm_qkv_out( hidden_states, self.qkv_proj_weight, @@ -234,15 +236,9 @@ def qkv_gemm(self, hidden_states): value, ) else: - query = torch.empty( - (bsz, seq_len, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - key = torch.empty( - (bsz, seq_len, self.num_key_value_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - value = torch.empty( - (bsz, seq_len, self.num_key_value_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) + query = torch.empty(query_shape, dtype=dtype, device=device) + key = torch.empty(kv_shape, dtype=dtype, device=device) + value = torch.empty(kv_shape, dtype=dtype, device=device) torch.ops.torch_ipex.mm_qkv_group_out( hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value ) @@ -254,9 +250,9 @@ def qkv_gemm(self, hidden_states): query = query.view(bsz, seq_len, self.num_heads, self.head_dim) key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - + return query, key, value - + def rope(self, query, key, kv_seq_len, position_ids, **kwargs): if self.module_device == "xpu": sin = kwargs.pop("sin", None) @@ -281,11 +277,10 @@ def rope(self, query, key, kv_seq_len, position_ids, **kwargs): self.head_dim, kv_seq_len, ) - - return query, key - + + return query, key + def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): - if self.module_device == "xpu": scale = 1.0 / math.sqrt(self.head_dim) is_causal = False @@ -326,10 +321,9 @@ def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - + return attn_output, attn_weights, past_key_value - - + def forward( self, hidden_states: torch.Tensor, @@ -359,10 +353,10 @@ def forward( """ bsz, seq_len, _ = hidden_states.size() kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len - + query, key, value = self.qkv_gemm(hidden_states) query, key = self.rope(query, key, kv_seq_len, position_ids, **kwargs) - + if self.module_device == "xpu": if past_key_value is not None: key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) @@ -370,25 +364,28 @@ def forward( query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - - attn_output, attn_weights, past_key_value= self.sdpa(query, key, value, past_key_value, attention_mask, use_cache) + + attn_output, attn_weights, past_key_value = self.sdpa( + query, key, value, past_key_value, attention_mask, use_cache + ) attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) - + if self.module_device == "xpu": - attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view([bsz, seq_len, self.hidden_size]) + attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view( + [bsz, seq_len, self.hidden_size] + ) else: if hasattr(self, "mha_linear_add"): attn_output = self.mha_linear_add(attn_output, residual) else: attn_output = self.o_proj(attn_output) attn_output = residual + attn_output - + if not output_attentions: attn_weights = None return attn_output, past_key_value, attn_weights - - + def port_parameters(self, module): self.qkv_proj_bias = None self.qkv_proj_weight = None @@ -410,20 +407,22 @@ def port_parameters(self, module): module.k_proj.bias.data = self.qkv_proj_bias[1] module.v_proj.bias.data = self.qkv_proj_bias[2] else: - q_proj = module.q_proj.weight.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim, self.embed_dim) - k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) - v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim) + q_proj = module.q_proj.weight.view( + self.num_kv_heads, self.num_key_value_groups, self.head_dim, self.hidden_size + ) + k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.hidden_size) + v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.hidden_size) self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( - [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim, self.embed_dim] + [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] ) - module.q_proj.data = self.qkv_proj_weight[:, :self.num_key_value_groups, :, :].view( - [self.num_kv_heads * self.num_key_value_groups * self.head_dim, self.embed_dim] + module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].view( + [self.num_kv_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] ) module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] + [self.num_kv_heads * self.head_dim, self.hidden_size] ) module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].view( - [self.num_kv_heads * self.head_dim, self.embed_dim] + [self.num_kv_heads * self.head_dim, self.hidden_size] ) if module.q_proj.bias is not None: q_bias = module.q_proj.bias.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim) @@ -432,10 +431,9 @@ def port_parameters(self, module): self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim] ) - module.q_proj.bias.data = self.qkv_proj_bias[:, :self.num_key_value_groups, self.head_dim].view(-1) + module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1) module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1) - self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) self.o_proj_bias = module.o_proj.bias @@ -452,12 +450,13 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed - self.module_device = module.gate_proj.weight.device.type + self.module_device = next(module.parameters()).device.type if self.module_device == "xpu": self.port_parameter(module) torch.xpu.empty_cache() else: from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + if not self.distributed: self.mlp_linear_add = LinearAdd(module.down_proj) del self.__dict__["_modules"]["down_proj"] @@ -483,8 +482,10 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** hidden_states = self.down_proj(mlp_gate) hidden_states = residual + hidden_states else: - hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) - hidden_states = residual + hidden_states + hidden_states = self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + hidden_states = residual + hidden_states return hidden_states def port_parameter(self, module): @@ -498,6 +499,7 @@ def port_parameter(self, module): self.gate_proj_bias = module.gate_proj.bias self.down_proj_bias = module.down_proj.bias + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayerRef(nn.Module): def __init__(self, module, config, distributed=False): @@ -510,25 +512,23 @@ def __init__(self, module, config, distributed=False): setattr(self.__class__, k, getattr(module.__class__, k)) self.distributed = distributed self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) + self.module_device = next(module.parameters()).device.type self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) - self.module_device = module.mlp.gate_proj.weight.device.type from intel_extension_for_pytorch.llm.modules import RMSNorm + if self.module_device == "xpu": - self.input_layernorm = RMSNorm( - self.input_layernorm.weight, self.input_layernorm.variance_epsilon - ) + self.input_layernorm = RMSNorm(self.input_layernorm.weight, self.input_layernorm.variance_epsilon) self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon ) else: self.input_layernorm = RMSNorm( - self.hidden_size, self.input_layernorm.variance_epsilon, self.input_layernorm.weight + self.hidden_size, self.input_layernorm.variance_epsilon, self.input_layernorm.weight ) self.post_attention_layernorm = RMSNorm( - self.hidden_size, self.post_attention_layernorm.variance_epsilon, self.post_attention_layernorm.weight + self.hidden_size, self.post_attention_layernorm.variance_epsilon, self.post_attention_layernorm.weight ) - - + def forward( self, hidden_states: torch.Tensor, diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index af9ba30426..6316fbd225 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -63,17 +63,18 @@ def _is_patched_with_ipex(model, task): - ipex_version = "2.1.0" if model.device.type == "xpu" else "2.3.0" - if is_ipex_version("<", ipex_version): - return False - if isinstance(model, torch.jit.ScriptModule): + if is_ipex_version("<", "2.3.0"): + return False for node in model.graph.nodes(): # Jit will record the codes position so we can check if the node use ipex exporter. if "torch_ipex::rotary_position_embedding" in node.__str__(): return True return False else: + ipex_version = "2.2.0" if model.device.type == "xpu" else "2.3.0" + if is_ipex_version("<", ipex_version): + return False return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK From ee78f95460f894aadb19da99a6e5b24ce754e568 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sun, 26 May 2024 23:17:12 -0700 Subject: [PATCH 22/31] simplfy rmsnorm --- optimum/exporters/ipex/model_patcher.py | 7 +++-- optimum/exporters/ipex/modeling_utils.py | 39 +++++++++--------------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 839d501637..b1872ef622 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -16,12 +16,14 @@ LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, + LlamaRMSNorm, ) from optimum.intel.utils.import_utils import is_ipex_version from .modeling_utils import ( - _IPEXLlamaDecoderLayerRef, + _IPEXLlamaDecoderLayer, + _llama_layer_norm_forward, _llama_model_forward, ) @@ -63,7 +65,8 @@ def _patch_llama_model(model): raise ImportError(f"Only ipex version >= {ipex_version} supports llama model patching") convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) + convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 6dc2145082..fc8676dc2d 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union import torch +from intel_extension_for_pytorch.llm.functional import rms_norm from torch import nn from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast @@ -58,6 +59,11 @@ def padding_attn_mask(attn_mask, alignment): return new_attn_mask +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 +def _llama_layer_norm_forward(self, hidden_states): + return rms_norm(hidden_states, self.weight, self.variance_epsilon) + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -173,7 +179,7 @@ def _llama_model_forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 -class _IPEXLlamaAttentionRef(nn.Module): +class _IPEXLlamaAttention(nn.Module): def __init__(self, module, config, distributed=False) -> None: super().__init__() for k, v in module.__dict__.items(): @@ -322,7 +328,7 @@ def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - return attn_output, attn_weights, past_key_value + return attn_output, past_key_value, attn_weights def forward( self, @@ -365,7 +371,7 @@ def forward( key = key.transpose(1, 2) value = value.transpose(1, 2) - attn_output, attn_weights, past_key_value = self.sdpa( + attn_output, past_key_value, attn_weights = self.sdpa( query, key, value, past_key_value, attention_mask, use_cache ) attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) @@ -384,7 +390,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, past_key_value, attn_weights + return attn_output, attn_weights, past_key_value def port_parameters(self, module): self.qkv_proj_bias = None @@ -439,7 +445,7 @@ def port_parameters(self, module): self.o_proj_bias = module.o_proj.bias -class _IPEXLlamaMLPRef(nn.Module): +class _IPEXLlamaMLP(nn.Module): def __init__(self, module, config, distributed=False) -> None: super().__init__() for k, v in module.__dict__.items(): @@ -501,7 +507,7 @@ def port_parameter(self, module): # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 -class _IPEXLlamaDecoderLayerRef(nn.Module): +class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config, distributed=False): super().__init__() for k, v in module.__dict__.items(): @@ -511,23 +517,8 @@ def __init__(self, module, config, distributed=False): continue setattr(self.__class__, k, getattr(module.__class__, k)) self.distributed = distributed - self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) - self.module_device = next(module.parameters()).device.type - self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) - from intel_extension_for_pytorch.llm.modules import RMSNorm - - if self.module_device == "xpu": - self.input_layernorm = RMSNorm(self.input_layernorm.weight, self.input_layernorm.variance_epsilon) - self.post_attention_layernorm = RMSNorm( - self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon - ) - else: - self.input_layernorm = RMSNorm( - self.hidden_size, self.input_layernorm.variance_epsilon, self.input_layernorm.weight - ) - self.post_attention_layernorm = RMSNorm( - self.hidden_size, self.post_attention_layernorm.variance_epsilon, self.post_attention_layernorm.weight - ) + self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) def forward( self, @@ -558,7 +549,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, present_key_value, self_attn_weights = self.self_attn( + hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, From 60989435b7cef1927b4a610326bc26716ff30429 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sun, 26 May 2024 23:24:46 -0700 Subject: [PATCH 23/31] style --- optimum/exporters/ipex/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 9975494218..b1872ef622 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -16,6 +16,7 @@ LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, + LlamaRMSNorm, ) from optimum.intel.utils.import_utils import is_ipex_version From e0fb06e7b901554a517c3c5043bbb7d88b8d4636 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Thu, 6 Jun 2024 18:41:50 -0700 Subject: [PATCH 24/31] fix group attention --- optimum/exporters/ipex/modeling_utils.py | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index fc8676dc2d..7d4566feb4 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -414,28 +414,28 @@ def port_parameters(self, module): module.v_proj.bias.data = self.qkv_proj_bias[2] else: q_proj = module.q_proj.weight.view( - self.num_kv_heads, self.num_key_value_groups, self.head_dim, self.hidden_size + self.num_key_value_heads, self.num_key_value_groups, self.head_dim, self.hidden_size ) - k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.hidden_size) - v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.hidden_size) + k_proj = module.k_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) + v_proj = module.v_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( - [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] ) - module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].view( - [self.num_kv_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] + module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] ) - module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].view( - [self.num_kv_heads * self.head_dim, self.hidden_size] + module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] ) - module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].view( - [self.num_kv_heads * self.head_dim, self.hidden_size] + module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] ) if module.q_proj.bias is not None: - q_bias = module.q_proj.bias.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim) - k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim) - v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim) + q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim) + k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) + v_bias = module.v_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( - [self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim] + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim] ) module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1) module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) From aa8d395f722d45e7d771909f260dd0bbd32d83b7 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Thu, 6 Jun 2024 22:55:09 -0700 Subject: [PATCH 25/31] fix weight shape --- optimum/exporters/ipex/modeling_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 7d4566feb4..a5cbff3eaf 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -430,6 +430,7 @@ def port_parameters(self, module): module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape( [self.num_key_value_heads * self.head_dim, self.hidden_size] ) + self.qkv_proj_weight = self.qkv_proj_weight.permute(3, 0, 1, 2).contiguous() if module.q_proj.bias is not None: q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim) k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) From 548d83fc4ccf132d1886d9cd5210bbaa8e0deb1e Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 7 Jun 2024 01:54:31 -0700 Subject: [PATCH 26/31] fix rebase bug --- optimum/exporters/ipex/modeling_utils.py | 285 +++++++++++++++----- optimum/exporters/openvino/model_patcher.py | 12 +- 2 files changed, 222 insertions(+), 75 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 6e90059a00..b9b7034f93 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -21,11 +21,14 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + from optimum.intel.utils.import_utils import is_ipex_version from optimum.intel.utils.modeling_utils import _setattr_from_module + _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" + def matmul_add_add(attn_output, weight, bias=None, residual=None): seq_len, bs, _ = attn_output.size() if residual is None: @@ -192,27 +195,68 @@ def __init__(self, module, config, distributed=False) -> None: _setattr_from_module(self, module) self.config = config self.distributed = distributed - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + self.module_device = next(module.parameters()).device.type + if self.module_device == "xpu": + from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU + + self.ipex_rope = _IPEXRopeXPU( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) + self.port_parameters(module) + torch.xpu.empty_cache() + else: + from intel_extension_for_pytorch.llm.modules import ( + IndirectAccessKVCacheAttention, + LinearAdd, + RotaryEmbedding, + ) - if not self.distributed: - self.mha_linear_add = LinearAdd(self.o_proj) - del self.__dict__["_modules"]["o_proj"] - self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( - text_max_length=module.config.max_position_embeddings - ) - self.ipex_rope = RotaryEmbedding( - module.config.max_position_embeddings, - module.config.hidden_size // module.config.num_attention_heads, - module.config.rope_theta, - module.config.architectures[0], - ) + if not self.distributed: + self.mha_linear_add = LinearAdd(self.o_proj) + del self.__dict__["_modules"]["o_proj"] + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=module.config.max_position_embeddings + ) + self.ipex_rope = RotaryEmbedding( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) def qkv_gemm(self, hidden_states): bsz, seq_len, _ = hidden_states.size() - - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) + if self.module_device == "xpu": + query_shape = (bsz, seq_len, self.num_heads * self.head_dim) + kv_shape = (bsz, seq_len, self.num_key_value_heads * self.head_dim) + dtype = hidden_states.dtype + device = hidden_states.device + if self.num_key_value_heads == self.num_heads: + query = torch.empty(query_shape, dtype=dtype, device=device) + key = torch.empty(query_shape, dtype=dtype, device=device) + value = torch.empty(query_shape, dtype=dtype, device=device) + torch.ops.torch_ipex.mm_qkv_out( + hidden_states, + self.qkv_proj_weight, + self.qkv_proj_bias, + query, + key, + value, + ) + else: + query = torch.empty(query_shape, dtype=dtype, device=device) + key = torch.empty(kv_shape, dtype=dtype, device=device) + value = torch.empty(kv_shape, dtype=dtype, device=device) + torch.ops.torch_ipex.mm_qkv_group_out( + hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value + ) + else: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) query = query.view(bsz, seq_len, self.num_heads, self.head_dim) key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) @@ -220,40 +264,55 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, kv_seq_len, position_ids, use_cache): - if use_cache: - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) + def rope(self, query, key, kv_seq_len, position_ids, use_cache, **kwargs): + if self.module_device == "xpu": + sin = kwargs.pop("sin", None) + cos = kwargs.pop("cos", None) + self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) + else: + if use_cache: + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) return query, key def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids): - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - None, - attention_mask, - ) + + if self.module_device == "xpu": + scale = 1.0 / math.sqrt(self.head_dim) + is_causal = False + attn_output = torch.xpu.IpexSDP( + query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False + ) + attn_weights = None + past_key_value = (key, value) + else: + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) return attn_output, past_key_value, attn_weights # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 @@ -319,7 +378,15 @@ def forward( kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache) + query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache, **kwargs) + + if self.module_device == "xpu": + if past_key_value is not None: + key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache attn_output, past_key_value, attn_weights = sdpa( @@ -327,17 +394,75 @@ def forward( ) attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) - if hasattr(self, "mha_linear_add"): - attn_output = self.mha_linear_add(attn_output, residual) + if self.module_device == "xpu": + attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view( + [bsz, seq_len, self.hidden_size] + ) else: - attn_output = self.o_proj(attn_output) - attn_output = residual + attn_output + if hasattr(self, "mha_linear_add"): + attn_output = self.mha_linear_add(attn_output, residual) + else: + attn_output = self.o_proj(attn_output) + attn_output = residual + attn_output if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value + def port_parameters(self, module): + self.qkv_proj_bias = None + self.qkv_proj_weight = None + if self.num_heads == self.num_key_value_heads: + q_proj = module.q_proj.weight.transpose(0, 1) + k_proj = module.k_proj.weight.transpose(0, 1) + v_proj = module.v_proj.weight.transpose(0, 1) + self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) + module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) + module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) + module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) + if module.q_proj.bias is not None: + self.qkv_proj_bias = ( + torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) + .contiguous() + .view([3, -1]) + ) + module.q_proj.bias.data = self.qkv_proj_bias[0] + module.k_proj.bias.data = self.qkv_proj_bias[1] + module.v_proj.bias.data = self.qkv_proj_bias[2] + else: + q_proj = module.q_proj.weight.view( + self.num_key_value_heads, self.num_key_value_groups, self.head_dim, self.hidden_size + ) + k_proj = module.k_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) + v_proj = module.v_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) + self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] + ) + module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] + ) + module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] + ) + module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] + ) + self.qkv_proj_weight = self.qkv_proj_weight.permute(3, 0, 1, 2).contiguous() + if module.q_proj.bias is not None: + q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim) + k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) + v_bias = module.v_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) + self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim] + ) + module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1) + module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) + module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1) + self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() + module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) + self.o_proj_bias = module.o_proj.bias + # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 class _IPEXLlamaMLP(nn.Module): @@ -350,14 +475,19 @@ def __init__(self, module, config, distributed=False) -> None: _setattr_from_module(self, module) self.config = config self.distributed = distributed - from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + self.module_device = next(module.parameters()).device.type + if self.module_device == "xpu": + self.port_parameter(module) + torch.xpu.empty_cache() + else: + from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd - if not self.distributed: - self.mlp_linear_add = LinearAdd(module.down_proj) - del self.__dict__["_modules"]["down_proj"] - self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) - del self.__dict__["_modules"]["gate_proj"] - del self.__dict__["_modules"]["up_proj"] + if not self.distributed: + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): """ @@ -365,19 +495,36 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` """ - if hasattr(self, "linear_silu_mul"): - mlp_gate = self.linear_silu_mul(hidden_states) - if hasattr(self, "mlp_linear_add"): - hidden_states = self.mlp_linear_add(mlp_gate, residual) + if self.module_device == "xpu": + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + hidden_states = matmul_add_add(hidden_states, self.down_proj_weight, self.down_proj_bias, residual) + else: + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj(mlp_gate) + hidden_states = residual + hidden_states else: - hidden_states = self.down_proj(mlp_gate) + hidden_states = self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) hidden_states = residual + hidden_states - else: - hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) - hidden_states = residual + hidden_states - return hidden_states + def port_parameter(self, module): + self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() + module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1) + self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() + module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1) + self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() + module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1) + self.up_proj_bias = module.up_proj.bias + self.gate_proj_bias = module.gate_proj.bias + self.down_proj_bias = module.down_proj.bias + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayer(nn.Module): diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 7771bb35ab..7c0c07bfb7 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -364,9 +364,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -1637,9 +1637,9 @@ def _dbrx_update_causal_mask_legacy( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" From 68187e54b59e391b4813cbf72719fe78c634bee2 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 7 Jun 2024 02:05:18 -0700 Subject: [PATCH 27/31] revert openvino --- optimum/exporters/openvino/model_patcher.py | 12 ++++++------ optimum/intel/openvino/modeling_decoder.py | 8 +++++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 7c0c07bfb7..7771bb35ab 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -364,9 +364,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" @@ -1637,9 +1637,9 @@ def _dbrx_update_causal_mask_legacy( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index e557c5ddfa..9c14d74119 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -571,9 +571,11 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke ): past_key_values = tuple( tuple( - past_state[indicies] - if not self.config.model_type == "chatglm" - else past_state[:, indicies, ...] + ( + past_state[indicies] + if not self.config.model_type == "chatglm" + else past_state[:, indicies, ...] + ) for past_state in layer_past ) for layer_past in past_key_values From efedca4540ea4f39999bbdeac3cb23e11a58353a Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 7 Jun 2024 02:09:49 -0700 Subject: [PATCH 28/31] revert openvino --- optimum/intel/openvino/modeling_decoder.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 9c14d74119..e557c5ddfa 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -571,11 +571,9 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke ): past_key_values = tuple( tuple( - ( - past_state[indicies] - if not self.config.model_type == "chatglm" - else past_state[:, indicies, ...] - ) + past_state[indicies] + if not self.config.model_type == "chatglm" + else past_state[:, indicies, ...] for past_state in layer_past ) for layer_past in past_key_values From bd03552e961f90c6a68b811ae6d3264a3dcf39c2 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 7 Jun 2024 02:12:41 -0700 Subject: [PATCH 29/31] remove duplicates --- optimum/intel/ipex/modeling_base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 8693a5f0ff..ec63be60bc 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -68,8 +68,6 @@ def _is_patched_with_ipex(model, task): return False if isinstance(model, torch.jit.ScriptModule): - if is_ipex_version("<", "2.3.0"): - return False for node in model.graph.nodes(): # Jit will record the codes position so we can check if the node use ipex exporter. if "torch_ipex::rotary_position_embedding" in node.__str__(): From 0d3930ace1463a16a9d7838770a4a1810bd47a66 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 7 Jun 2024 02:17:44 -0700 Subject: [PATCH 30/31] use the correct black --- optimum/exporters/ipex/modeling_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index b9b7034f93..dfd62c8872 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -292,7 +292,6 @@ def rope(self, query, key, kv_seq_len, position_ids, use_cache, **kwargs): return query, key def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids): - if self.module_device == "xpu": scale = 1.0 / math.sqrt(self.head_dim) is_causal = False From 1fd464be2b788d8f5b9a0ab6bcba54fa95162226 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 10 Sep 2024 11:51:24 -0400 Subject: [PATCH 31/31] fix merge conflict Signed-off-by: Liu, Kaixuan --- notebooks/ipex/text_generation.ipynb | 1 + .../openvino/optimum_openvino_inference.ipynb | 49 +++- .../openvino/quantized_generation_demo.ipynb | 27 +- .../question_answering_quantization.ipynb | 4 +- ...stable_diffusion_hybrid_quantization.ipynb | 5 +- optimum/exporters/ipex/model_patcher.py | 4 +- optimum/exporters/ipex/modeling_utils.py | 245 +++++++++++------- optimum/exporters/openvino/model_patcher.py | 12 +- optimum/intel/ipex/modeling_base.py | 52 +--- optimum/intel/openvino/modeling_base.py | 12 +- 10 files changed, 234 insertions(+), 177 deletions(-) diff --git a/notebooks/ipex/text_generation.ipynb b/notebooks/ipex/text_generation.ipynb index d1a62d9201..dd6b8c0abb 100644 --- a/notebooks/ipex/text_generation.ipynb +++ b/notebooks/ipex/text_generation.ipynb @@ -22,6 +22,7 @@ "source": [ "import torch\n", "from transformers import AutoTokenizer\n", + "\n", "from optimum.intel.ipex import IPEXModelForCausalLM" ] }, diff --git a/notebooks/openvino/optimum_openvino_inference.ipynb b/notebooks/openvino/optimum_openvino_inference.ipynb index 76c77aec55..5106fe1fba 100644 --- a/notebooks/openvino/optimum_openvino_inference.ipynb +++ b/notebooks/openvino/optimum_openvino_inference.ipynb @@ -78,6 +78,7 @@ "source": [ "from optimum.intel import OVModelForQuestionAnswering\n", "\n", + "\n", "# Load PyTorch model from the Hub and export to OpenVINO in the background\n", "model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad\", export=True)\n", "\n", @@ -122,6 +123,7 @@ "source": [ "from transformers import AutoTokenizer\n", "\n", + "\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-distilled-squad\")\n", "tokenizer.save_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp32\")" ] @@ -182,9 +184,11 @@ } ], "source": [ - "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForQuestionAnswering\n", + "\n", + "\n", "model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp32\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-distilled-squad\")\n", "ov_pipe = pipeline(\"question-answering\", model=model, tokenizer=tokenizer)\n", @@ -240,9 +244,11 @@ ], "source": [ "import torch\n", - "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForQuestionAnswering\n", + "\n", + "\n", "model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp32\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp32\")\n", "\n", @@ -324,9 +330,11 @@ } ], "source": [ - "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForQuestionAnswering\n", + "\n", + "\n", "model = OVModelForQuestionAnswering.from_pretrained(\n", " \"helenai/distilbert-base-uncased-distilled-squad-ov-fp32\", compile=False\n", ")\n", @@ -411,6 +419,7 @@ "source": [ "from openvino.runtime import Core\n", "\n", + "\n", "for device in Core().available_devices:\n", " print(device, Core().get_property(device, \"FULL_DEVICE_NAME\"))" ] @@ -528,10 +537,12 @@ } ], "source": [ + "from datasets import load_dataset\n", "from IPython.display import Audio\n", - "from optimum.intel import OVModelForAudioClassification\n", "from transformers import AutoFeatureExtractor, pipeline\n", - "from datasets import load_dataset\n", + "\n", + "from optimum.intel import OVModelForAudioClassification\n", + "\n", "\n", "model_id = \"helenai/MIT-ast-finetuned-speech-commands-v2-ov\"\n", "model = OVModelForAudioClassification.from_pretrained(model_id)\n", @@ -638,9 +649,11 @@ } ], "source": [ - "from optimum.intel import OVModelForCausalLM\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForCausalLM\n", + "\n", + "\n", "model_id = \"helenai/gpt2-ov\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = OVModelForCausalLM.from_pretrained(model_id)\n", @@ -704,9 +717,11 @@ ], "source": [ "from IPython.display import Image\n", - "from optimum.intel import OVModelForImageClassification\n", "from transformers import AutoImageProcessor, pipeline\n", "\n", + "from optimum.intel import OVModelForImageClassification\n", + "\n", + "\n", "model_id = \"helenai/microsoft-swin-tiny-patch4-window7-224-ov\"\n", "model = OVModelForImageClassification.from_pretrained(model_id, compile=False)\n", "image_processor = AutoImageProcessor.from_pretrained(model_id)\n", @@ -766,9 +781,11 @@ } ], "source": [ - "from optimum.intel import OVModelForMaskedLM\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForMaskedLM\n", + "\n", + "\n", "model_id = \"helenai/bert-base-uncased-ov\"\n", "model = OVModelForMaskedLM.from_pretrained(model_id)\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", @@ -835,9 +852,11 @@ } ], "source": [ - "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForQuestionAnswering\n", + "\n", + "\n", "# Load the model and tokenizer saved in Part 1 of this notebook. Or use the line below to load them from the hub\n", "# model_id = \"helenai/distilbert-base-uncased-distilled-squad-ov-fp32\"\n", "model_id = \"distilbert-base-uncased-distilled-squad-ov-fp32\"\n", @@ -890,9 +909,11 @@ } ], "source": [ - "from optimum.intel import OVModelForSeq2SeqLM\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForSeq2SeqLM\n", + "\n", + "\n", "model_id = \"helenai/t5-small-ov\"\n", "model = OVModelForSeq2SeqLM.from_pretrained(model_id, compile=False, trust_remote_code=True)\n", "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n", @@ -998,9 +1019,11 @@ } ], "source": [ - "from optimum.intel import OVModelForSequenceClassification\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForSequenceClassification\n", + "\n", + "\n", "model_id = \"helenai/papluca-xlm-roberta-base-language-detection-ov\"\n", "model = OVModelForSequenceClassification.from_pretrained(model_id)\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", @@ -1047,9 +1070,11 @@ } ], "source": [ - "from optimum.intel import OVModelForTokenClassification\n", "from transformers import AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVModelForTokenClassification\n", + "\n", + "\n", "model_id = \"helenai/dslim-bert-base-NER-ov-fp32\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = OVModelForTokenClassification.from_pretrained(model_id)\n", diff --git a/notebooks/openvino/quantized_generation_demo.ipynb b/notebooks/openvino/quantized_generation_demo.ipynb index 5673243cb2..c160e735b0 100644 --- a/notebooks/openvino/quantized_generation_demo.ipynb +++ b/notebooks/openvino/quantized_generation_demo.ipynb @@ -45,6 +45,7 @@ "import os\n", "\n", "from transformers import AutoTokenizer\n", + "\n", "from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig" ] }, @@ -211,6 +212,7 @@ "source": [ "from transformers import TextStreamer\n", "\n", + "\n", "# Tokenize the sample\n", "inputs = tokenizer([sample], return_tensors='pt')\n", "\n", @@ -294,7 +296,7 @@ "\n", "\n", "# Tokenize the sample\n", - "inputs = tokenizer([sample], return_tensors='pt') \n", + "inputs = tokenizer([sample], return_tensors='pt')\n", "\n", "out = stateless_model.generate(\n", " **inputs,\n", @@ -302,7 +304,7 @@ " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n", " pad_token_id=tokenizer.eos_token_id,\n", " prompt_lookup_num_tokens=3,\n", - ") " + ")" ] }, { @@ -442,6 +444,7 @@ "outputs": [], "source": [ "from functools import wraps\n", + "\n", "import numpy as np\n", "\n", "\n", @@ -458,15 +461,15 @@ " if len(self.seq_lens) > 0 or len(self.win_sizes) > 0:\n", " raise RuntimeError(\"Always use a new instance, don't reuse!\")\n", " self.model_forward = self.model.forward\n", - " \n", + "\n", " @wraps(self.model_forward)\n", " def forward_wrapper(**kwargs):\n", " self.seq_lens[-1].append(kwargs.get(\"attention_mask\").shape[-1])\n", " self.win_sizes[-1].append(kwargs.get(\"input_ids\").shape[-1] - 1)\n", " return self.model_forward(**kwargs)\n", - " \n", + "\n", " self.model.forward = forward_wrapper\n", - " \n", + "\n", " # wrap generate method\n", " self.model_generate = self.model.generate\n", "\n", @@ -494,7 +497,7 @@ " self.seq_lens = [sl[1:] for sl in self.seq_lens]\n", " # Add window size for output to ease calculation later\n", " for ws, sl in zip(self.win_sizes, self.seq_lens):\n", - " ws.append(0) \n", + " ws.append(0)\n", "\n", " def acceptance_rate(self, return_mean=True, normalize=False):\n", " # ar_per_win = ((cur_seq_len - cur_win_size) - (prev_seq_len - prev_win_size) - 1) / prev_win_size\n", @@ -533,8 +536,9 @@ "metadata": {}, "outputs": [], "source": [ - "from tqdm import tqdm\n", "from datasets import load_dataset\n", + "from tqdm import tqdm\n", + "\n", "\n", "dataset_name = \"openai_humaneval\"\n", "dataset_subset_name = None\n", @@ -590,10 +594,10 @@ "from threading import Thread\n", "\n", "from transformers import (\n", - " TextIteratorStreamer,\n", + " GenerationConfig,\n", " StoppingCriteria,\n", " StoppingCriteriaList,\n", - " GenerationConfig,\n", + " TextIteratorStreamer,\n", ")\n", "\n", "\n", @@ -690,7 +694,7 @@ " prompt_char = \"▌\"\n", " history[-1][1] = prompt_char\n", " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", - " \n", + "\n", " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", "\n", " # Create a stopping criteria to prevent the model from playing the role of the user aswell.\n", @@ -770,6 +774,7 @@ "source": [ "import gradio as gr\n", "\n", + "\n", "try:\n", " demo.close()\n", "except:\n", @@ -808,7 +813,7 @@ " history: conversation history\n", " Returns:\n", " updated history\n", - " \"\"\" \n", + " \"\"\"\n", " history[-1][1] = None\n", " return history\n", "\n", diff --git a/notebooks/openvino/question_answering_quantization.ipynb b/notebooks/openvino/question_answering_quantization.ipynb index 2481c9b904..247a6f868b 100644 --- a/notebooks/openvino/question_answering_quantization.ipynb +++ b/notebooks/openvino/question_answering_quantization.ipynb @@ -51,9 +51,11 @@ "import transformers\n", "from evaluate import evaluator\n", "from openvino.runtime import Core\n", - "from optimum.intel import OVModelForQuestionAnswering, OVQuantizer, OVQuantizationConfig, OVConfig\n", "from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline\n", "\n", + "from optimum.intel import OVConfig, OVModelForQuestionAnswering, OVQuantizationConfig, OVQuantizer\n", + "\n", + "\n", "transformers.logging.set_verbosity_error()\n", "datasets.logging.set_verbosity_error()" ] diff --git a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb index 8ef2e8ad6c..798aede77a 100644 --- a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb +++ b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb @@ -46,15 +46,18 @@ "outputs": [], "source": [ "import time\n", + "from pathlib import Path\n", + "\n", "import datasets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import transformers\n", - "from pathlib import Path\n", "from openvino.runtime import Core\n", + "\n", "from optimum.intel import OVConfig, OVQuantizer, OVStableDiffusionPipeline, OVWeightQuantizationConfig\n", "from optimum.intel.openvino.configuration import OVQuantizationMethod\n", "\n", + "\n", "transformers.logging.set_verbosity_error()\n", "datasets.logging.set_verbosity_error()" ] diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 484fd38077..2a9af1cd52 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -29,11 +29,11 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, _gpt2_block_forward, - _ipex_rms_layer_norm_forward, _IPEXFalconDecoderLayer, _IPEXGPT2Attention, _IPEXIntermediate, _IPEXLlamaDecoderLayer, + _llama_layer_norm_forward, _llama_model_forward, ) @@ -79,7 +79,7 @@ def _patch_llama_model(model): 2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add) """ convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward) + convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e415a40771..2e73fb9076 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -84,14 +84,12 @@ def padding_attn_mask(attn_mask, alignment): return new_attn_mask -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 def _llama_layer_norm_forward(self, hidden_states): - return rms_norm(hidden_states, self.weight, self.variance_epsilon) - - -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 -def _ipex_rms_layer_norm_forward(self, hidden_states): - return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) + if hidden_states.device.type == "xpu": + return rms_norm(hidden_states, self.weight, self.variance_epsilon) + else: + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 + return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 @@ -226,14 +224,82 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=config.max_position_embeddings) - if hasattr(config, "rope_theta"): - self.ipex_rope = RotaryEmbedding( - config.max_position_embeddings, - config.hidden_size // config.num_attention_heads, - config.rope_theta, - config.architectures[0], + self.module_device = next(module.parameters()).device.type + if self.module_device == "xpu": + from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU + + self.ipex_rope = _IPEXRopeXPU( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], ) + self.port_parameters(module) + torch.xpu.empty_cache() + else: + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=config.max_position_embeddings + ) + if hasattr(config, "rope_theta"): + self.ipex_rope = RotaryEmbedding( + config.max_position_embeddings, + config.hidden_size // config.num_attention_heads, + config.rope_theta, + config.architectures[0], + ) + + def port_parameters(self, module): + self.qkv_proj_bias = None + self.qkv_proj_weight = None + if self.num_heads == self.num_key_value_heads: + q_proj = module.q_proj.weight.transpose(0, 1) + k_proj = module.k_proj.weight.transpose(0, 1) + v_proj = module.v_proj.weight.transpose(0, 1) + self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) + module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) + module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) + module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) + if module.q_proj.bias is not None: + self.qkv_proj_bias = ( + torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) + .contiguous() + .view([3, -1]) + ) + module.q_proj.bias.data = self.qkv_proj_bias[0] + module.k_proj.bias.data = self.qkv_proj_bias[1] + module.v_proj.bias.data = self.qkv_proj_bias[2] + else: + q_proj = module.q_proj.weight.view( + self.num_key_value_heads, self.num_key_value_groups, self.head_dim, self.hidden_size + ) + k_proj = module.k_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) + v_proj = module.v_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) + self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] + ) + module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] + ) + module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] + ) + module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] + ) + self.qkv_proj_weight = self.qkv_proj_weight.permute(3, 0, 1, 2).contiguous() + if module.q_proj.bias is not None: + q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim) + k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) + v_bias = module.v_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) + self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim] + ) + module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1) + module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) + module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1) + self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() + module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) + self.o_proj_bias = module.o_proj.bias def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") @@ -244,16 +310,25 @@ def rope(self, *args, **kwargs): def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): # This ipex op pre-allocates buffers for past_key_values and use beam index history # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - kwargs.get("head_mask", None), - attention_mask, - kwargs.get("alibi", None), - ) + if self.module_device == "xpu": + scale = 1.0 / math.sqrt(self.head_dim) + is_causal = False + attn_output = torch.xpu.IpexSDP( + query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False + ) + attn_weights = None + past_key_value = (key, value) + else: + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + kwargs.get("head_mask", None), + attention_mask, + kwargs.get("alibi", None), + ) return attn_output, past_key_value, attn_weights def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): @@ -287,10 +362,18 @@ def forward( qkv_out = self.qkv_gemm(hidden_states) if isinstance(qkv_out, tuple) and len(qkv_out) == 3: query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids=position_ids) + query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids, **kwargs) else: query, key, value = self.rope(qkv_out, kv_seq_len, use_cache, past_len=past_len) + if self.module_device == "xpu": + if past_key_value is not None: + key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attention_mask = self.prepare_attention_mask_float(attention_mask, query.dtype) sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache attn_output, past_key_value, attn_weights = sdpa( @@ -315,9 +398,10 @@ def forward( class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mha_linear_add = LinearAdd(module.o_proj) - del self.__dict__["_modules"]["o_proj"] + if self.module_device == "cpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = LinearAdd(module.o_proj) + del self.__dict__["_modules"]["o_proj"] def qkv_gemm(self, hidden_states): bsz, seq_len, _ = hidden_states.size() @@ -327,11 +411,16 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, kv_seq_len, use_cache, position_ids): - if use_cache: - args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len) - key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args) - query = self.ipex_rope(query, position_ids, self.num_heads, *args) + def rope(self, query, key, kv_seq_len, use_cache, position_ids, **kwargs): + if self.module_device == "xpu": + sin = kwargs.pop("sin", None) + cos = kwargs.pop("cos", None) + self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) + else: + if use_cache: + args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len) + key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args) + query = self.ipex_rope(query, position_ids, self.num_heads, *args) return query, key # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 @@ -418,59 +507,6 @@ def postprocess_attention_output(self, attn_output, bsz, seq_len): attn_output = self.resid_dropout(attn_output) return attn_output - def port_parameters(self, module): - self.qkv_proj_bias = None - self.qkv_proj_weight = None - if self.num_heads == self.num_key_value_heads: - q_proj = module.q_proj.weight.transpose(0, 1) - k_proj = module.k_proj.weight.transpose(0, 1) - v_proj = module.v_proj.weight.transpose(0, 1) - self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) - module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) - module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) - module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) - if module.q_proj.bias is not None: - self.qkv_proj_bias = ( - torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) - .contiguous() - .view([3, -1]) - ) - module.q_proj.bias.data = self.qkv_proj_bias[0] - module.k_proj.bias.data = self.qkv_proj_bias[1] - module.v_proj.bias.data = self.qkv_proj_bias[2] - else: - q_proj = module.q_proj.weight.view( - self.num_key_value_heads, self.num_key_value_groups, self.head_dim, self.hidden_size - ) - k_proj = module.k_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) - v_proj = module.v_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) - self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( - [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] - ) - module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].reshape( - [self.num_key_value_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] - ) - module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].reshape( - [self.num_key_value_heads * self.head_dim, self.hidden_size] - ) - module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape( - [self.num_key_value_heads * self.head_dim, self.hidden_size] - ) - self.qkv_proj_weight = self.qkv_proj_weight.permute(3, 0, 1, 2).contiguous() - if module.q_proj.bias is not None: - q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim) - k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) - v_bias = module.v_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) - self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( - [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim] - ) - module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1) - module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) - module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1) - self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() - module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) - self.o_proj_bias = module.o_proj.bias - # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 class _IPEXLlamaMLP(nn.Module): @@ -478,19 +514,34 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mlp_linear_add = LinearAdd(module.down_proj) - del self.__dict__["_modules"]["down_proj"] - self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) - del self.__dict__["_modules"]["gate_proj"] - del self.__dict__["_modules"]["up_proj"] + self.module_device = next(module.parameters()).device.type + if self.module_device == "xpu": + self.port_parameter(module) + torch.xpu.empty_cache() + else: + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): - if hasattr(self, "linear_silu_mul"): - mlp_gate = self.linear_silu_mul(hidden_states) - if hasattr(self, "mlp_linear_add"): - hidden_states = self.mlp_linear_add(mlp_gate, residual) + if self.module_device == "xpu": + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + hidden_states = matmul_add_add(hidden_states, self.down_proj_weight, self.down_proj_bias, residual) + else: + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + hidden_states = residual + hidden_states else: hidden_states = self.down_proj( self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 59d4bedb51..57185de6c1 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -404,9 +404,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -1966,9 +1966,9 @@ def _dbrx_update_causal_mask_legacy( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 6a820b4feb..b2be8a6b1d 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -154,7 +154,7 @@ def __init__( self._device = torch.device("cpu") # CPU only support jit model for now. - if export: + if export and self._device.type == "cpu": if isinstance(model, torch.jit.RecursiveScriptModule): logger.warning("The model has been exported already.") else: @@ -251,7 +251,6 @@ def _from_pretrained( ) token = use_auth_token - task = cls.export_feature commit_hash = kwargs.pop("_commit_hash", None) model_kwargs = { @@ -263,49 +262,11 @@ def _from_pretrained( "force_download": force_download, } - model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - - if is_torch_xpu_available(check_device=True): - model.to("xpu:0") - if _is_patched_with_ipex(model, task): - model = _patch_model(model) - else: - model = ipex_jit_trace(model, task, use_cache) - config.torchscript = True - config.torch_dtype = torch_dtype - return cls(model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False) - - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: str = HUGGINGFACE_HUB_CACHE, - file_name: Optional[str] = WEIGHTS_NAME, - local_files_only: bool = False, - subfolder: str = "", - **kwargs, - ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "Both the arguments `use_auth_token` and `token` were specified, which is not supported. Please specify only `token`." - ) - token = use_auth_token - if not getattr(config, "torchscript", False): logger.warning("Detect torchscript is false. Convert to torchscript model!") if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.0.0` is needed to trace your model") + raise ImportError("`torch>=2.1.0` is needed to trace your model") task = cls.export_feature config.torch_dtype = torch_dtype @@ -318,6 +279,15 @@ def _from_pretrained( _commit_hash=commit_hash, **model_kwargs, ) + if is_torch_xpu_available(check_device=True): + model.to("xpu:0") + if _is_patched_with_ipex(model, task): + model = _patch_model(model) + else: + use_cache = kwargs.get("use_cache", True) + model = ipex_jit_trace(model, task, use_cache) + config.torchscript = True + config.torch_dtype = torch_dtype return cls(model, config=config, export=True, **kwargs) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 89da349c82..8944ef6da2 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -84,9 +84,9 @@ def __init__( for idx, key in enumerate(model.inputs): names = tuple(key.get_names()) input_names[next((name for name in names if "/" not in name), names[0])] = idx - input_dtypes[ - next((name for name in names if "/" not in name), names[0]) - ] = key.get_element_type().get_type_name() + input_dtypes[next((name for name in names if "/" not in name), names[0])] = ( + key.get_element_type().get_type_name() + ) self.input_names = input_names self.input_dtypes = input_dtypes @@ -95,9 +95,9 @@ def __init__( for idx, key in enumerate(model.outputs): names = tuple(key.get_names()) output_names[next((name for name in names if "/" not in name), names[0])] = idx - output_dtypes[ - next((name for name in names if "/" not in name), names[0]) - ] = key.get_element_type().get_type_name() + output_dtypes[next((name for name in names if "/" not in name), names[0])] = ( + key.get_element_type().get_type_name() + ) self.output_names = output_names self.output_dtypes = output_dtypes