diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 592cd85a4..859360e8b 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -49,7 +49,7 @@ ) -FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"} +FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager", "gemma2": "sdpa"} if TYPE_CHECKING: from optimum.intel.openvino.configuration import OVConfig diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 825eaac48..7a6b2998c 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -2710,26 +2710,6 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward - def __enter__(self): - super().__enter__() - if is_transformers_version(">=", "4.45.0"): - from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES - - sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"] - eager_attn = GEMMA2_ATTENTION_CLASSES["eager"] - - for layer in self._model.model.layers: - if isinstance(layer.self_attn, eager_attn): - layer.self_attn._orig_forward = layer.self_attn.forward - layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) - - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - if is_transformers_version(">=", "4.45.0"): - for layer in self._model.model.layers: - if hasattr(layer.self_attn, "_orig_forward"): - layer.self_attn.forward = layer.self_attn._orig_forward - def _decilm_attn_forward( self, diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index fe85f9212..1c0e35cca 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: - from PIL import Image + from PIL.Image import Image logger = logging.getLogger(__name__) @@ -2100,6 +2100,8 @@ def __init__( quantization_config=quantization_config, **kwargs, ) + self.rope_deltas = None # cache rope_deltas here + if is_transformers_version(">=", "4.45.0"): from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLForConditionalGeneration, @@ -2197,6 +2199,7 @@ def get_multimodal_embeddings( pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, + cache_position=None, **kwargs, ): inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids)) @@ -2209,6 +2212,26 @@ def get_multimodal_embeddings( video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw)) video_mask = input_ids == self.config.video_token_id inputs_embeds[video_mask] = video_embeds + + # if we get 4D attention mask we cannot calculate rope deltas anymore. + if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + return inputs_embeds, attention_mask, position_ids def forward( diff --git a/setup.py b/setup.py index f78052b4b..0f02ef15c 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ INSTALL_REQUIRE = [ "torch>=1.11", "optimum@git+https://github.com/huggingface/optimum.git", - "transformers>=4.36,<4.47", + "transformers>=4.36,<4.48", "datasets>=1.4.0", "sentencepiece", "setuptools",