From c67449227bd4230525624395678a7cfc86883a0d Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 22 Apr 2024 20:13:40 +0400 Subject: [PATCH] fix bauchan-13b --- optimum/exporters/openvino/convert.py | 21 ++--- optimum/exporters/openvino/model_configs.py | 4 +- optimum/exporters/openvino/model_patcher.py | 94 +++++++++++++++++++-- optimum/intel/openvino/quantization.py | 6 +- 4 files changed, 105 insertions(+), 20 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 55e3318017..5c90dc7b71 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -347,6 +347,7 @@ def ts_patched_forward(*args, **kwargs): with patcher: check_dummy_inputs_are_allowed(model, dummy_inputs) + sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) inputs = config.ordered_inputs(model) input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) @@ -376,7 +377,6 @@ def ts_patched_forward(*args, **kwargs): ov_config=ov_config, ) - sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} if not ordered_dummy_inputs: ordered_dummy_inputs = dummy_inputs @@ -388,15 +388,16 @@ def ts_patched_forward(*args, **kwargs): out_tensor.get_tensor().set_names({output_names[idx]}) for idx, inp_tensor in enumerate(ov_model.inputs): - input_name = ordered_input_names[idx] - inp_tensor.get_tensor().set_names({input_name}) - inp_data = flatten_inputs[idx] - static_shape = PartialShape(inp_data.shape) - dims = inputs[input_name] - for dim in dims: - static_shape[dim] = -1 - inp_tensor.get_node().set_partial_shape(static_shape) - inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) + if idx < len(ordered_input_names): + input_name = ordered_input_names[idx] + inp_tensor.get_tensor().set_names({input_name}) + inp_data = flatten_inputs[idx] + static_shape = PartialShape(inp_data.shape) + dims = inputs.get(input_name, []) + for dim in dims: + static_shape[dim] = -1 + inp_tensor.get_node().set_partial_shape(static_shape) + inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() if stateful: diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 9ecf6bcaee..b10e744801 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -34,11 +34,11 @@ BaichuanModelPatcher, ChatGLMModelPatcher, GemmaModelPatcher, + InternLMPatcher, LlamaModelPatcher, MixtralModelPatcher, - QwenModelPatcher, MPTModelPatcher, - InternLMPatcher, + QwenModelPatcher, ) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 48fc5f60ea..996003f1ea 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging as log import math import types @@ -328,9 +329,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" @@ -601,6 +602,46 @@ def __exit__(self, exc_type, exc_value, traceback): self._model.config.fp16 = self.original_fp16 +def _baichuan13b_atten_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + if attention_mask is not None: + attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :] + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) + attn_weights = None + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class BaichuanModelPatcher(DecoderModelPatcher): def __init__( self, @@ -613,6 +654,50 @@ def __init__( if hasattr(self._model.lm_head, "first_flag"): self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) + def __enter__(self): + super().__enter__() + # override signature to have position_ids + if "position_ids" not in inspect.signature(self._model.forward).parameters: + self._model._orig_forward = self._model.forward + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + ): + return self._orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=past_key_values is not None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=self.config.return_dict, + ) + + self._model.forward = types.MethodType(forward, self._model) + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model, "_orig_forward"): + self._model.forward = self._model._orig_forward + + for layer in self._model.model.layers: + layer.self_attn.forward = layer.self_attn._orig_forward + def _mpt_attention_forward( self, @@ -679,8 +764,7 @@ def _internlm_attention_forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv bsz, q_len, _ = hidden_states.size() diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index aae66c148b..75ff56d6b3 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -484,9 +484,9 @@ def _quantize_torchmodel( subset_size=quantization_config.num_samples, ignored_scope=quantization_config.get_ignored_scope_instance(), model_type=nncf.ModelType(quantization_config.model_type), - preset=nncf.QuantizationPreset.PERFORMANCE - if quantization_config.sym - else nncf.QuantizationPreset.MIXED, + preset=( + nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED + ), fast_bias_correction=quantization_config.fast_bias_correction, advanced_parameters=nncf.AdvancedQuantizationParameters( overflow_fix=OverflowFix(quantization_config.overflow_fix)