-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Open
Description
System Info
Inference with Qwen3-VL, num beam > 1 and video inputs failed:
[rank0]: File "/xx/python/transformers/src/transformers/trainer_seq2seq.py", line 255, in predict
[rank0]: return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/python/transformers/src/transformers/trainer.py", line 4567, in predict
[rank0]: output = eval_loop(
[rank0]: ^^^^^^^^^^
[rank0]: File "/xx/python/transformers/src/transformers/trainer.py", line 4685, in evaluation_loop
[rank0]: losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/python/LLaMA-Factory-latest/src/llamafactory/train/sft/trainer.py", line 137, in prediction_step
[rank0]: loss, generated_tokens, _ = super().prediction_step(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/python/transformers/src/transformers/trainer_seq2seq.py", line 327, in prediction_step
[rank0]: generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/miniconda3/envs/vlm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/python/transformers/src/transformers/generation/utils.py", line 2482, in generate
[rank0]: input_ids, model_kwargs = self._expand_inputs_for_generation(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/python/transformers/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 1540, in _expand_inputs_for_generation
[rank0]: model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/python/transformers/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 1513, in _expand_dict_for_generation_visual
[rank0]: samples = torch.split(video_grid_thw, list(video_nums))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/miniconda3/envs/vlm/lib/python3.11/site-packages/torch/functional.py", line 222, in split
[rank0]: return tensor.split(split_size_or_sections, dim)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/xx/miniconda3/envs/vlm/lib/python3.11/site-packages/torch/_tensor.py", line 1052, in split
[rank0]: return torch._VF.split_with_sizes(self, split_size, dim)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: split_with_sizes expects split_sizes to sum exactly to 2 (input tensor's size at dimension 0), but got split_sizes=[8, 7]
It seems that the vision_start_token insertion in Qwen3-VL is different from Qwen2-VL, i.e, one vision_start_token for each frame in videos:
transformers/src/transformers/models/qwen3_vl/processing_qwen3_vl.py
Lines 200 to 215 in 87be559
| video_placeholder = "" | |
| frame_seqlen = video_grid_thw[index][1:].prod() // merge_length | |
| for frame_idx in range(video_grid_thw[index][0]): | |
| curr_time = curr_timestamp[frame_idx] | |
| video_placeholder += f"<{curr_time:.1f} seconds>" | |
| video_placeholder += ( | |
| self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token | |
| ) | |
| if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]: | |
| text[i] = text[i].replace( | |
| f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1 | |
| ) | |
| else: | |
| # vllm may input video token directly | |
| text[i] = text[i].replace(self.video_token, video_placeholder, 1) | |
| index += 1 |
however, the function in Qwen3-VL _get_image_nums_and_video_nums counts the vision_start_token as video number:
transformers/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Lines 1491 to 1493 in 87be559
| vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) | |
| image_nums = torch.sum(vision_first_mask & image_mask, dim=1) | |
| video_nums = torch.sum(vision_first_mask & video_mask, dim=1) |
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Inference with Qwen3-VL, video input, and num beams > 1.
Expected behavior
Correctly inference.
zucchini-nlp