Skip to content

Commit 36e95fe

Browse files
committed
fix qwen3vl expand generation with video and add
1 parent 08f52e2 commit 36e95fe

File tree

4 files changed

+159
-16
lines changed

4 files changed

+159
-16
lines changed

src/transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,15 +1535,16 @@ def _expand_inputs_for_generation(
15351535
input_ids: Optional[torch.LongTensor] = None,
15361536
**model_kwargs,
15371537
) -> tuple[torch.LongTensor, dict[str, Any]]:
1538-
# Overwritten -- Support for expanding tensors without a batch size dimension
1539-
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1538+
# Overwritten -- Qwen3VL use timestamps and remove second_per_grid_ts
1539+
# Support for expanding tensors without a batch size dimension
1540+
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
15401541
# pixel_values.shape[0] is sum(seqlen_images for samples)
15411542
# image_grid_thw.shape[0] is sum(num_images for samples)
15421543

15431544
if expand_size == 1:
15441545
return input_ids, model_kwargs
15451546

1546-
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1547+
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
15471548

15481549
def _expand_dict_for_generation_visual(dict_to_expand):
15491550
image_grid_thw = model_kwargs.get("image_grid_thw", None)
@@ -1552,6 +1553,17 @@ def _expand_dict_for_generation_visual(dict_to_expand):
15521553
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
15531554
)
15541555

1556+
# video_nums: (batch_size,)
1557+
# since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
1558+
# but qwen3vl append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
1559+
if video_grid_thw is not None:
1560+
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
1561+
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
1562+
# Find video boundaries in cumulative_frame_counts
1563+
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
1564+
# example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
1565+
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))
1566+
15551567
def _repeat_interleave_samples(x, lengths, repeat_times):
15561568
samples = torch.split(x, lengths)
15571569
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
@@ -1584,10 +1596,6 @@ def _repeat_interleave_samples(x, lengths, repeat_times):
15841596
dict_to_expand[key] = _repeat_interleave_samples(
15851597
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
15861598
)
1587-
elif key == "second_per_grid_ts":
1588-
dict_to_expand[key] = _repeat_interleave_samples(
1589-
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1590-
)
15911599
return dict_to_expand
15921600

15931601
def _expand_dict_for_generation(dict_to_expand):

src/transformers/models/qwen3_vl/modular_qwen3_vl.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""PyTorch Qwen3-VL model."""
1616

1717
from collections.abc import Callable
18-
from typing import Optional, Union
18+
from typing import Any, Optional, Union
1919

2020
import numpy as np
2121
import torch
@@ -1242,6 +1242,101 @@ def prepare_inputs_for_generation(
12421242

12431243
return model_inputs
12441244

1245+
def _expand_inputs_for_generation(
1246+
self,
1247+
expand_size: int = 1,
1248+
is_encoder_decoder: bool = False,
1249+
input_ids: Optional[torch.LongTensor] = None,
1250+
**model_kwargs,
1251+
) -> tuple[torch.LongTensor, dict[str, Any]]:
1252+
# Overwritten -- Qwen3VL use timestamps and remove second_per_grid_ts
1253+
# Support for expanding tensors without a batch size dimension
1254+
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
1255+
# pixel_values.shape[0] is sum(seqlen_images for samples)
1256+
# image_grid_thw.shape[0] is sum(num_images for samples)
1257+
1258+
if expand_size == 1:
1259+
return input_ids, model_kwargs
1260+
1261+
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
1262+
1263+
def _expand_dict_for_generation_visual(dict_to_expand):
1264+
image_grid_thw = model_kwargs.get("image_grid_thw", None)
1265+
video_grid_thw = model_kwargs.get("video_grid_thw", None)
1266+
image_nums, video_nums = self._get_image_nums_and_video_nums(
1267+
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1268+
)
1269+
1270+
# video_nums: (batch_size,)
1271+
# since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
1272+
# but qwen3vl append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
1273+
if video_grid_thw is not None:
1274+
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
1275+
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
1276+
# Find video boundaries in cumulative_frame_counts
1277+
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
1278+
# example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
1279+
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))
1280+
1281+
def _repeat_interleave_samples(x, lengths, repeat_times):
1282+
samples = torch.split(x, lengths)
1283+
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1284+
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1285+
return result
1286+
1287+
for key in dict_to_expand:
1288+
if key == "pixel_values":
1289+
# split images into samples
1290+
samples = torch.split(image_grid_thw, list(image_nums))
1291+
# compute the sequence length of images for each sample
1292+
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1293+
dict_to_expand[key] = _repeat_interleave_samples(
1294+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1295+
)
1296+
elif key == "image_grid_thw":
1297+
# get the num of images for each sample
1298+
lengths = list(image_nums)
1299+
dict_to_expand[key] = _repeat_interleave_samples(
1300+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1301+
)
1302+
elif key == "pixel_values_videos":
1303+
samples = torch.split(video_grid_thw, list(video_nums))
1304+
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1305+
dict_to_expand[key] = _repeat_interleave_samples(
1306+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1307+
)
1308+
elif key == "video_grid_thw":
1309+
lengths = list(video_nums)
1310+
dict_to_expand[key] = _repeat_interleave_samples(
1311+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1312+
)
1313+
return dict_to_expand
1314+
1315+
def _expand_dict_for_generation(dict_to_expand):
1316+
for key in dict_to_expand:
1317+
if (
1318+
key != "cache_position"
1319+
and dict_to_expand[key] is not None
1320+
and isinstance(dict_to_expand[key], torch.Tensor)
1321+
and key not in visual_keys
1322+
):
1323+
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1324+
return dict_to_expand
1325+
1326+
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1327+
1328+
if input_ids is not None:
1329+
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1330+
1331+
model_kwargs = _expand_dict_for_generation(model_kwargs)
1332+
1333+
if is_encoder_decoder:
1334+
if model_kwargs.get("encoder_outputs") is None:
1335+
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1336+
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1337+
1338+
return input_ids, model_kwargs
1339+
12451340

12461341
class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
12471342
_defaults = {

src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,15 +1740,16 @@ def _expand_inputs_for_generation(
17401740
input_ids: Optional[torch.LongTensor] = None,
17411741
**model_kwargs,
17421742
) -> tuple[torch.LongTensor, dict[str, Any]]:
1743-
# Overwritten -- Support for expanding tensors without a batch size dimension
1744-
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1743+
# Overwritten -- Qwen3VLMoe use timestamps and remove second_per_grid_ts
1744+
# Support for expanding tensors without a batch size dimension
1745+
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
17451746
# pixel_values.shape[0] is sum(seqlen_images for samples)
17461747
# image_grid_thw.shape[0] is sum(num_images for samples)
17471748

17481749
if expand_size == 1:
17491750
return input_ids, model_kwargs
17501751

1751-
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1752+
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
17521753

17531754
def _expand_dict_for_generation_visual(dict_to_expand):
17541755
image_grid_thw = model_kwargs.get("image_grid_thw", None)
@@ -1757,6 +1758,17 @@ def _expand_dict_for_generation_visual(dict_to_expand):
17571758
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
17581759
)
17591760

1761+
# video_nums: (batch_size,)
1762+
# since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
1763+
# but Qwen3VLMoe append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
1764+
if video_grid_thw is not None:
1765+
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
1766+
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
1767+
# Find video boundaries in cumulative_frame_counts
1768+
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
1769+
# example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
1770+
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))
1771+
17601772
def _repeat_interleave_samples(x, lengths, repeat_times):
17611773
samples = torch.split(x, lengths)
17621774
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
@@ -1789,10 +1801,6 @@ def _repeat_interleave_samples(x, lengths, repeat_times):
17891801
dict_to_expand[key] = _repeat_interleave_samples(
17901802
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
17911803
)
1792-
elif key == "second_per_grid_ts":
1793-
dict_to_expand[key] = _repeat_interleave_samples(
1794-
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1795-
)
17961804
return dict_to_expand
17971805

17981806
def _expand_dict_for_generation(dict_to_expand):

tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def test_video_forward(self):
305305

306306

307307
@require_torch
308-
@unittest.skip("The checkpoint is not yet released")
309308
class Qwen3VLMoeIntegrationTest(unittest.TestCase):
310309
def setUp(self):
311310
cleanup(torch_device, gc_collect=True)
@@ -336,6 +335,18 @@ def setUp(self):
336335
],
337336
}
338337
]
338+
self.message3 = [
339+
{
340+
"role": "user",
341+
"content": [
342+
{
343+
"type": "video",
344+
"url": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
345+
},
346+
{"type": "text", "text": "Describe the video in short."},
347+
],
348+
}
349+
]
339350

340351
def tearDown(self):
341352
cleanup(torch_device, gc_collect=True)
@@ -455,6 +466,27 @@ def test_small_model_integration_test_expand(self):
455466
EXPECTED_DECODED_TEXT,
456467
)
457468

469+
@slow
470+
def test_small_model_integration_test_expand_with_video(self):
471+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
472+
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto"
473+
)
474+
inputs = self.processor.apply_chat_template(
475+
self.message3, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
476+
).to(torch_device)
477+
478+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, num_beams=2, num_return_sequences=2)
479+
480+
EXPECTED_DECODED_TEXT = [
481+
"user\n<0.3 seconds><1.3 seconds><2.4 seconds><3.5 seconds><4.6 seconds><5.6 seconds><6.7 seconds><7.8 seconds><8.9 seconds><9.7 seconds>Describe the video in short.\nassistant\nA baby wearing glasses sits on a bed and flips through a book.",
482+
"user\n<0.3 seconds><1.3 seconds><2.4 seconds><3.5 seconds><4.6 seconds><5.6 seconds><6.7 seconds><7.8 seconds><8.9 seconds><9.7 seconds>Describe the video in short.\nassistant\nA baby wearing glasses sits on a bed and flips through the pages of a book."
483+
] # fmt: skip
484+
485+
self.assertEqual(
486+
self.processor.batch_decode(output, skip_special_tokens=True),
487+
EXPECTED_DECODED_TEXT,
488+
)
489+
458490
@slow
459491
def test_small_model_integration_test_batch_wo_image(self):
460492
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(

0 commit comments

Comments
 (0)