From e1b6a59c55157d0feb4d53945cbbe191e5c0f243 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 30 Apr 2024 19:02:52 +0400 Subject: [PATCH] Apply sdpa for mpt and internlm (#676) * apply sdpa for mpt and internlm * fix bauchan-13b * fix accuracy * small refactoring * add test for baichuan 13b * add support output_attentions * code style --- optimum/exporters/openvino/convert.py | 4 +- optimum/exporters/openvino/model_configs.py | 18 ++ optimum/exporters/openvino/model_patcher.py | 326 ++++++++++++++++++++ tests/openvino/test_modeling.py | 13 +- tests/openvino/utils_tests.py | 1 + 5 files changed, 359 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 6c86c2c2df..3022346af5 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -358,6 +358,7 @@ def ts_patched_forward(*args, **kwargs): with patcher: check_dummy_inputs_are_allowed(model, dummy_inputs) + sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) inputs = config.ordered_inputs(model) input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) @@ -387,7 +388,6 @@ def ts_patched_forward(*args, **kwargs): ov_config=ov_config, ) - sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} if not ordered_dummy_inputs: ordered_dummy_inputs = dummy_inputs @@ -403,7 +403,7 @@ def ts_patched_forward(*args, **kwargs): inp_tensor.get_tensor().set_names({input_name}) inp_data = flatten_inputs[idx] static_shape = PartialShape(inp_data.shape) - dims = inputs[input_name] + dims = inputs.get(input_name, []) for dim in dims: static_shape[dim] = -1 inp_tensor.get_node().set_partial_shape(static_shape) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index fb1415700d..575f1cc4db 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -23,6 +23,7 @@ FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, + MPTOnnxConfig, PhiOnnxConfig, UNetOnnxConfig, VaeDecoderOnnxConfig, @@ -43,8 +44,10 @@ BaichuanModelPatcher, ChatGLMModelPatcher, GemmaModelPatcher, + InternLMPatcher, LlamaModelPatcher, MixtralModelPatcher, + MPTModelPatcher, Phi3ModelPatcher, QwenModelPatcher, ) @@ -439,6 +442,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return InternLMPatcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers") class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -455,6 +463,16 @@ class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +@register_in_tasks_manager( + "mpt", *["text-generation", "text-generation-with-past", "text-classification"], library_name="transformers" +) +class MPTOpenVINOConfig(MPTOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MPTModelPatcher(self, model, model_kwargs=model_kwargs) + + @register_in_tasks_manager( "phi3", *[ diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 9082f718cb..f68e873d40 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging as log +import math import types from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -612,6 +614,57 @@ def __exit__(self, exc_type, exc_value, traceback): self._model.config.fp16 = self.original_fp16 +def _baichuan13b_atten_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + if attention_mask is not None: + attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :] + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + if not output_attentions: + past_key_value = (key_states, value_states) if use_cache else None + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + class BaichuanModelPatcher(DecoderModelPatcher): def __init__( self, @@ -624,6 +677,279 @@ def __init__( if hasattr(self._model.lm_head, "first_flag"): self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) + def __enter__(self): + super().__enter__() + # override signature to have position_ids + if "position_ids" not in inspect.signature(self._model.forward).parameters: + self._model._orig_forward = self._model.forward + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + ): + return self._orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=past_key_values is not None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=self.config.return_dict, + ) + + self._model.forward = types.MethodType(forward, self._model) + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model, "_orig_forward"): + self._model.forward = self._model._orig_forward + + for layer in self._model.model.layers: + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _mpt_sdpa_attention_forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, +): + batch_size, seq_length = hidden_states.shape[:2] + + mixed_qkv = self.Wqkv(hidden_states) + query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) + query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if len(past_key_value) != 0: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) + else: + past_key_value = (key_states, value_states) + + key_length = key_states.shape[-2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + attention_mask_sdpa = torch.ones( + (query_states.shape[0], query_states.shape[1], query_states.shape[2], key_states.shape[2]), + dtype=query_states.dtype, + ) + if position_bias is not None: + position_bias_query_index = max(0, position_bias.size(1) - query_length) + position_bias_key_index = max(0, position_bias.size(2) - key_length) + + position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] + attention_mask_sdpa += position_bias + attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min) + context_states = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask_sdpa, + dropout_p=self.attn_dropout_p, + scale=self.softmax_scale, + ) + + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output = self.out_proj(context_states) + + return attn_output, None, past_key_value + + +def _mpt_block_forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, +): + # hidden_states: [batch_size, seq_length, hidden_size] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.norm_1(hidden_states) + + residual = hidden_states + + if not output_attentions: + # Self attention. + attn_outputs, attn_weights, past_key_value = self.attn( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_value=layer_past, + ) + else: + attn_outputs, attn_weights, past_key_value = self.attn._orig_forward( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_value=layer_past, + ) + + hidden_states = self.resid_attn_dropout(attn_outputs) + residual + + layernorm_output = self.norm_2(hidden_states) + + # Get residual + residual = hidden_states + + # MLP. + output = self.ffn(layernorm_output, residual) + outputs = (output,) + + if use_cache: + outputs += (past_key_value,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class MPTModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + if is_torch_version(">=", "2.1.0"): + for block in self._model.transformer.blocks: + block._orig_forward = block.forward + block.forward = types.MethodType(_mpt_block_forward, block) + block.attn._orig_forward = block.attn.forward + block.attn.forward = types.MethodType(_mpt_sdpa_attention_forward, block.attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.transformer.blocks: + if hasattr(block, "_orig_forward"): + block.forward = block._orig_forward + if hasattr(block.attn, "_orig_forward"): + block.attn.forward = block.attn._orig_forward + + +def _internlm_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + from einops import rearrange + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + "b q (h gs d) -> b q h gs d", + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + return attn_output, attn_weights, past_key_value + + +class InternLMPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + if is_torch_version(">=", "2.1.0"): + for block in self._model.model.layers: + block.attention._orig_forward = block.attention.forward + block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.model.layers: + if hasattr(block.attention, "_orig_forward"): + block.attention.forward = block.attention._orig_forward + class Phi3ModelPatcher(DecoderModelPatcher): def __enter__(self): diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index a7f490a22e..d4f55c683b 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -520,6 +520,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", "baichuan2", + "baichuan2-13b", "gpt_bigcode", "blenderbot", "blenderbot-small", @@ -553,7 +554,17 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "falcon-40b", ) GENERATION_LENGTH = 100 - REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "orion", "phi3") + REMOTE_CODE_MODELS = ( + "chatglm", + "minicpm", + "baichuan2", + "baichuan2-13b", + "jais", + "qwen", + "internlm2", + "orion", + "phi3", + ) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 2f76e5eeda..9f28e40a4b 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -24,6 +24,7 @@ "bert": "hf-internal-testing/tiny-random-bert", "bart": "hf-internal-testing/tiny-random-bart", "baichuan2": "katuni4ka/tiny-random-baichuan2", + "baichuan2-13b": "katuni4ka/tiny-random-baichuan2-13b", "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",