diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 0d87a5fd6c..c6ed33006e 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -13,24 +13,26 @@ # limitations under the License. from transformers.models.llama.modeling_llama import ( - LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm, ) -from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, - _IPEXLlamaDecoderLayerRef, - _llama_attn_forward, + _IPEXLlamaDecoderLayer, _llama_layer_norm_forward, _llama_model_forward, ) +# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version +_TRANSFORMERS_MIN_VERSION = "4.39.0" +_TRANSFORMERS_MAX_VERSION = "4.41.2" + _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) _IPEX_EXPORTED_TASK = ("text-generation",) @@ -64,27 +66,16 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching") + if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version( + ">", _TRANSFORMERS_MAX_VERSION + ): raise ImportError( - f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports RotaryEmbedding and IndirectAccessKVCacheAttention" + f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." ) - - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding - - ipex_rope = RotaryEmbedding( - model.config.max_position_embeddings, - model.config.hidden_size // model.config.num_attention_heads, - model.config.rope_theta, - model.config.architectures[0], - ) - ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings) - patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) - patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) - convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) - - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index a2b73e74ae..5870a8c792 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -21,12 +21,10 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv -from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version +from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.modeling_utils import _setattr_from_module -# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version -_TRANSFORMERS_MIN_VERSION = "4.39.0" -_TRANSFORMERS_MAX_VERSION = "4.41.2" _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" @@ -35,92 +33,6 @@ def _llama_layer_norm_forward(self, hidden_states): return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 -def _llama_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len - - query = query.view(bsz, q_len, self.num_heads, self.head_dim) - key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - if use_cache: - # Use ipex op to rotary position embedding more efficient. - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - None, - attention_mask, - ) - else: - value_states = value.transpose(1, 2) - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - kv_seq_len = key_states.shape[-2] - - past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -224,38 +136,212 @@ def _llama_model_forward( ) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 -class _IPEXLlamaDecoderLayerRef(nn.Module): - def __init__(self, module, config, distributed=False): +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 +class _IPEXLlamaAttention(nn.Module): + def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): raise ImportError( - f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul and LinearAdd" + f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding" ) - if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version( - ">", _TRANSFORMERS_MAX_VERSION - ): - raise ImportError( - f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." + super().__init__() + _setattr_from_module(self, module) + self.config = config + self.distributed = distributed + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + + if not self.distributed: + self.mha_linear_add = LinearAdd(self.o_proj) + del self.__dict__["_modules"]["o_proj"] + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=module.config.max_position_embeddings + ) + self.ipex_rope = RotaryEmbedding( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) + + def qkv_gemm(self, hidden_states): + bsz, seq_len, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = query.view(bsz, seq_len, self.num_heads, self.head_dim) + key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + return query, key, value + + def rope(self, query, key, kv_seq_len, position_ids, use_cache): + if use_cache: + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, ) + return query, key - from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids): + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) + return attn_output, past_key_value, attn_weights + + # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 + def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, position_ids): + value_states = value.transpose(1, 2) + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + past_key_value = None + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + return attn_output, past_key_value, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` + """ + bsz, seq_len, _ = hidden_states.size() + kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len + + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache) + + sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache + attn_output, past_key_value, attn_weights = sdpa( + query, key, value, past_key_value, attention_mask, position_ids + ) + attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) + + if hasattr(self, "mha_linear_add"): + attn_output = self.mha_linear_add(attn_output, residual) + else: + attn_output = self.o_proj(attn_output) + attn_output = residual + attn_output + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 +class _IPEXLlamaMLP(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError( + f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul, LinearAdd" + ) super().__init__() - for k, v in module.__dict__.items(): - setattr(self, k, v) - for k, v in module.__class__.__dict__.items(): - if k.startswith("__") or k.startswith("forward"): - continue - setattr(self.__class__, k, getattr(module.__class__, k)) + _setattr_from_module(self, module) + self.config = config self.distributed = distributed + from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + if not self.distributed: - self.mha_linear_add = LinearAdd(module.self_attn.o_proj) - self.mlp_linear_add = LinearAdd(module.mlp.down_proj) - del self.__dict__["_modules"]["self_attn"].o_proj - del self.__dict__["_modules"]["mlp"].down_proj - self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj) - del self.__dict__["_modules"]["mlp"].gate_proj - del self.__dict__["_modules"]["mlp"].up_proj + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` + """ + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj(mlp_gate) + hidden_states = residual + hidden_states + else: + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = residual + hidden_states + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 +class _IPEXLlamaDecoderLayer(nn.Module): + def __init__(self, module, config, distributed=False): + super().__init__() + _setattr_from_module(self, module) + self.distributed = distributed + self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) def forward( self, @@ -271,15 +357,17 @@ def forward( Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states @@ -293,27 +381,15 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=None, + residual=residual, + **kwargs, ) - if hasattr(self, "mha_linear_add"): - hidden_states = self.mha_linear_add(hidden_states, residual) - else: - hidden_states = self.self_attn.o_proj(hidden_states) - hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - - if hasattr(self, "linear_silu_mul"): - mlp_gate = self.linear_silu_mul(hidden_states) - if hasattr(self, "mlp_linear_add"): - hidden_states = self.mlp_linear_add(mlp_gate, residual) - else: - hidden_states = self.mlp.down_proj(mlp_gate) - hidden_states = residual + hidden_states - else: - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, residual, **kwargs) outputs = (hidden_states,) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index a2cd728354..3541f4f933 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -182,3 +182,12 @@ def recursive_to_device(value, device): elif isinstance(value, torch.Tensor): return value.to(device) return value + + +def _setattr_from_module(new_module, module): + for k, v in module.__dict__.items(): + setattr(new_module, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(new_module.__class__, k, getattr(module.__class__, k))