Skip to content

Commit

Permalink
latest qwen2 vl position_ids formula
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 6, 2025
1 parent 35c47a2 commit c6c4a25
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
12 changes: 6 additions & 6 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down
28 changes: 24 additions & 4 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


if TYPE_CHECKING:
from PIL import Image
from PIL.Image import Image


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -166,9 +166,6 @@ def prepare_inputs(
if past_len:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]

if self.config.model_type == "qwen2_vl" and position_ids.ndim != 3:
position_ids = np.repeat(np.expand_dims(position_ids, 0), 3, axis=0)

inputs["position_ids"] = position_ids

if "beam_idx" in self.input_names:
Expand Down Expand Up @@ -2100,6 +2097,8 @@ def __init__(
quantization_config=quantization_config,
**kwargs,
)
self.rope_deltas = None # cache rope_deltas here

if is_transformers_version(">=", "4.45.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration,
Expand Down Expand Up @@ -2197,6 +2196,7 @@ def get_multimodal_embeddings(
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
cache_position=None,
**kwargs,
):
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
Expand All @@ -2209,6 +2209,26 @@ def get_multimodal_embeddings(
video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw))
video_mask = input_ids == self.config.video_token_id
inputs_embeds[video_mask] = video_embeds

# if we get 4D attention mask we cannot calculate rope deltas anymore.
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

return inputs_embeds, attention_mask, position_ids

def forward(
Expand Down

0 comments on commit c6c4a25

Please sign in to comment.