Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:

# Intern-S1
def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
model_name = "internlm/Intern-S1"
model_name = "internlm/Intern-S1-mini"

engine_args = EngineArgs(
model=model_name,
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:


def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "internlm/Intern-S1"
model_name = "internlm/Intern-S1-mini"

engine_args = EngineArgs(
model=model_name,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def __init__(

def get_inputs(
self,
prompts: Union[list[str], list[torch.Tensor], list[int]],
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
Expand Down
13 changes: 13 additions & 0 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,16 @@ def test_max_model_len():
# It can be less if generation finishes due to other reasons (e.g., EOS)
# before reaching the absolute model length limit.
assert num_total_tokens <= max_model_len


def test_log_stats():
llm = LLM(
model=MODEL_NAME,
disable_log_stats=False,
gpu_memory_utilization=0.10,
enforce_eager=True, # reduce test time
)
outputs = llm.generate(PROMPTS, sampling_params=None)

# disable_log_stats is False, every output should have metrics
assert all(output.metrics is not None for output in outputs)
4 changes: 2 additions & 2 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def test_distributed_correctness(
num_logprobs: int,
) -> None:
with vllm_runner(model, tensor_parallel_size=1,
max_num_seqs=2) as vllm_model:
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, tensor_parallel_size=2,
max_num_seqs=2) as vllm_model:
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

Expand Down
72 changes: 19 additions & 53 deletions vllm/model_executor/models/interns1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
Expand All @@ -39,7 +39,7 @@

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)


Expand Down Expand Up @@ -304,7 +304,7 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
images = mm_data.pop("images", [])
Expand Down Expand Up @@ -342,7 +342,7 @@ def _call_hf_processor(
image_placeholder, 1)

num_patches = [len(item) for item in image_pixel_values]
image_outputs: dict[str, NestedTensors] = {
image_outputs = {
"pixel_values": torch.concat(image_pixel_values),
"image_num_patches": torch.tensor(num_patches),
"image_token_id": torch.tensor(hf_processor.image_token_id),
Expand Down Expand Up @@ -370,7 +370,7 @@ def _call_hf_processor(
video_placeholder, 1)

num_frames = [len(item) for item in video_pixel_values]
video_outputs: dict[str, NestedTensors] = {
video_outputs = {
"pixel_values_videos": torch.concat(video_pixel_values),
"video_num_patches": torch.tensor(num_frames),
"video_token_id": torch.tensor(video_token_id),
Expand All @@ -382,16 +382,11 @@ def _call_hf_processor(
prompt)
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")

combined_outputs = dict(
**text_outputs,
**image_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
return BatchFeature({**text_outputs, **image_outputs, **video_outputs})

def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:

Expand Down Expand Up @@ -487,6 +482,7 @@ def get_replacement_interns1_video(item_idx: int):
dummy_inputs=InternS1DummyInputsBuilder)
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA):
merge_by_field_config = True

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
Expand Down Expand Up @@ -561,7 +557,7 @@ def _init_vision_model(
prefix=prefix,
)

def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
return InternS1MultiModalProjector(config)

def pixel_shuffle(self, x, scale_factor=0.5):
Expand Down Expand Up @@ -599,31 +595,16 @@ def _parse_and_validate_image_input(
return None

if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

return InternS1ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)

image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()

if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")

pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)

h, w = self.config.vision_config.image_size
return InternS1ImagePixelInputs(
type="pixel_values",
Expand All @@ -638,7 +619,7 @@ def _parse_and_validate_image_input(
raise AssertionError("This line should be unreachable.")

def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
self, **kwargs: object) -> Optional[InternS1VideoInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None)
Expand All @@ -647,32 +628,16 @@ def _parse_and_validate_video_input(
return None

if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")

return InternS1ImageEmbeddingInputs(
return InternS1VideoEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
data=video_embeds,
)

video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()

if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")

if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")

pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)

h, w = self.config.vision_config.image_size
return InternS1VideoPixelInputs(
type="pixel_values_videos",
Expand All @@ -686,11 +651,12 @@ def _parse_and_validate_video_input(

raise AssertionError("This line should be unreachable.")

def _process_image_input(
def _process_vision_input(
self,
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
image_input: Union[InternS1ImageInputs, InternS1VideoInputs],
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
if (image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"):
return image_input["data"]

assert self.vision_tower is not None
Expand Down Expand Up @@ -753,11 +719,11 @@ def get_multimodal_embeddings(self,
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings

return multimodal_embeddings
Expand Down
Loading