Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,11 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Override to return default layers for Llama

Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def _init_model(self,
prefix=prefix,
layer_type=layer_type)


def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
Expand Down
69 changes: 61 additions & 8 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
from vllm.multimodal.inputs import NestedTensors

from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings

logger = init_logger(__name__)

Expand Down Expand Up @@ -147,12 +148,20 @@

def forward(
self,
input_ids: torch.Tensor,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
assert hidden_states.shape[-1] == input_embeds.shape[-1]
if inputs_embeds is not None:
input_embeds = inputs_embeds
else:
input_embeds = self.embed_tokens(input_ids)

# Only check dimension compatibility after we have the input embeddings
# For multimodal cases, hidden_states dimensions may differ and need adaptation

Check failure on line 162 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:162:81: E501 Line too long (87 > 80)
if hidden_states.shape[-1] != input_embeds.shape[-1]:
hidden_states = self.fc(hidden_states)

residual = None
hidden_states, residual = self.layers[0](
Expand Down Expand Up @@ -200,6 +209,7 @@
class Eagle3LlamaForCausalLM(LlamaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
logger.info("Eagle3LlamaForCausalLM initialized")
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
Expand Down Expand Up @@ -232,18 +242,61 @@
requires_grad=False,
)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.embed_tokens(input_ids)

# Check if this drafter is configured for text-only inference
inference_type = getattr(self.config, 'inference_type', 'multimodal')

if multimodal_embeddings is not None and inference_type != 'text':
# For Eagle3, multimodal content is already processed by the verifier

Check failure on line 256 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:256:81: E501 Line too long (81 > 80)
# The auxiliary hidden states contain the multimodal context
# So we just return the text embeddings here
# Note: merge_multimodal_embeddings requires image_token_index
if hasattr(self.config, 'image_token_index'):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
elif multimodal_embeddings is not None and inference_type == 'text':
# Text-only drafter: ignore multimodal embeddings
# The verifier handles all multimodal processing, drafter only processes text tokens

Check failure on line 269 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:269:81: E501 Line too long (96 > 80)
pass

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Eagle3 drafter processes auxiliary hidden states from verifier model
# For multimodal inputs, the verifier already processed the multimodal content

Check failure on line 282 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:282:81: E501 Line too long (86 > 80)
# and generated auxiliary hidden states that contain this context.
# This drafter is configured for text-only inference (inference_type: "text")

Check failure on line 284 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:284:81: E501 Line too long (85 > 80)

if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)
# Handle edge cases (e.g., warmup) where pre-computed embeddings are provided

Check failure on line 287 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:287:81: E501 Line too long (89 > 80)
input_embeds = inputs_embeds
else:
# Standard case: use text embeddings for current token prediction
input_embeds = self.model.embed_tokens(input_ids)

# Adapt auxiliary hidden state dimensions if they don't match text embeddings

Check failure on line 293 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:293:81: E501 Line too long (85 > 80)
# Critical for multimodal models where auxiliary hidden states may have different dimensions

Check failure on line 294 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:294:81: E501 Line too long (100 > 80)
if hidden_states.shape[-1] != input_embeds.shape[-1]:
hidden_states = self.model.fc(hidden_states)

# Eagle3 architecture: combines text embeddings + multimodal hidden states

Check failure on line 298 in vllm/model_executor/models/llama_eagle3.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/llama_eagle3.py:298:81: E501 Line too long (82 > 80)
return self.model(None, positions, hidden_states, input_embeds)

def compute_logits(
self,
Expand Down
21 changes: 19 additions & 2 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsEagle3,
SupportsMultiModal, SupportsPP)
from .llama4 import Llama4ForCausalLM
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
Expand Down Expand Up @@ -710,7 +711,7 @@ def get_dummy_mm_data(
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down Expand Up @@ -759,6 +760,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, 'set_aux_hidden_state_layers')
self.language_model.set_aux_hidden_state_layers(layers)

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.

Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, 'get_eagle3_aux_hidden_state_layers')
self.language_model.get_eagle3_aux_hidden_state_layers()

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
# num_images, 1, num_chunks, channel, image_size, image_size
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def __init__(self,
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers = tuple[int, ...]()
self.aux_hidden_state_layers = tuple()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand Down
18 changes: 16 additions & 2 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
SupportsMultiModal, SupportsPP, SupportsQuant, SupportsEagle3)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
Expand Down Expand Up @@ -912,7 +912,7 @@ def _get_mm_fields_config(
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP,
SupportsQuant):
SupportsQuant, SupportsEagle3):

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
Expand Down Expand Up @@ -1137,6 +1137,20 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, 'set_aux_hidden_state_layers')
self.get_language_model().set_aux_hidden_state_layers(layers)

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.

Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
return self.language_model.get_eagle3_aux_hidden_state_layers()

def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
Expand Down
6 changes: 6 additions & 0 deletions vllm/transformers_utils/configs/speculators/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,9 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
vllm_config["norm_before_residual"] = config_dict.get(
"norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids"]
if config_dict.get("inference_type"):
vllm_config["inference_type"] = config_dict["inference_type"]

6 changes: 5 additions & 1 deletion vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,12 @@ def load_model(self, target_model: nn.Module) -> None:

if supports_multimodal(target_model):
# handle multimodality
self.model.config.image_token_index = (
if hasattr(target_model.config, "image_token_index"):
self.model.config.image_token_index = (
target_model.config.image_token_index)
elif hasattr(draft_model_config, "image_token_id"):
self.model.config.image_token_index = (
target_model.image_token_id)
target_language_model = target_model.get_language_model()
else:
target_language_model = target_model
Expand Down
29 changes: 27 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2587,8 +2587,15 @@
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
if supports_eagle3(self.model):
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
# Get auxiliary layers from speculative config if available
aux_layers = self._get_eagle3_aux_layers_from_config()
if aux_layers is not None:
logger.info(f"Using auxiliary layers from speculative config: {aux_layers}")
self.model.set_aux_hidden_state_layers(aux_layers)

Check failure on line 2594 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/v1/worker/gpu_model_runner.py:2594:29: G004 Logging statement uses f-string
else:
# Fallback to model's default implementation
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
else:
raise RuntimeError(
"Model does not support EAGLE3 interface but "
Expand Down Expand Up @@ -2638,6 +2645,24 @@
else:
self.model = UBatchWrapper(self.model, self.vllm_config,
CUDAGraphMode.NONE, self.device)
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
"""
Extract Eagle3 auxiliary layer IDs from the speculative config.

Returns:
Tuple of layer indices from draft model config, or None if not found.
"""
try:
if (self.speculative_config and
self.speculative_config.draft_model_config and
hasattr(self.speculative_config.draft_model_config.hf_config,
'eagle_aux_hidden_state_layer_ids')):
layer_ids = self.speculative_config.draft_model_config.hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
except Exception as e:
logger.warning(f"Failed to read auxiliary layers from speculative config: {e}")
return None

def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
Expand Down