Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,28 @@ def __init__(self,
self.hidden_size = hidden_size
self.max_num_requests = max_num_requests
self.use_relaxed_acceptance_for_thinking = config.use_relaxed_acceptance_for_thinking
self.slot_manager = SlotManager(max_num_requests)
# Reserve one extra slot for the CUDA graph padding dummy request,
# which is kept alive permanently and must not consume a real slot.
slot_pool_size = max_num_requests + 1
self.slot_manager = SlotManager(slot_pool_size)
# Optional SA manager for MTP+SA mode
self.sa_manager = sa_manager

# Since golden token's hidden state will always be generated after target model
self.mtp_past_hidden_states_pool = torch.zeros(
(max_num_requests, self.num_nextn_predict_layers, self.hidden_size),
(slot_pool_size, self.num_nextn_predict_layers, self.hidden_size),
device='cuda',
dtype=self.dtype,
)
self.mtp_past_tokens_pool = torch.zeros(
(max_num_requests, self.num_nextn_predict_layers),
(slot_pool_size, self.num_nextn_predict_layers),
device='cuda',
dtype=torch.int,
)
if self.use_relaxed_acceptance_for_thinking:
# The relaxed_delta for relaxed acceptance
self.mtp_relaxed_delta_pool = torch.zeros(
(self.max_num_requests),
(slot_pool_size),
dtype=torch.float,
device='cuda',
)
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/bench/dataclasses/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,9 @@ def get_max_draft_len(self) -> int:
spec_config = self.kwargs["speculative_config"]
# Handle both dict (from YAML) and object types
if isinstance(spec_config, dict):
return spec_config.get("max_draft_len") or 0
draft_len = (spec_config.get("max_draft_len")
or spec_config.get("num_nextn_predict_layers"))
return draft_len or 0
return spec_config.max_draft_len or 0
Comment on lines +781 to 784
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle object configs and explicit zero values separately.

The new fallback only applies to dict-backed configs. Line 784 still reads spec_config.max_draft_len only, so object-based MTP configs can keep returning 0 and skip decoding stats. The or chain also treats an explicit max_draft_len=0 as “missing”, which changes the meaning of a configured value.

Proposed fix
         if ("speculative_config" in self.kwargs
                 and self.kwargs["speculative_config"] is not None):
             spec_config = self.kwargs["speculative_config"]
             # Handle both dict (from YAML) and object types
             if isinstance(spec_config, dict):
-                draft_len = (spec_config.get("max_draft_len")
-                             or spec_config.get("num_nextn_predict_layers"))
-                return draft_len or 0
-            return spec_config.max_draft_len or 0
+                if "max_draft_len" in spec_config:
+                    return spec_config["max_draft_len"] or 0
+                return spec_config.get("num_nextn_predict_layers") or 0
+            draft_len = getattr(spec_config, "max_draft_len", None)
+            if draft_len is not None:
+                return draft_len
+            return getattr(spec_config, "num_nextn_predict_layers", 0) or 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/bench/dataclasses/reporting.py` around lines 781 - 784, The code
treats falsy 0 as "missing" and only applies the new dict-backed fallback,
causing object-based configs to skip decoding stats; update the logic around
spec_config to (1) detect if spec_config is a dict and, if so, use "if
'max_draft_len' in spec_config: return spec_config['max_draft_len']" (allow 0),
else if "num_nextn_predict_layers" in spec_config return that value (again
allowing 0), and (2) for non-dict (object) configs, use attribute checks like
"hasattr(spec_config, 'max_draft_len') and spec_config.max_draft_len is not
None" to return the explicit attribute (including 0) or fall back similarly to
num_nextn_predict_layers; in short, stop using boolean "or" chains and use
explicit key/attribute existence and "is not None" checks for
spec_config.max_draft_len and spec_config.num_nextn_predict_layers.


return 0
Loading