Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,15 +1535,16 @@ def _expand_inputs_for_generation(
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
# Overwritten -- Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
# Overwritten -- Qwen3VL use timestamps and remove second_per_grid_ts
# Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
# pixel_values.shape[0] is sum(seqlen_images for samples)
# image_grid_thw.shape[0] is sum(num_images for samples)

if expand_size == 1:
return input_ids, model_kwargs

visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]

def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
Expand All @@ -1552,6 +1553,17 @@ def _expand_dict_for_generation_visual(dict_to_expand):
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)

# video_nums: (batch_size,)
# since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
# but qwen3vl append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
if video_grid_thw is not None:
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
# Find video boundaries in cumulative_frame_counts
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
# example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))

def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
Expand Down Expand Up @@ -1584,10 +1596,6 @@ def _repeat_interleave_samples(x, lengths, repeat_times):
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "second_per_grid_ts":
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
)
return dict_to_expand

def _expand_dict_for_generation(dict_to_expand):
Expand Down
97 changes: 96 additions & 1 deletion src/transformers/models/qwen3_vl/modular_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""PyTorch Qwen3-VL model."""

from collections.abc import Callable
from typing import Optional, Union
from typing import Any, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -1242,6 +1242,101 @@ def prepare_inputs_for_generation(

return model_inputs

def _expand_inputs_for_generation(
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
# Overwritten -- Qwen3VL use timestamps and remove second_per_grid_ts
# Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
# pixel_values.shape[0] is sum(seqlen_images for samples)
# image_grid_thw.shape[0] is sum(num_images for samples)

if expand_size == 1:
return input_ids, model_kwargs

visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]

def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)

# video_nums: (batch_size,)
# since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
# but qwen3vl append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
if video_grid_thw is not None:
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
# Find video boundaries in cumulative_frame_counts
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
# example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))

def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result

for key in dict_to_expand:
if key == "pixel_values":
# split images into samples
samples = torch.split(image_grid_thw, list(image_nums))
# compute the sequence length of images for each sample
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
# get the num of images for each sample
lengths = list(image_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "pixel_values_videos":
samples = torch.split(video_grid_thw, list(video_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "video_grid_thw":
lengths = list(video_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
return dict_to_expand

def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand

model_kwargs = _expand_dict_for_generation_visual(model_kwargs)

if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)

model_kwargs = _expand_dict_for_generation(model_kwargs)

if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])

return input_ids, model_kwargs


class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
Expand Down
22 changes: 15 additions & 7 deletions src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,15 +1740,16 @@ def _expand_inputs_for_generation(
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
# Overwritten -- Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
# Overwritten -- Qwen3VLMoe use timestamps and remove second_per_grid_ts
# Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
# pixel_values.shape[0] is sum(seqlen_images for samples)
# image_grid_thw.shape[0] is sum(num_images for samples)

if expand_size == 1:
return input_ids, model_kwargs

visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]

def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
Expand All @@ -1757,6 +1758,17 @@ def _expand_dict_for_generation_visual(dict_to_expand):
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)

# video_nums: (batch_size,)
# since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
# but Qwen3VLMoe append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
if video_grid_thw is not None:
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
# Find video boundaries in cumulative_frame_counts
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
# example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))

def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
Expand Down Expand Up @@ -1789,10 +1801,6 @@ def _repeat_interleave_samples(x, lengths, repeat_times):
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "second_per_grid_ts":
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
)
return dict_to_expand

def _expand_dict_for_generation(dict_to_expand):
Expand Down
34 changes: 33 additions & 1 deletion tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def test_video_forward(self):


@require_torch
@unittest.skip("The checkpoint is not yet released")
class Qwen3VLMoeIntegrationTest(unittest.TestCase):
def setUp(self):
cleanup(torch_device, gc_collect=True)
Expand Down Expand Up @@ -336,6 +335,18 @@ def setUp(self):
],
}
]
self.message3 = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
},
{"type": "text", "text": "Describe the video in short."},
],
}
]

def tearDown(self):
cleanup(torch_device, gc_collect=True)
Expand Down Expand Up @@ -455,6 +466,27 @@ def test_small_model_integration_test_expand(self):
EXPECTED_DECODED_TEXT,
)

@slow
def test_small_model_integration_test_expand_with_video(self):
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto"
)
inputs = self.processor.apply_chat_template(
self.message3, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
).to(torch_device)
Comment on lines +474 to +476
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ig we need to do either beam search or sampling with num_return_sequences > 1 to tigger the needed behavior, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the test uses num_beams=2 and num_return_sequences=2.

output = model.generate(**inputs, max_new_tokens=30, do_sample=False, num_beams=2, num_return_sequences=2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I didn't see that. Oke


output = model.generate(**inputs, max_new_tokens=30, do_sample=False, num_beams=2, num_return_sequences=2)

EXPECTED_DECODED_TEXT = [
"user\n<0.3 seconds><1.3 seconds><2.4 seconds><3.5 seconds><4.6 seconds><5.6 seconds><6.7 seconds><7.8 seconds><8.9 seconds><9.7 seconds>Describe the video in short.\nassistant\nA baby wearing glasses sits on a bed and flips through a book.",
"user\n<0.3 seconds><1.3 seconds><2.4 seconds><3.5 seconds><4.6 seconds><5.6 seconds><6.7 seconds><7.8 seconds><8.9 seconds><9.7 seconds>Describe the video in short.\nassistant\nA baby wearing glasses sits on a bed and flips through the pages of a book."
] # fmt: skip

self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
def test_small_model_integration_test_batch_wo_image(self):
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
Expand Down
Loading