From 0447ae2fbe7f638edce1e9770af443fa9084af31 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 12 Nov 2024 19:14:08 +0400 Subject: [PATCH] add patching for update_causal_mask to falcon for >= 4.45 (#989) * add patching for update_causal_mask to falcon and gpt-like models for >=4.45 * fix falcon * enable codegen2 back * Apply suggestions from code review Co-authored-by: Nikita Savelyev * Update optimum/exporters/openvino/model_patcher.py --------- Co-authored-by: Nikita Savelyev --- optimum/exporters/openvino/model_configs.py | 20 ++ optimum/exporters/openvino/model_patcher.py | 228 ++++++++++++++++++-- tests/openvino/test_modeling.py | 5 +- 3 files changed, 237 insertions(+), 16 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 5276ade33b..e8c8e5d134 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -29,6 +29,7 @@ CodeGenOnnxConfig, FalconOnnxConfig, GemmaOnnxConfig, + GPTJOnnxConfig, GPTNeoXOnnxConfig, IBertOnnxConfig, LlamaOnnxConfig, @@ -66,6 +67,7 @@ FalconModelPatcher, FluxTransfromerModelPatcher, Gemma2ModelPatcher, + GptJModelPatcher, GptNeoxJapaneseModelPatcher, GptNeoxModelPatcher, IBertModelPatcher, @@ -726,6 +728,24 @@ def patch_model_for_export( return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs) +@register_in_tasks_manager( + "gptj", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class GPTJOpenVINOConfig(GPTJOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return GptJModelPatcher(self, model, model_kwargs=model_kwargs) + + @register_in_tasks_manager( "cohere", *[ diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index dbbfb56629..7406e13702 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -109,11 +109,20 @@ def patch_model_with_bettertransformer(model): return model -def patch_update_causal_mask(model, transformers_version): +def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None): if is_transformers_version(">=", transformers_version): - inner_model = getattr(model, "model", getattr(model, "transformer", None)) + inner_model = getattr(model, inner_model_name, None) if inner_model is not None: - inner_model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, inner_model) + if hasattr(inner_model, "_update_causal_mask"): + inner_model._orig_update_causal_mask = inner_model._update_causal_mask + patch_fn = patch_fn or _llama_gemma_update_causal_mask + inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model) + + +def unpatch_update_causal_mask(model, inner_model_name="model"): + inner_model = getattr(model, inner_model_name, None) + if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"): + inner_model._update_causal_mask = inner_model._orig_update_causal_mask # initialization of sin/cos cached in bf16/fp16 leads to accuracy loss @@ -579,13 +588,11 @@ def __enter__(self): # llama/gemma has some accuracy issues with bf16 with transformers >= 4.39 # fill causal mask in slightly different way for avoid overflow on some platforms - patch_update_causal_mask(self._model, "4.39.0") + patch_update_causal_mask(self._model, "4.39.0", "model" if hasattr(self._model, "model") else "transformer") def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - inner_model = getattr(self._model, "model", getattr(self._model, "transformer", None)) - if hasattr(inner_model, "_orig_update_causal_mask"): - inner_model._update_causal_mask = inner_model._orig_update_causal_mask + unpatch_update_causal_mask(self._model, "model" if hasattr(self._model, "model") else "transformer") # copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42 @@ -1865,6 +1872,67 @@ def __exit__(self, exc_type, exc_value, traceback): layer.self_attn.forward = layer.self_attn._orig_forward +# copied from https://github.com/huggingface/optimum/blob/2112e99122d7f23a1da1a9d263fef64301050ea7/optimum/bettertransformer/models/attention.py#L168 +# for preserving backward compatibility between outdated codegen remote code and new transformers +def _codegen_wrapped_scaled_dot_product_legacy( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, +): + from optimum.bettertransformer.models.attention import raise_on_head_mask + + raise_on_head_mask(head_mask) + batch_size = query.shape[0] + mask_value = torch.finfo(value.dtype).min + mask_value = torch.full([], mask_value, dtype=value.dtype) + + if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1: + raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") + + # in codegen the query and key are always in fp32 regardless of the dtype of the model + # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 + query = query.to(value.dtype) + key = key.to(value.dtype) + + dropout_p = self.dropout_prob_attn if self.training else 0.0 + if batch_size == 1 or self.training: + if query.shape[2] > 1: + # first step of the decoding + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True + ) + else: + # in this case, which is the later decoding steps, the `causal_mask`` in + # https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195 + # is [True, ..., True] so actually not causal + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False + ) + else: + query_length, key_length = query.size(-2), key.size(-2) + + # causal_mask is always [True, ..., True] otherwise, so executing this is unnecessary + if query_length > 1: + causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + + causal_mask = torch.where(causal_mask, 0, mask_value) + + # torch.Tensor.expand does no memory copy + causal_mask = causal_mask.expand(batch_size, -1, -1, -1) + + # we use torch.min to avoid having tensor(-inf) + attention_mask = torch.min(causal_mask, attention_mask) + + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False + ) + + return sdpa_result, None + + class CodeGenModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() @@ -1873,14 +1941,23 @@ def __enter__(self): # For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn. from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product + attn_fn = codegen_wrapped_scaled_dot_product + if is_torch_version(">=", "2.1.0") and is_transformers_version(">=", "4.45"): + # in transformers 4.45 causal_mask const buffer was removed from the model + # if it still exists, it means legacy remote code was loaded + if hasattr(self._model.transformer.h[0].attn, "causal_mask"): + attn_fn = _codegen_wrapped_scaled_dot_product_legacy + for layer in self._model.transformer.h: if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions: orig_self_attn_fwd = layer.attn._attn - layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn) + layer.attn._attn = types.MethodType(attn_fn, layer.attn) layer.attn._orig_attn = orig_self_attn_fwd + patch_update_causal_mask(self._model, "4.45.0", "transformer") def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) + unpatch_update_causal_mask(self._model, "transformer") for layer in self._model.transformer.h: if hasattr(layer.attn, "_orig_attn"): layer.attn._attn = layer.attn._orig_attn @@ -2275,8 +2352,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if hasattr(self._model.model, "_orig_update_causal_mask"): - self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask + unpatch_update_causal_mask(self._model) for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward @@ -2413,8 +2489,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if hasattr(self._model.model, "_orig_update_causal_mask"): - self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask + unpatch_update_causal_mask(self._model) class RotaryEmbPatcher(DecoderModelPatcher): @@ -2425,12 +2500,119 @@ def __enter__(self): _reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb) +def _falcon_update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: "Cache", + output_attentions: bool, + head_mask: torch.Tensor, + alibi: torch.Tensor, +): + # copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130 + from transformers.cache_utils import StaticCache + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if hasattr(self, "_prepare_4d_causal_attention_mask_with_cache_position"): + _prepare_4d_causal_attention_mask_with_cache_position = ( + self._prepare_4d_causal_attention_mask_with_cache_position + ) + else: + from transformers.models.falcon.modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not using_static_cache + and not output_attentions + and head_mask is None + and alibi is None + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution + min_dtype = torch.finfo(torch.float16).min + batch_size, sequence_length, _ = input_tensor.shape + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + # We take care to integrate alibi bias in the causal_mask here + if head_mask is None and alibi is not None: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + causal_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + causal_mask < -1, + min_dtype, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + class FalconModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version("<", "4.44.99"): for layer in self._model.transformer.h: _reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb) + else: + patch_update_causal_mask(self._model, "4.45.0", "transformer", _falcon_update_causal_mask) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + unpatch_update_causal_mask(self._model, "transformer") class GptNeoxModelPatcher(DecoderModelPatcher): @@ -2439,6 +2621,22 @@ def __enter__(self): if is_transformers_version("<", "4.44.99"): for layer in self._model.gpt_neox.layers: _reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb) + else: + patch_update_causal_mask(self._model, "4.45.0", "gpt_neox") + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + unpatch_update_causal_mask(self._model, "gpt_neox") + + +class GptJModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + patch_update_causal_mask(self._model, "4.45.0", "transformer") + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + unpatch_update_causal_mask(self._model, "transformer") class GptNeoxJapaneseModelPatcher(DecoderModelPatcher): @@ -2447,6 +2645,12 @@ def __enter__(self): if is_transformers_version("<", "4.44.99"): for layer in self._model.gpt_neox_japanese.layers: _reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb) + else: + patch_update_causal_mask(self._model, "4.45.0", "gpt_neox_japanese") + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + unpatch_update_causal_mask(self._model, "gpt_neox_japanese") class Gemma2ModelPatcher(LlamaModelPatcher): diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 916833602d..0218f6d0e1 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -773,6 +773,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "bloom", "chatglm", "codegen", + "codegen2", "gpt2", "gpt_neo", "gpt_neox", @@ -821,10 +822,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "mistral-nemo", ) - # custom modeling defined in https://huggingface.co/katuni4ka/tiny-random-codegen2 differs from transformers after v4.45 resulting in unadapted patching - if is_transformers_version("<", "4.45.0"): - SUPPORTED_ARCHITECTURES += ("codegen2",) - GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( "chatglm",