diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5ca2156c08b5..c705a70b93f5 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | -| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index f8ddb5a22b31..1d6d819ff58a 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -576,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: # Intern-S1 def run_interns1(questions: list[str], modality: str) -> ModelRequestData: - model_name = "internlm/Intern-S1" + model_name = "internlm/Intern-S1-mini" engine_args = EngineArgs( model=model_name, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 51b41f34b2ff..e0d95758a822 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData: def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "internlm/Intern-S1" + model_name = "internlm/Intern-S1-mini" engine_args = EngineArgs( model=model_name, diff --git a/tests/conftest.py b/tests/conftest.py index 66106d1bf779..c61a8f8dd539 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -756,7 +756,7 @@ def __init__( def get_inputs( self, - prompts: Union[list[str], list[torch.Tensor], list[int]], + prompts: Union[list[str], list[torch.Tensor], list[list[int]]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index e0ecb02d4f56..5af4327b65d0 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -86,3 +86,16 @@ def test_max_model_len(): # It can be less if generation finishes due to other reasons (e.g., EOS) # before reaching the absolute model length limit. assert num_total_tokens <= max_model_len + + +def test_log_stats(): + llm = LLM( + model=MODEL_NAME, + disable_log_stats=False, + gpu_memory_utilization=0.10, + enforce_eager=True, # reduce test time + ) + outputs = llm.generate(PROMPTS, sampling_params=None) + + # disable_log_stats is False, every output should have metrics + assert all(output.metrics is not None for output in outputs) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index e60a86075b8b..9d67b46f2e3e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -240,12 +240,12 @@ def test_distributed_correctness( num_logprobs: int, ) -> None: with vllm_runner(model, tensor_parallel_size=1, - max_num_seqs=2) as vllm_model: + max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) with vllm_runner(model, tensor_parallel_size=2, - max_num_seqs=2) as vllm_model: + max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 0292845f819c..e5caf0eae37d 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -25,7 +25,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -39,7 +39,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, +from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix) @@ -304,7 +304,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: mm_data = dict(mm_data) videos = mm_data.pop("videos", []) images = mm_data.pop("images", []) @@ -342,7 +342,7 @@ def _call_hf_processor( image_placeholder, 1) num_patches = [len(item) for item in image_pixel_values] - image_outputs: dict[str, NestedTensors] = { + image_outputs = { "pixel_values": torch.concat(image_pixel_values), "image_num_patches": torch.tensor(num_patches), "image_token_id": torch.tensor(hf_processor.image_token_id), @@ -370,7 +370,7 @@ def _call_hf_processor( video_placeholder, 1) num_frames = [len(item) for item in video_pixel_values] - video_outputs: dict[str, NestedTensors] = { + video_outputs = { "pixel_values_videos": torch.concat(video_pixel_values), "video_num_patches": torch.tensor(num_frames), "video_token_id": torch.tensor(video_token_id), @@ -382,16 +382,11 @@ def _call_hf_processor( prompt) text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt") - combined_outputs = dict( - **text_outputs, - **image_outputs, - **video_outputs, - ) - return BatchFeature(combined_outputs) + return BatchFeature({**text_outputs, **image_outputs, **video_outputs}) def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: @@ -487,6 +482,7 @@ def get_replacement_interns1_video(item_idx: int): dummy_inputs=InternS1DummyInputsBuilder) class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -561,7 +557,7 @@ def _init_vision_model( prefix=prefix, ) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: return InternS1MultiModalProjector(config) def pixel_shuffle(self, x, scale_factor=0.5): @@ -599,13 +595,9 @@ def _parse_and_validate_image_input( return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternS1ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -613,17 +605,6 @@ def _parse_and_validate_image_input( self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values = flatten_bn(pixel_values, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - h, w = self.config.vision_config.image_size return InternS1ImagePixelInputs( type="pixel_values", @@ -638,7 +619,7 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]: + self, **kwargs: object) -> Optional[InternS1VideoInputs]: pixel_values_flat_video = kwargs.pop("pixel_values_videos", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("video_embeds", None) @@ -647,13 +628,9 @@ def _parse_and_validate_video_input( return None if video_embeds is not None: - if not isinstance(video_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") - - return InternS1ImageEmbeddingInputs( + return InternS1VideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) video_token_id = kwargs["video_token_id"] @@ -661,18 +638,6 @@ def _parse_and_validate_video_input( self.video_context_token_id = video_token_id.flatten().unique().item() if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}") - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}") - - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, - concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) - h, w = self.config.vision_config.image_size return InternS1VideoPixelInputs( type="pixel_values_videos", @@ -686,11 +651,12 @@ def _parse_and_validate_video_input( raise AssertionError("This line should be unreachable.") - def _process_image_input( + def _process_vision_input( self, - image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs], + image_input: Union[InternS1ImageInputs, InternS1VideoInputs], ) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": + if (image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds"): return image_input["data"] assert self.vision_tower is not None @@ -753,11 +719,11 @@ def get_multimodal_embeddings(self, for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + vision_embeddings = self._process_vision_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] - video_embeddings = self._process_image_input(video_input) + video_embeddings = self._process_vision_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 0c95c49f90b1..1f3224f9ac58 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,7 +17,7 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import BatchEncoding, PretrainedConfig, TensorType +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig @@ -28,7 +28,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -42,8 +42,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix IMG_START = '' IMG_END = '' @@ -471,7 +470,7 @@ def _preprocess_image( max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -502,7 +501,7 @@ def __call__( max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: text, images = [self._make_batch_input(x) for x in (text, images)] text, image_inputs = self._preprocess_image( @@ -515,10 +514,9 @@ def __call__( text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class InternVLProcessor(BaseInternVLProcessor): @@ -598,7 +596,7 @@ def _preprocess_video( videos, dynamic_image_size=dynamic_image_size, ) - video_inputs: dict[str, NestedTensors] = { + video_inputs = { "pixel_values_flat_video": torch.cat(pixel_values_lst_video), "video_num_patches": @@ -622,7 +620,7 @@ def __call__( max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: text, images, videos = [ self._make_batch_input(x) for x in (text, images, videos) ] @@ -643,11 +641,9 @@ def __call__( text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - **video_inputs, - } + combined_outputs = {**text_inputs, **image_inputs, **video_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) def get_image_repl( self, @@ -773,7 +769,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -793,7 +789,7 @@ def _call_hf_processor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -948,7 +944,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs) @@ -960,7 +956,7 @@ def _call_hf_processor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_fields = super()._get_mm_fields_config(hf_inputs, @@ -1033,6 +1029,7 @@ def get_video_replacement_internvl(item_idx: int): dummy_inputs=InternVLDummyInputsBuilder) class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True supports_encoder_tp_data = True @@ -1126,7 +1123,7 @@ def _init_vision_model( else: return InternVisionPatchModel(config.vision_config) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size @@ -1175,13 +1172,9 @@ def _parse_and_validate_image_input( return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -1189,16 +1182,6 @@ def _parse_and_validate_image_input( self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) expected_h = expected_w = self.config.vision_config.image_size resolve_bindings = {"h": expected_h, "w": expected_w} @@ -1223,7 +1206,7 @@ def _parse_and_validate_video_input( if video_embeds is not None: return InternVLVideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) video_token_id = kwargs["video_token_id"] @@ -1231,17 +1214,6 @@ def _parse_and_validate_video_input( self.video_context_token_id = video_token_id.flatten().unique().item() if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}") - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}") - - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, - concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) expected_h = expected_w = self.config.vision_config.image_size resolve_bindings = {"h": expected_h, "w": expected_w} @@ -1254,11 +1226,12 @@ def _parse_and_validate_video_input( raise AssertionError("This line should be unreachable.") - def _process_image_input( + def _process_vision_input( self, - image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs], + image_input: Union[InternVLImageInputs, InternVLVideoInputs], ) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": + if (image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds"): return image_input["data"] assert self.vision_model is not None @@ -1326,11 +1299,11 @@ def get_multimodal_embeddings(self, for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + vision_embeddings = self._process_vision_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] - video_embeddings = self._process_image_input(video_input) + video_embeddings = self._process_vision_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a6081d331511..e15dc43ec824 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -578,6 +578,11 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Override to return default layers for Llama + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 3fb6f2f8d5ec..4e4d8d21d057 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -20,6 +20,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) +from vllm.multimodal.inputs import NestedTensors from .utils import AutoWeightsLoader, maybe_prefix @@ -242,7 +243,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): requires_grad=False, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + is_multimodal: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index db5a9fbc6a33..ffa659a5c3f9 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -54,7 +54,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsEagle3, + SupportsMultiModal, SupportsPP) from .llama4 import Llama4ForCausalLM from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -708,8 +709,8 @@ def get_dummy_mm_data( info=Mllama4ProcessingInfo, dummy_inputs=Mllama4DummyInputsBuilder, ) -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, + SupportsEagle3): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -758,6 +759,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """Set which layers should output auxiliary hidden states for EAGLE3.""" + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr(self.language_model, 'set_aux_hidden_state_layers') + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Get the layer indices for auxiliary hidden state outputs. + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr(self.language_model, + 'get_eagle3_aux_hidden_state_layers') + return self.language_model.get_eagle3_aux_hidden_state_layers() + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: # num_images, 1, num_chunks, channel, image_size, image_size diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 2d0ebdc90277..b1d59f77f59d 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -18,8 +18,7 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import (BatchEncoding, BatchFeature, PretrainedConfig, - TensorType) +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.activation import ReLUSquaredActivation @@ -38,8 +37,7 @@ maybe_prefix) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, MultiModalKwargsItems, - NestedTensors) + MultiModalKwargs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -298,7 +296,7 @@ def _preprocess_image( else: pixel_values_lst = self._images_to_pixel_values_lst( images, max_num_tiles) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -326,7 +324,7 @@ def __call__( images: Optional[Union[Image.Image, list[Image.Image]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, max_num_tiles: Optional[int] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: max_num_tiles = 12 @@ -341,10 +339,9 @@ def __call__( text_inputs = self.tokenizer(text, add_special_tokens=False) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): @@ -420,7 +417,7 @@ def _preprocess_video( dynamic_image_size=dynamic_image_size, ) - video_inputs: dict[str, NestedTensors] = { + video_inputs = { "pixel_values_flat_video": torch.cat(pixel_values_lst_video), "video_num_patches": @@ -443,7 +440,7 @@ def __call__( return_tensors: Optional[Union[str, TensorType]] = None, max_num_tiles: Optional[int] = None, dynamic_image_size: Optional[bool] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: max_num_tiles = 12 @@ -467,11 +464,9 @@ def __call__( text_inputs = self.tokenizer(text, add_special_tokens=False) - return BatchFeature({ - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - **video_inputs, - }) + combined_outputs = {**text_inputs, **image_inputs, **video_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) def get_image_repl( self, @@ -625,7 +620,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -645,7 +640,7 @@ def _call_hf_processor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -724,7 +719,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs) @@ -736,7 +731,7 @@ def _call_hf_processor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_fields = super()._get_mm_fields_config(hf_inputs, diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 0e7ec8e458cf..e6c4c5b022dc 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -28,7 +28,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode -from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processor import ( @@ -37,8 +36,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix IMG_START = '' IMG_END = '' @@ -289,7 +287,7 @@ def _preprocess_image( max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -344,6 +342,7 @@ def get_image_processor(self, **kwargs: object): dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo]) class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -414,7 +413,7 @@ def _init_vision_model( return AutoModel.from_config(config.vision_config, trust_remote_code=True) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vit_hidden_size vision_projection_hidden_size = config.projector_hidden_size llm_hidden_size = config.text_config.hidden_size @@ -467,13 +466,9 @@ def _parse_and_validate_image_input( return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -481,17 +476,6 @@ def _parse_and_validate_image_input( self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return InternVLImagePixelInputs( type="pixel_values", pixel_values_flat=pixel_values_flat, diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 3bbf4c67604c..0f993b0dc62f 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -159,7 +159,7 @@ def get_replacement_nvlm(item_idx: int): dummy_inputs=NVLMDummyInputsBuilder) class NVLM_D_Model(InternVLChatModel): - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_intermediate_size = config.text_config.intermediate_size llm_hidden_size = config.text_config.hidden_size diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index f03022aa719c..8556c3847041 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -14,7 +14,7 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import BatchEncoding, PretrainedConfig, TensorType +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear @@ -25,7 +25,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix IMG_START = '' IMG_END = '' @@ -399,7 +398,7 @@ def __call__( max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: if text is None: text = [] if not isinstance(text, list): @@ -418,7 +417,7 @@ def __call__( max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { + image_inputs = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": @@ -435,10 +434,9 @@ def __call__( text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class SkyworkR1VProcessingInfo(BaseProcessingInfo): @@ -529,7 +527,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -549,7 +547,7 @@ def _call_hf_processor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -617,6 +615,7 @@ def get_replacement_skyworkr1v(item_idx: int): info=SkyworkR1VProcessingInfo, dummy_inputs=SkyworkR1VDummyInputsBuilder) class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -703,7 +702,7 @@ def _init_vision_model( else: return InternVisionPatchModel(config.vision_config) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size @@ -756,13 +755,9 @@ def _parse_and_validate_image_input( return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return SkyworkR1VImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -770,17 +765,6 @@ def _parse_and_validate_image_input( self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return SkyworkR1VImagePixelInputs( type="pixel_values", pixel_values_flat=pixel_values_flat, diff --git a/vllm/outputs.py b/vllm/outputs.py index 4d8206bb2d83..1ed20461def1 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sequence import RequestMetrics +from vllm.v1.metrics.stats import RequestStateStats logger = init_logger(__name__) @@ -108,7 +109,7 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], outputs: list[CompletionOutput], finished: bool, - metrics: Optional[RequestMetrics] = None, + metrics: Optional[Union[RequestMetrics, RequestStateStats]] = None, lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[list[int]] = None, diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index efc87b6bcf26..73d9f87527b5 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -17,11 +17,15 @@ def decorator(fn): def update_eagle3(config_dict: dict, vllm_config: dict) -> None: """ Apply Eagle-3 specific configuration transformations. - + Eagle-3 specific fields: - draft_vocab_size: Size of the draft model's vocabulary - target_hidden_size: Hidden size of the target model - norm_before_residual: Whether to apply norm before residual connection + - eagle_aux_hidden_state_layer_ids: List of layer indices from the base + model to use as auxiliary inputs for the Eagle3 drafter. These layers + provide intermediate hidden states that help the drafter make better + predictions. This is the standard field used in Eagle3 checkpoints. """ vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") @@ -30,3 +34,6 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None: vllm_config["norm_before_residual"] = config_dict.get( "norm_before_residual", True) vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] + if config_dict.get("eagle_aux_hidden_state_layer_ids"): + vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ + "eagle_aux_hidden_state_layer_ids"] diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 38b2d6824b47..46cb97d4e7b5 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -248,16 +248,15 @@ def _new_request_output( if prompt_token_ids is None and self.prompt_embeds is not None: prompt_token_ids = [0] * len(self.prompt_embeds) - return RequestOutput( - request_id=request_id, - prompt=self.prompt, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=prompt_logprobs, - outputs=cast(list[CompletionOutput], outputs), - finished=finished, - kv_transfer_params=kv_transfer_params, - num_cached_tokens=self.num_cached_tokens, - ) + return RequestOutput(request_id=request_id, + prompt=self.prompt, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=prompt_logprobs, + outputs=cast(list[CompletionOutput], outputs), + finished=finished, + kv_transfer_params=kv_transfer_params, + num_cached_tokens=self.num_cached_tokens, + metrics=self.stats) def _new_completion_output( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8b92cb052efd..ace6d8c100dd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2746,13 +2746,23 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - if supports_eagle3(self.model): - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - else: + if not supports_eagle3(self.model): raise RuntimeError( "Model does not support EAGLE3 interface but " "aux_hidden_state_outputs was requested") + + # Try to get auxiliary layers from speculative config, + # otherwise use model's default layers + aux_layers = self._get_eagle3_aux_layers_from_config() + if aux_layers: + logger.info( + "Using auxiliary layers from speculative config: %s", + aux_layers) + else: + aux_layers = self.model.get_eagle3_aux_hidden_state_layers( + ) + + self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds", @@ -2803,6 +2813,31 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model = UBatchWrapper(self.model, self.vllm_config, CUDAGraphMode.NONE, self.device) + def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: + """Extract Eagle3 auxiliary layer indices from speculative config. + + These indices specify which hidden states from the base model should + be used as auxiliary inputs for the Eagle3 drafter model during + speculative decoding. + + Returns: + Tuple of layer indices if found in draft model config, + None otherwise. + """ + if not (self.speculative_config + and self.speculative_config.draft_model_config): + return None + + hf_config = self.speculative_config.draft_model_config.hf_config + if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'): + return None + + layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) + + return None + def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ "Cannot reload weights before model is loaded."