diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 3649c163c6..35983975ed 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -291,11 +291,17 @@ def __exit__(self, exc_type, exc_value, traceback): # adopted from # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965 # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058 -def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, **kwargs): +def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): from transformers.modeling_attn_mask_utils import AttentionMaskConverter - # for compatibility with https://github.com/huggingface/transformers/pull/30047 - current_length = kwargs.get("current_length", cache_position[-1]) + if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None: + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device # using minimum from dtype with larger bandwith (floa32) may lead to overflow @@ -305,7 +311,13 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + if past_seen_tokens is not None: + current_length = past_seen_tokens + sequence_length + 1 + # TODO : remove after support of transformers >= v4.40.0 + else: + current_length = cache_position[-1] + 1 + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype if sequence_length != 1: diff --git a/setup.py b/setup.py index ea87e6ad59..a7937bd1e4 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ INSTALL_REQUIRE = [ "torch>=1.11", - "transformers>=4.36.0,<4.40.0", + "transformers>=4.36.0,<4.41.0", "optimum~=1.19", "datasets>=1.4.0", "sentencepiece",