diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 5ca2156c08b5..c705a70b93f5 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
-| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index f8ddb5a22b31..1d6d819ff58a 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -576,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# Intern-S1
def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
- model_name = "internlm/Intern-S1"
+ model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs(
model=model_name,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 51b41f34b2ff..e0d95758a822 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
- model_name = "internlm/Intern-S1"
+ model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs(
model=model_name,
diff --git a/tests/conftest.py b/tests/conftest.py
index 66106d1bf779..c61a8f8dd539 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -756,7 +756,7 @@ def __init__(
def get_inputs(
self,
- prompts: Union[list[str], list[torch.Tensor], list[int]],
+ prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py
index e0ecb02d4f56..5af4327b65d0 100644
--- a/tests/entrypoints/llm/test_generate.py
+++ b/tests/entrypoints/llm/test_generate.py
@@ -86,3 +86,16 @@ def test_max_model_len():
# It can be less if generation finishes due to other reasons (e.g., EOS)
# before reaching the absolute model length limit.
assert num_total_tokens <= max_model_len
+
+
+def test_log_stats():
+ llm = LLM(
+ model=MODEL_NAME,
+ disable_log_stats=False,
+ gpu_memory_utilization=0.10,
+ enforce_eager=True, # reduce test time
+ )
+ outputs = llm.generate(PROMPTS, sampling_params=None)
+
+ # disable_log_stats is False, every output should have metrics
+ assert all(output.metrics is not None for output in outputs)
diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py
index e60a86075b8b..9d67b46f2e3e 100644
--- a/tests/models/language/generation/test_hybrid.py
+++ b/tests/models/language/generation/test_hybrid.py
@@ -240,12 +240,12 @@ def test_distributed_correctness(
num_logprobs: int,
) -> None:
with vllm_runner(model, tensor_parallel_size=1,
- max_num_seqs=2) as vllm_model:
+ max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, tensor_parallel_size=2,
- max_num_seqs=2) as vllm_model:
+ max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py
index 0292845f819c..e5caf0eae37d 100644
--- a/vllm/model_executor/models/interns1.py
+++ b/vllm/model_executor/models/interns1.py
@@ -25,7 +25,7 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
- MultiModalKwargsItems, NestedTensors)
+ MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -39,7 +39,7 @@
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
-from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
@@ -304,7 +304,7 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
images = mm_data.pop("images", [])
@@ -342,7 +342,7 @@ def _call_hf_processor(
image_placeholder, 1)
num_patches = [len(item) for item in image_pixel_values]
- image_outputs: dict[str, NestedTensors] = {
+ image_outputs = {
"pixel_values": torch.concat(image_pixel_values),
"image_num_patches": torch.tensor(num_patches),
"image_token_id": torch.tensor(hf_processor.image_token_id),
@@ -370,7 +370,7 @@ def _call_hf_processor(
video_placeholder, 1)
num_frames = [len(item) for item in video_pixel_values]
- video_outputs: dict[str, NestedTensors] = {
+ video_outputs = {
"pixel_values_videos": torch.concat(video_pixel_values),
"video_num_patches": torch.tensor(num_frames),
"video_token_id": torch.tensor(video_token_id),
@@ -382,16 +382,11 @@ def _call_hf_processor(
prompt)
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")
- combined_outputs = dict(
- **text_outputs,
- **image_outputs,
- **video_outputs,
- )
- return BatchFeature(combined_outputs)
+ return BatchFeature({**text_outputs, **image_outputs, **video_outputs})
def _get_mm_fields_config(
self,
- hf_inputs: Mapping[str, NestedTensors],
+ hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
@@ -487,6 +482,7 @@ def get_replacement_interns1_video(item_idx: int):
dummy_inputs=InternS1DummyInputsBuilder)
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA):
+ merge_by_field_config = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@@ -561,7 +557,7 @@ def _init_vision_model(
prefix=prefix,
)
- def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
return InternS1MultiModalProjector(config)
def pixel_shuffle(self, x, scale_factor=0.5):
@@ -599,13 +595,9 @@ def _parse_and_validate_image_input(
return None
if image_embeds is not None:
- if not isinstance(image_embeds, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image embeddings. "
- f"Got type: {type(image_embeds)}")
-
return InternS1ImageEmbeddingInputs(
type="image_embeds",
- data=flatten_bn(image_embeds),
+ data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@@ -613,17 +605,6 @@ def _parse_and_validate_image_input(
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values is not None:
- if not isinstance(pixel_values, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values)}")
-
- if not isinstance(image_num_patches, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image_num_patches. "
- f"Got type: {type(image_num_patches)}")
-
- pixel_values = flatten_bn(pixel_values, concat=True)
- image_num_patches = flatten_bn(image_num_patches, concat=True)
-
h, w = self.config.vision_config.image_size
return InternS1ImagePixelInputs(
type="pixel_values",
@@ -638,7 +619,7 @@ def _parse_and_validate_image_input(
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
- self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
+ self, **kwargs: object) -> Optional[InternS1VideoInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None)
@@ -647,13 +628,9 @@ def _parse_and_validate_video_input(
return None
if video_embeds is not None:
- if not isinstance(video_embeds, (torch.Tensor, list)):
- raise ValueError("Incorrect type of video embeddings. "
- f"Got type: {type(video_embeds)}")
-
- return InternS1ImageEmbeddingInputs(
+ return InternS1VideoEmbeddingInputs(
type="video_embeds",
- data=flatten_bn(video_embeds),
+ data=video_embeds,
)
video_token_id = kwargs["video_token_id"]
@@ -661,18 +638,6 @@ def _parse_and_validate_video_input(
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
- if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values_flat_video)}")
-
- if not isinstance(video_num_patches, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image_num_patches. "
- f"Got type: {type(video_num_patches)}")
-
- pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
- concat=True)
- video_num_patches = flatten_bn(video_num_patches, concat=True)
-
h, w = self.config.vision_config.image_size
return InternS1VideoPixelInputs(
type="pixel_values_videos",
@@ -686,11 +651,12 @@ def _parse_and_validate_video_input(
raise AssertionError("This line should be unreachable.")
- def _process_image_input(
+ def _process_vision_input(
self,
- image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
+ image_input: Union[InternS1ImageInputs, InternS1VideoInputs],
) -> tuple[torch.Tensor, ...]:
- if image_input["type"] == "image_embeds":
+ if (image_input["type"] == "image_embeds"
+ or image_input["type"] == "video_embeds"):
return image_input["data"]
assert self.vision_tower is not None
@@ -753,11 +719,11 @@ def get_multimodal_embeddings(self,
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
- vision_embeddings = self._process_image_input(image_input)
+ vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
- video_embeddings = self._process_image_input(video_input)
+ video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index 0c95c49f90b1..1f3224f9ac58 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -17,7 +17,7 @@
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
-from transformers import BatchEncoding, PretrainedConfig, TensorType
+from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -28,7 +28,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
- MultiModalKwargsItems, NestedTensors)
+ MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -42,8 +42,7 @@
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
-from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
- maybe_prefix)
+from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = '
'
IMG_END = ''
@@ -471,7 +470,7 @@ def _preprocess_image(
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
- image_inputs: dict[str, NestedTensors] = {
+ image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@@ -502,7 +501,7 @@ def __call__(
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
text, images = [self._make_batch_input(x) for x in (text, images)]
text, image_inputs = self._preprocess_image(
@@ -515,10 +514,9 @@ def __call__(
text_inputs = self.tokenizer(text)
- return {
- **BatchEncoding(text_inputs, tensor_type=return_tensors),
- **image_inputs,
- }
+ combined_outputs = {**text_inputs, **image_inputs}
+
+ return BatchFeature(combined_outputs, tensor_type=return_tensors)
class InternVLProcessor(BaseInternVLProcessor):
@@ -598,7 +596,7 @@ def _preprocess_video(
videos,
dynamic_image_size=dynamic_image_size,
)
- video_inputs: dict[str, NestedTensors] = {
+ video_inputs = {
"pixel_values_flat_video":
torch.cat(pixel_values_lst_video),
"video_num_patches":
@@ -622,7 +620,7 @@ def __call__(
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
text, images, videos = [
self._make_batch_input(x) for x in (text, images, videos)
]
@@ -643,11 +641,9 @@ def __call__(
text_inputs = self.tokenizer(text)
- return {
- **BatchEncoding(text_inputs, tensor_type=return_tensors),
- **image_inputs,
- **video_inputs,
- }
+ combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
+
+ return BatchFeature(combined_outputs, tensor_type=return_tensors)
def get_image_repl(
self,
@@ -773,7 +769,7 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
@@ -793,7 +789,7 @@ def _call_hf_processor(
def _get_mm_fields_config(
self,
- hf_inputs: Mapping[str, NestedTensors],
+ hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@@ -948,7 +944,7 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs, tok_kwargs)
@@ -960,7 +956,7 @@ def _call_hf_processor(
def _get_mm_fields_config(
self,
- hf_inputs: Mapping[str, NestedTensors],
+ hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs,
@@ -1033,6 +1029,7 @@ def get_video_replacement_internvl(item_idx: int):
dummy_inputs=InternVLDummyInputsBuilder)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
+ merge_by_field_config = True
supports_encoder_tp_data = True
@@ -1126,7 +1123,7 @@ def _init_vision_model(
else:
return InternVisionPatchModel(config.vision_config)
- def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
@@ -1175,13 +1172,9 @@ def _parse_and_validate_image_input(
return None
if image_embeds is not None:
- if not isinstance(image_embeds, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image embeddings. "
- f"Got type: {type(image_embeds)}")
-
return InternVLImageEmbeddingInputs(
type="image_embeds",
- data=flatten_bn(image_embeds),
+ data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@@ -1189,16 +1182,6 @@ def _parse_and_validate_image_input(
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
- if not isinstance(pixel_values_flat, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values_flat)}")
-
- if not isinstance(image_num_patches, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image_num_patches. "
- f"Got type: {type(image_num_patches)}")
-
- pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
- image_num_patches = flatten_bn(image_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w}
@@ -1223,7 +1206,7 @@ def _parse_and_validate_video_input(
if video_embeds is not None:
return InternVLVideoEmbeddingInputs(
type="video_embeds",
- data=flatten_bn(video_embeds),
+ data=video_embeds,
)
video_token_id = kwargs["video_token_id"]
@@ -1231,17 +1214,6 @@ def _parse_and_validate_video_input(
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
- if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values_flat_video)}")
-
- if not isinstance(video_num_patches, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image_num_patches. "
- f"Got type: {type(video_num_patches)}")
-
- pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
- concat=True)
- video_num_patches = flatten_bn(video_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w}
@@ -1254,11 +1226,12 @@ def _parse_and_validate_video_input(
raise AssertionError("This line should be unreachable.")
- def _process_image_input(
+ def _process_vision_input(
self,
- image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs],
+ image_input: Union[InternVLImageInputs, InternVLVideoInputs],
) -> tuple[torch.Tensor, ...]:
- if image_input["type"] == "image_embeds":
+ if (image_input["type"] == "image_embeds"
+ or image_input["type"] == "video_embeds"):
return image_input["data"]
assert self.vision_model is not None
@@ -1326,11 +1299,11 @@ def get_multimodal_embeddings(self,
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
- vision_embeddings = self._process_image_input(image_input)
+ vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
- video_embeddings = self._process_image_input(video_input)
+ video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index a6081d331511..e15dc43ec824 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -578,6 +578,11 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
+ """Override to return default layers for Llama
+
+ Note: The GPU model runner will override this with layers from
+ the speculative config if available, providing dynamic configuration.
+ """
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py
index 3fb6f2f8d5ec..4e4d8d21d057 100644
--- a/vllm/model_executor/models/llama_eagle3.py
+++ b/vllm/model_executor/models/llama_eagle3.py
@@ -20,6 +20,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
+from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
@@ -242,7 +243,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
requires_grad=False,
)
- def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ is_multimodal: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py
index db5a9fbc6a33..ffa659a5c3f9 100644
--- a/vllm/model_executor/models/mllama4.py
+++ b/vllm/model_executor/models/mllama4.py
@@ -54,7 +54,8 @@
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
-from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .interfaces import (MultiModalEmbeddings, SupportsEagle3,
+ SupportsMultiModal, SupportsPP)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .vision import run_dp_sharded_vision_model
@@ -708,8 +709,8 @@ def get_dummy_mm_data(
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
-class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
- SupportsPP):
+class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
+ SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -758,6 +759,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
+ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
+ """Set which layers should output auxiliary hidden states for EAGLE3."""
+ # Delegate to underlying language model (Llama4ForCausalLM)
+ assert hasattr(self.language_model, 'set_aux_hidden_state_layers')
+ self.language_model.set_aux_hidden_state_layers(layers)
+
+ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
+ """Get the layer indices for auxiliary hidden state outputs.
+
+ Note: The GPU model runner will override this with layers from
+ the speculative config if available, providing dynamic configuration.
+ """
+ # Delegate to underlying language model (Llama4ForCausalLM)
+ assert hasattr(self.language_model,
+ 'get_eagle3_aux_hidden_state_layers')
+ return self.language_model.get_eagle3_aux_hidden_state_layers()
+
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
# num_images, 1, num_chunks, channel, image_size, image_size
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 2d0ebdc90277..b1d59f77f59d 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -18,8 +18,7 @@
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
-from transformers import (BatchEncoding, BatchFeature, PretrainedConfig,
- TensorType)
+from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation
@@ -38,8 +37,7 @@
maybe_prefix)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
- MultiModalKwargs, MultiModalKwargsItems,
- NestedTensors)
+ MultiModalKwargs, MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -298,7 +296,7 @@ def _preprocess_image(
else:
pixel_values_lst = self._images_to_pixel_values_lst(
images, max_num_tiles)
- image_inputs: dict[str, NestedTensors] = {
+ image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@@ -326,7 +324,7 @@ def __call__(
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
max_num_tiles: Optional[int] = None,
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
# Use default if not provided
if max_num_tiles is None:
max_num_tiles = 12
@@ -341,10 +339,9 @@ def __call__(
text_inputs = self.tokenizer(text, add_special_tokens=False)
- return {
- **BatchEncoding(text_inputs, tensor_type=return_tensors),
- **image_inputs,
- }
+ combined_outputs = {**text_inputs, **image_inputs}
+
+ return BatchFeature(combined_outputs, tensor_type=return_tensors)
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
@@ -420,7 +417,7 @@ def _preprocess_video(
dynamic_image_size=dynamic_image_size,
)
- video_inputs: dict[str, NestedTensors] = {
+ video_inputs = {
"pixel_values_flat_video":
torch.cat(pixel_values_lst_video),
"video_num_patches":
@@ -443,7 +440,7 @@ def __call__(
return_tensors: Optional[Union[str, TensorType]] = None,
max_num_tiles: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
# Use default if not provided
if max_num_tiles is None:
max_num_tiles = 12
@@ -467,11 +464,9 @@ def __call__(
text_inputs = self.tokenizer(text, add_special_tokens=False)
- return BatchFeature({
- **BatchEncoding(text_inputs, tensor_type=return_tensors),
- **image_inputs,
- **video_inputs,
- })
+ combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
+
+ return BatchFeature(combined_outputs, tensor_type=return_tensors)
def get_image_repl(
self,
@@ -625,7 +620,7 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
@@ -645,7 +640,7 @@ def _call_hf_processor(
def _get_mm_fields_config(
self,
- hf_inputs: Mapping[str, NestedTensors],
+ hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@@ -724,7 +719,7 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs, tok_kwargs)
@@ -736,7 +731,7 @@ def _call_hf_processor(
def _get_mm_fields_config(
self,
- hf_inputs: Mapping[str, NestedTensors],
+ hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs,
diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py
index 0e7ec8e458cf..e6c4c5b022dc 100644
--- a/vllm/model_executor/models/nemotron_vl.py
+++ b/vllm/model_executor/models/nemotron_vl.py
@@ -28,7 +28,6 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
-from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import (
@@ -37,8 +36,7 @@
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
-from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
- maybe_prefix)
+from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = '
'
IMG_END = ''
@@ -289,7 +287,7 @@ def _preprocess_image(
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
- image_inputs: dict[str, NestedTensors] = {
+ image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@@ -344,6 +342,7 @@ def get_image_processor(self, **kwargs: object):
dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo])
class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
+ merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -414,7 +413,7 @@ def _init_vision_model(
return AutoModel.from_config(config.vision_config,
trust_remote_code=True)
- def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vit_hidden_size
vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.text_config.hidden_size
@@ -467,13 +466,9 @@ def _parse_and_validate_image_input(
return None
if image_embeds is not None:
- if not isinstance(image_embeds, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image embeddings. "
- f"Got type: {type(image_embeds)}")
-
return InternVLImageEmbeddingInputs(
type="image_embeds",
- data=flatten_bn(image_embeds),
+ data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@@ -481,17 +476,6 @@ def _parse_and_validate_image_input(
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
- if not isinstance(pixel_values_flat, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values_flat)}")
-
- if not isinstance(image_num_patches, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image_num_patches. "
- f"Got type: {type(image_num_patches)}")
-
- pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
- image_num_patches = flatten_bn(image_num_patches, concat=True)
-
return InternVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=pixel_values_flat,
diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py
index 3bbf4c67604c..0f993b0dc62f 100644
--- a/vllm/model_executor/models/nvlm_d.py
+++ b/vllm/model_executor/models/nvlm_d.py
@@ -159,7 +159,7 @@ def get_replacement_nvlm(item_idx: int):
dummy_inputs=NVLMDummyInputsBuilder)
class NVLM_D_Model(InternVLChatModel):
- def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_intermediate_size = config.text_config.intermediate_size
llm_hidden_size = config.text_config.hidden_size
diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py
index f03022aa719c..8556c3847041 100644
--- a/vllm/model_executor/models/skyworkr1v.py
+++ b/vllm/model_executor/models/skyworkr1v.py
@@ -14,7 +14,7 @@
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
-from transformers import BatchEncoding, PretrainedConfig, TensorType
+from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
@@ -25,7 +25,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
- MultiModalKwargsItems, NestedTensors)
+ MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -37,8 +37,7 @@
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
-from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
- maybe_prefix)
+from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = '
'
IMG_END = ''
@@ -399,7 +398,7 @@ def __call__(
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
@@ -418,7 +417,7 @@ def __call__(
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
- image_inputs: dict[str, NestedTensors] = {
+ image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@@ -435,10 +434,9 @@ def __call__(
text_inputs = self.tokenizer(text)
- return {
- **BatchEncoding(text_inputs, tensor_type=return_tensors),
- **image_inputs,
- }
+ combined_outputs = {**text_inputs, **image_inputs}
+
+ return BatchFeature(combined_outputs, tensor_type=return_tensors)
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
@@ -529,7 +527,7 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
+ ) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
@@ -549,7 +547,7 @@ def _call_hf_processor(
def _get_mm_fields_config(
self,
- hf_inputs: Mapping[str, NestedTensors],
+ hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@@ -617,6 +615,7 @@ def get_replacement_skyworkr1v(item_idx: int):
info=SkyworkR1VProcessingInfo,
dummy_inputs=SkyworkR1VDummyInputsBuilder)
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
+ merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -703,7 +702,7 @@ def _init_vision_model(
else:
return InternVisionPatchModel(config.vision_config)
- def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
@@ -756,13 +755,9 @@ def _parse_and_validate_image_input(
return None
if image_embeds is not None:
- if not isinstance(image_embeds, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image embeddings. "
- f"Got type: {type(image_embeds)}")
-
return SkyworkR1VImageEmbeddingInputs(
type="image_embeds",
- data=flatten_bn(image_embeds),
+ data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@@ -770,17 +765,6 @@ def _parse_and_validate_image_input(
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
- if not isinstance(pixel_values_flat, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values_flat)}")
-
- if not isinstance(image_num_patches, (torch.Tensor, list)):
- raise ValueError("Incorrect type of image_num_patches. "
- f"Got type: {type(image_num_patches)}")
-
- pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
- image_num_patches = flatten_bn(image_num_patches, concat=True)
-
return SkyworkR1VImagePixelInputs(
type="pixel_values",
pixel_values_flat=pixel_values_flat,
diff --git a/vllm/outputs.py b/vllm/outputs.py
index 4d8206bb2d83..1ed20461def1 100644
--- a/vllm/outputs.py
+++ b/vllm/outputs.py
@@ -14,6 +14,7 @@
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sequence import RequestMetrics
+from vllm.v1.metrics.stats import RequestStateStats
logger = init_logger(__name__)
@@ -108,7 +109,7 @@ def __init__(
prompt_logprobs: Optional[PromptLogprobs],
outputs: list[CompletionOutput],
finished: bool,
- metrics: Optional[RequestMetrics] = None,
+ metrics: Optional[Union[RequestMetrics, RequestStateStats]] = None,
lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[list[int]] = None,
diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py
index efc87b6bcf26..73d9f87527b5 100644
--- a/vllm/transformers_utils/configs/speculators/algos.py
+++ b/vllm/transformers_utils/configs/speculators/algos.py
@@ -17,11 +17,15 @@ def decorator(fn):
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
"""
Apply Eagle-3 specific configuration transformations.
-
+
Eagle-3 specific fields:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection
+ - eagle_aux_hidden_state_layer_ids: List of layer indices from the base
+ model to use as auxiliary inputs for the Eagle3 drafter. These layers
+ provide intermediate hidden states that help the drafter make better
+ predictions. This is the standard field used in Eagle3 checkpoints.
"""
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
@@ -30,3 +34,6 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
vllm_config["norm_before_residual"] = config_dict.get(
"norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
+ if config_dict.get("eagle_aux_hidden_state_layer_ids"):
+ vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
+ "eagle_aux_hidden_state_layer_ids"]
diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py
index 38b2d6824b47..46cb97d4e7b5 100644
--- a/vllm/v1/engine/output_processor.py
+++ b/vllm/v1/engine/output_processor.py
@@ -248,16 +248,15 @@ def _new_request_output(
if prompt_token_ids is None and self.prompt_embeds is not None:
prompt_token_ids = [0] * len(self.prompt_embeds)
- return RequestOutput(
- request_id=request_id,
- prompt=self.prompt,
- prompt_token_ids=prompt_token_ids,
- prompt_logprobs=prompt_logprobs,
- outputs=cast(list[CompletionOutput], outputs),
- finished=finished,
- kv_transfer_params=kv_transfer_params,
- num_cached_tokens=self.num_cached_tokens,
- )
+ return RequestOutput(request_id=request_id,
+ prompt=self.prompt,
+ prompt_token_ids=prompt_token_ids,
+ prompt_logprobs=prompt_logprobs,
+ outputs=cast(list[CompletionOutput], outputs),
+ finished=finished,
+ kv_transfer_params=kv_transfer_params,
+ num_cached_tokens=self.num_cached_tokens,
+ metrics=self.stats)
def _new_completion_output(
self,
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 8b92cb052efd..ace6d8c100dd 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -2746,13 +2746,23 @@ def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
- if supports_eagle3(self.model):
- self.model.set_aux_hidden_state_layers(
- self.model.get_eagle3_aux_hidden_state_layers())
- else:
+ if not supports_eagle3(self.model):
raise RuntimeError(
"Model does not support EAGLE3 interface but "
"aux_hidden_state_outputs was requested")
+
+ # Try to get auxiliary layers from speculative config,
+ # otherwise use model's default layers
+ aux_layers = self._get_eagle3_aux_layers_from_config()
+ if aux_layers:
+ logger.info(
+ "Using auxiliary layers from speculative config: %s",
+ aux_layers)
+ else:
+ aux_layers = self.model.get_eagle3_aux_hidden_state_layers(
+ )
+
+ self.model.set_aux_hidden_state_layers(aux_layers)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
@@ -2803,6 +2813,31 @@ def load_model(self, eep_scale_up: bool = False) -> None:
self.model = UBatchWrapper(self.model, self.vllm_config,
CUDAGraphMode.NONE, self.device)
+ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
+ """Extract Eagle3 auxiliary layer indices from speculative config.
+
+ These indices specify which hidden states from the base model should
+ be used as auxiliary inputs for the Eagle3 drafter model during
+ speculative decoding.
+
+ Returns:
+ Tuple of layer indices if found in draft model config,
+ None otherwise.
+ """
+ if not (self.speculative_config
+ and self.speculative_config.draft_model_config):
+ return None
+
+ hf_config = self.speculative_config.draft_model_config.hf_config
+ if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'):
+ return None
+
+ layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
+ if layer_ids and isinstance(layer_ids, (list, tuple)):
+ return tuple(layer_ids)
+
+ return None
+
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
"Cannot reload weights before model is loaded."