From 55fec6adff9b76ec96352a7887780ed62c549c5d Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 25 Jul 2024 19:01:05 +0400 Subject: [PATCH] Fix bf16 inference accuracy for mistral, phi3, dbrx (#833) * Fix bf16 inference accuracy for mistral, phi3, dbrx * reuse inv_freq * Apply suggestions from code review Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * make dim and base optional * fix model patcher for dbrx and add bitwise fix for mistral --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --- optimum/exporters/openvino/model_configs.py | 20 ++ optimum/exporters/openvino/model_patcher.py | 207 ++++++++++++++++++-- 2 files changed, 207 insertions(+), 20 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 8f598293a2..c968e92b2c 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -24,6 +24,7 @@ FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, + MistralOnnxConfig, MPTOnnxConfig, PhiOnnxConfig, UNetOnnxConfig, @@ -53,6 +54,7 @@ InternLMModelPatcher, JaisModelPatcher, LlamaModelPatcher, + MistralModelPatcher, MixtralModelPatcher, MPTModelPatcher, PersimmonModelPatcher, @@ -839,3 +841,21 @@ def patch_model_for_export( ) return ArcticModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "mistral", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class MistralOpenVINOConfig(MistralOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MistralModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 81cf58f9dc..377a0fbf43 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -510,6 +510,39 @@ def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None): return cos, sin +def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000, inv_freq=None) -> torch.Tensor: + # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101 + if inv_freq is None: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) + + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() + emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1) + return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1) + + +def register_sin_cos_buffer(model): + max_positions = model.config.max_position_embeddings + + # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step + # use precomputed + + rotary_emb = model.model.layers[0].self_attn.rotary_emb + dim, base = None, None + inv_freq = getattr(rotary_emb, "inv_freq", None) + if inv_freq is None: + base = rotary_emb.base + dim = rotary_emb.dim + embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq) + + for layer in model.model.layers: + layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions) + layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward + + layer.self_attn.rotary_emb.forward = types.MethodType( + llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb + ) + + class LlamaModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() @@ -521,39 +554,148 @@ def __enter__(self): self._model.model._update_causal_mask = types.MethodType( _llama_gemma_update_causal_mask, self._model.model ) + register_sin_cos_buffer(self._model) - max_positions = self._model.config.max_position_embeddings + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model.model, "_orig_update_causal_mask"): + self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask - # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step - # use precomputed - def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000) -> torch.Tensor: - # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101 - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) + for layer in self._model.model.layers: + layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward - sinusoid_inp = torch.einsum( - "i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq - ).float() - emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1) - return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1) - base = self._model.model.layers[0].self_attn.rotary_emb.base - dim = self._model.model.layers[0].self_attn.rotary_emb.dim - embed_positions = create_sinusoidal_positions(max_positions, dim, base) +# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42 +def _mistral_update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: "Cache", + use_cache: bool, + output_attentions: bool, +): + from transformers.cache_utils import SlidingWindowCache, StaticCache + from transformers.modeling_attn_mask_utils import AttentionMaskConverter - for layer in self._model.model.layers: - layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions) - layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None - layer.self_attn.rotary_emb.forward = types.MethodType( - llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: + target_length = past_key_values.get_max_length() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if self.config.sliding_window is not None: + if not using_sliding_window_cache or sequence_length > self.config.sliding_window: + exclude_mask = exclude_mask.bitwise_or( + torch.arange(target_length, device=device) + <= (cache_position.reshape(-1, 1) - self.config.sliding_window) + ) + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class MistralModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + if is_transformers_version(">=", "4.42.0"): + # apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 + self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask + self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model) + + # mistral has some accuracy issues with bf16 with transformers >= 4.42 + # prefill rotary emb sin/cos for avoid this issue + register_sin_cos_buffer(self._model) + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model.model, "_orig_update_causal_mask"): self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask - for layer in self._model.model.layers: + for layer in self._model.model.layers: + if hasattr(layer.self_attn.rotary_emb, "_orig_forward"): layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward @@ -1283,11 +1425,15 @@ def __enter__(self): rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) + # phi3 has issue with bf16 inference, precollect sin/cos for rotary_position_embedding for avoid accuracy issues + register_sin_cos_buffer(self._model) + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward + layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward def _aquila_self_attn_sdpa_forward( @@ -1807,6 +1953,18 @@ def __enter__(self): _dbrx_update_causal_mask, self._model.transformer ) + # starting from transformers 4.41 issue also observable for calculation sin/cos for rotary_emb + patch_rope_sin_cos = is_transformers_version(">=", "4.41.0") + + inv_freq = getattr(self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb, "inv_freq") + dim, base = None, None + if inv_freq is None: + dim = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.dim + base = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.base + max_positions = self._model.config.max_seq_len + if patch_rope_sin_cos: + embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq) + for block in self._model.transformer.blocks: rotary_emb = block.norm_attn_norm.attn.rotary_emb # initialize inv_freq for torchscript tracing @@ -1815,6 +1973,12 @@ def __enter__(self): rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) rotary_emb.inv_freq = inv_freq + + if patch_rope_sin_cos: + rotary_emb.register_buffer("embed_positions", embed_positions) + rotary_emb._orig_forward = rotary_emb.forward + rotary_emb.forward = types.MethodType(llama_gemma_rotary_emb_forward, rotary_emb) + # remove continue-operator from iteration loop over experts block.ffn.experts._orig_forward = block.ffn.experts.forward block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts) @@ -1825,6 +1989,9 @@ def __exit__(self, exc_type, exc_value, traceback): for block in self._model.transformer.blocks: block.ffn.experts.forward = block.ffn.experts._orig_forward + if hasattr(block.norm_attn_norm.attn.rotary_emb, "_orig_forward"): + block.norm_attn_norm.attn.rotary_emb.forward = block.norm_attn_norm.attn.rotary_emb._orig_forward + # Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264 def _persimmon_self_attn_sdpa_forward(