From 36e95feaf767ffff5ff654ddd2a132f7cd64388f Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Fri, 7 Nov 2025 20:42:08 +0800 Subject: [PATCH] fix qwen3vl expand generation with video and add --- .../models/qwen3_vl/modeling_qwen3_vl.py | 22 +++-- .../models/qwen3_vl/modular_qwen3_vl.py | 97 ++++++++++++++++++- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 22 +++-- .../test_modeling_qwen3_vl_moe.py | 34 ++++++- 4 files changed, 159 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 37f6a5146053..05c6204d051d 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1535,15 +1535,16 @@ def _expand_inputs_for_generation( input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> tuple[torch.LongTensor, dict[str, Any]]: - # Overwritten -- Support for expanding tensors without a batch size dimension - # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # Overwritten -- Qwen3VL use timestamps and remove second_per_grid_ts + # Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw # pixel_values.shape[0] is sum(seqlen_images for samples) # image_grid_thw.shape[0] is sum(num_images for samples) if expand_size == 1: return input_ids, model_kwargs - visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) @@ -1552,6 +1553,17 @@ def _expand_dict_for_generation_visual(dict_to_expand): input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) ) + # video_nums: (batch_size,) + # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), + # but qwen3vl append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw + if video_grid_thw is not None: + cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) + cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) + # Find video boundaries in cumulative_frame_counts + video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) + # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] + video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) + def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) repeat_args = [repeat_times] + [1] * (x.dim() - 1) @@ -1584,10 +1596,6 @@ def _repeat_interleave_samples(x, lengths, repeat_times): dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) - elif key == "second_per_grid_ts": - dict_to_expand[key] = _repeat_interleave_samples( - dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size - ) return dict_to_expand def _expand_dict_for_generation(dict_to_expand): diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 7758a23e2970..514fc041adf5 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -15,7 +15,7 @@ """PyTorch Qwen3-VL model.""" from collections.abc import Callable -from typing import Optional, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -1242,6 +1242,101 @@ def prepare_inputs_for_generation( return model_inputs + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Qwen3VL use timestamps and remove second_per_grid_ts + # Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + # video_nums: (batch_size,) + # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), + # but qwen3vl append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw + if video_grid_thw is not None: + cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) + cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) + # Find video boundaries in cumulative_frame_counts + video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) + # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] + video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False): _defaults = { diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 264902c2d8a4..6476eb9150ed 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1740,15 +1740,16 @@ def _expand_inputs_for_generation( input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> tuple[torch.LongTensor, dict[str, Any]]: - # Overwritten -- Support for expanding tensors without a batch size dimension - # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # Overwritten -- Qwen3VLMoe use timestamps and remove second_per_grid_ts + # Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw # pixel_values.shape[0] is sum(seqlen_images for samples) # image_grid_thw.shape[0] is sum(num_images for samples) if expand_size == 1: return input_ids, model_kwargs - visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) @@ -1757,6 +1758,17 @@ def _expand_dict_for_generation_visual(dict_to_expand): input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) ) + # video_nums: (batch_size,) + # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), + # but Qwen3VLMoe append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw + if video_grid_thw is not None: + cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) + cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) + # Find video boundaries in cumulative_frame_counts + video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) + # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] + video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) + def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) repeat_args = [repeat_times] + [1] * (x.dim() - 1) @@ -1789,10 +1801,6 @@ def _repeat_interleave_samples(x, lengths, repeat_times): dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) - elif key == "second_per_grid_ts": - dict_to_expand[key] = _repeat_interleave_samples( - dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size - ) return dict_to_expand def _expand_dict_for_generation(dict_to_expand): diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index ea48ef000d42..fcbb38260b8e 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -305,7 +305,6 @@ def test_video_forward(self): @require_torch -@unittest.skip("The checkpoint is not yet released") class Qwen3VLMoeIntegrationTest(unittest.TestCase): def setUp(self): cleanup(torch_device, gc_collect=True) @@ -336,6 +335,18 @@ def setUp(self): ], } ] + self.message3 = [ + { + "role": "user", + "content": [ + { + "type": "video", + "url": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", + }, + {"type": "text", "text": "Describe the video in short."}, + ], + } + ] def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -455,6 +466,27 @@ def test_small_model_integration_test_expand(self): EXPECTED_DECODED_TEXT, ) + @slow + def test_small_model_integration_test_expand_with_video(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto" + ) + inputs = self.processor.apply_chat_template( + self.message3, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, num_beams=2, num_return_sequences=2) + + EXPECTED_DECODED_TEXT = [ + "user\n<0.3 seconds><1.3 seconds><2.4 seconds><3.5 seconds><4.6 seconds><5.6 seconds><6.7 seconds><7.8 seconds><8.9 seconds><9.7 seconds>Describe the video in short.\nassistant\nA baby wearing glasses sits on a bed and flips through a book.", + "user\n<0.3 seconds><1.3 seconds><2.4 seconds><3.5 seconds><4.6 seconds><5.6 seconds><6.7 seconds><7.8 seconds><8.9 seconds><9.7 seconds>Describe the video in short.\nassistant\nA baby wearing glasses sits on a bed and flips through the pages of a book." + ] # fmt: skip + + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + @slow def test_small_model_integration_test_batch_wo_image(self): model = Qwen3VLMoeForConditionalGeneration.from_pretrained(