[None][feat] Add EAGLE3 dynamic tree speculative decoding support#12062
[None][feat] Add EAGLE3 dynamic tree speculative decoding support#12062sunnyqgg wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: qgai <qgai@nvidia.com>
|
/bot run |
|
PR_Github #38374 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request introduces dynamic tree-based speculative decoding for EAGLE3 inference with CUDA-accelerated kernels. It adds tree construction and greedy verification kernels, integrates them with Python/PyTorch layers, implements dynamic tree sampling and acceptance logic, and extends the Eagle3 resource manager and worker infrastructure to support dynamic tree mode alongside static tree mode. Changes
Sequence Diagram(s)sequenceDiagram
participant Python as Python Layer
participant Sampler as TorchSampler
participant DTree as DynamicTreeOpsConverter
participant Kernel as CUDA Kernel
participant Buf as GPU Buffers
Python->>Sampler: update() with requests
Sampler->>Sampler: _batch_verify_dynamic_tree(requests, tokens)
Sampler->>DTree: verify_dynamic_tree_greedy(draft_tokens, logits, tree_buffers)
DTree->>DTree: compute target predictions from logits
DTree->>Kernel: invoke verify_dynamic_tree_greedy_op()
Kernel->>Buf: read: candidates, retrieve_index, retrieve_next_sibling, targetPredict
Kernel->>Buf: greedy tree traversal per batch
Kernel->>Buf: write: predicts, acceptIndex, acceptTokenNum
Buf-->>Kernel: results
Kernel-->>DTree: returns VerifyTreeResults
DTree-->>Sampler: per-slot (num_accepted_tokens, accept_index)
Sampler->>Sampler: _process_draft_tokens_dynamic_tree() per request
Sampler-->>Python: updated tokens and finish reasons
sequenceDiagram
participant Worker as Eagle3DynamicTreeWorker
participant DraftModel as Draft Model
participant DTree as DynamicTreeOpsConverter
participant Kernel as CUDA Kernel
participant Cache as KV Cache
Worker->>Worker: _forward_draft_loop(initial context)
Worker->>DraftModel: forward(input_ids, position_ids)
DraftModel-->>Worker: logits
Worker->>Worker: sample_dynamic(logits, topk)
Worker->>Worker: dt_update_draft_tokens_and_scores()
Worker->>DTree: build_dynamic_tree(parent_list, topk_indices, tree_buffers)
DTree->>Kernel: invoke build_dynamic_tree_op()
Kernel->>Kernel: construct left-child/right-sibling tree
Kernel->>Kernel: compute per-node attention masks (treeMask)
Kernel->>Kernel: compute absolute positions
Kernel-->>DTree: DynamicTreeBuffers (tree_mask, positions, retrieve_index, etc.)
DTree-->>Worker: tree structure ready
Worker->>Worker: dt_prepare_tree_mask_and_position_offset()
Worker->>DraftModel: forward(growing context with tree topology)
DraftModel->>Cache: update KV cache with tree positions
DraftModel-->>Worker: logits per tree node
Worker->>Worker: _sample_and_accept_dynamic_tree(logits)
sequenceDiagram
participant App as Application
participant Executor as PyExecutor
participant ResourceMgr as Eagle3OneModelDynamicTreeResourceManager
participant Worker as Eagle3OneModelDynamicTreeWorker
participant Sampler as Eagle3OneModelDynamicTreeSampler
App->>Executor: initialize with EagleDecodingConfig(use_dynamic_tree=True)
Executor->>ResourceMgr: create with SpecTreeManager(use_dynamic_tree=True)
Executor->>Worker: initialize with spec_config
Executor->>Sampler: initialize with spec_config
App->>Executor: generate(requests)
Executor->>Worker: forward(context & draft loop)
Worker->>Worker: _forward_dynamic_tree_draft_loop()
Worker-->>Executor: draft_tokens, dynamic_tree_buffers, accepted_draft_indices
Executor->>Sampler: sample_and_accept_draft_tokens(logits, buffers)
Sampler->>Sampler: verify with dynamic tree buffers
Sampler-->>Executor: accepted tokens, accepted indices
Executor->>ResourceMgr: get_needed_resource_to_completion(request)
ResourceMgr-->>Executor: resource estimates
Executor->>Worker: prepare_1st_drafter_inputs(with tree topology targets)
Executor-->>App: generated tokens
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
Actionable comments posted: 17
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
3346-3386:⚠️ Potential issue | 🔴 CriticalZero-draft dynamic-tree requests currently fall into the static-tree verifier.
_batch_verify_dynamic_tree()explicitly skips requests whose draft length is 0, but the fallback here still callsprocess_draft_tokens(). With a non-nullspec_tree_manager, that dispatches to_process_draft_tokens_tree(), which assumes a populated draft tree. If a dynamic-tree request reaches this branch with no drafts, it will fail instead of just emitting the verified token.🛠️ One possible guard
- if req.py_seq_slot in dynamic_tree_results: + if req.py_seq_slot in dynamic_tree_results: num_accepted = self._process_draft_tokens_dynamic_tree( req, new_tokens_list, finish_reasons, dynamic_tree_results[req.py_seq_slot] ) - + elif spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree: + num_accepted = self._process_draft_tokens_greedy( + req, new_tokens=new_tokens_list, finish_reasons=finish_reasons + ) else: num_accepted = self.process_draft_tokens( req, new_tokens_tensor=new_tokens, new_tokens_list=new_tokens_list,
🧹 Nitpick comments (5)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
435-437: Keep the wrapper imports namespaced here too.Since this is new dispatch code, please import the
drafting_loopsmodule and reference these wrappers from that module rather than importing the classes directly. As per coding guidelines,When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` around lines 435 - 437, Replace the direct class imports DynamicTreeDraftingLoopWrapper, LinearDraftingLoopWrapper, StaticTreeDraftingLoopWrapper with a namespaced module import for drafting_loops and update all references to use drafting_loops.DynamicTreeDraftingLoopWrapper, drafting_loops.LinearDraftingLoopWrapper, and drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used in the dispatch/registration code inside py_executor_creator.py) so the module is imported, not individual classes.tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
8-9: Keep the drafting-loop import namespaced.Please import the
drafting_loopsmodule and resolveStaticTreeDraftingLoopWrapperthrough that namespace here instead of importing the class directly. As per coding guidelines,When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.Also applies to: 56-56
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py` around lines 8 - 9, Replace the direct class import with a module-level import for the drafting_loops module and qualify the class through that namespace: change the current "from tensorrt_llm._torch.speculative.drafting_loops import StaticTreeDraftingLoopWrapper" to "import tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from tensorrt_llm._torch.speculative import drafting_loops") and update all usages to drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage at the other occurrence noted).tensorrt_llm/_torch/attention_backend/interface.py (1)
371-381: Making trailing parameters keyword-only would improve API safety, but it is not required.The code is already safe: all call sites either use keyword arguments (
model_engine.py:3552, test cases) or correctly pass all 9 positional arguments in order (sparse/dsa.py:541). No stale callers with 7–8 positional arguments (which would silently misbind after insertingis_target_model) exist in the codebase.If keyword-only enforcement is desired for defensiveness, add
*beforeis_target_model:def update_spec_dec_param( self, batch_size, is_spec_decoding_enabled, is_spec_dec_tree, is_spec_dec_dynamic_tree, max_draft_len, max_total_draft_tokens, + *, is_target_model: bool = True, model_is_wrapped: bool = False, spec_tree_manager: Optional['SpecTreeManager'] = None):This prevents future positional misuse and makes the intent explicit, but the current codebase is already compliant.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/interface.py` around lines 371 - 381, The function update_spec_dec_param currently accepts many trailing boolean/optional parameters positionally; make is_target_model, model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare * before is_target_model in the signature so callers cannot accidentally bind those flags positionally—update the signature in the update_spec_dec_param definition and adjust any internal references accordingly (no other logic changes).tensorrt_llm/_torch/attention_backend/trtllm.py (1)
502-509: Consider clarifying the reshape assumption.The reshape from 1D
[max_num_requests * N]to 2D[max_num_requests, N]assumes the 1D tensor was allocated with exactlymax_num_requests * (max_total_draft_tokens + 1)elements. This is correct based on line 1463-1465, but the implicit coupling between allocation and reshape could be fragile.Consider adding an assertion or comment to make this contract explicit:
📝 Suggestion for defensive check
# For dynamic tree, reshape 1D position_offsets to 2D for C++ kernel compatibility position_offsets_for_cpp = self.spec_decoding_position_offsets if (self.spec_decoding_position_offsets is not None and self.spec_decoding_position_offsets.dim() == 1): # Reshape 1D [max_num_requests * N] to 2D [max_num_requests, N] # C++ kernel requires 2D to extract max_generation_length from sizes()[1] + assert self.spec_decoding_position_offsets.numel() % self.max_num_requests == 0, \ + "1D position_offsets size must be divisible by max_num_requests" position_offsets_for_cpp = self.spec_decoding_position_offsets.view( self.max_num_requests, -1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 502 - 509, The reshape from 1D to 2D (position_offsets_for_cpp based on self.spec_decoding_position_offsets) assumes the 1D tensor length equals max_num_requests * N; add a defensive check before the view to assert that self.spec_decoding_position_offsets.numel() is divisible by self.max_num_requests and (optionally) equals self.max_num_requests * (self.max_total_draft_tokens + 1) (or raise a clear error mentioning spec_decoding_position_offsets and max_num_requests) so the implicit allocation/reshape contract in trtllm.py is explicit and fails fast when violated.tensorrt_llm/_torch/pyexecutor/sampler.py (1)
2626-2668: Materializeaccept_indexbefore the per-request Python loop.Line 2649's
accept_index[j].item()performs a device read for every accepted token in a hot request loop. Convert the accepted indices once, then iterate over a Python list here, matching the existingnew_tokens.tolist()pattern. Based on learnings: In files undertensorrt_llm/_torch/pyexecutor, avoid accessingtorch.Tensorobjects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand usingtensor.tolist(), and then iterate over those lists.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/sampler.py` around lines 2626 - 2668, The loop in _process_draft_tokens_dynamic_tree repeatedly calls accept_index[j].item(), causing device reads per-iteration; materialize accept_index to a Python list once before the request loop (e.g. accept_indices = accept_index.tolist() or accept_index.cpu().tolist() and cast elements to int), then iterate over accept_indices for add_token and finish_if_reason calls, and compute request.py_num_accepted_draft_tokens_indices from that list by subtracting 1 for positions after the root; keep using the same symbols (accept_index -> accept_indices, _process_draft_tokens_dynamic_tree, add_token, finish_if_reason, request.py_num_accepted_draft_tokens_indices) so you only replace tensor indexing with list indexing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 112-182: The kernel reuses preallocated topology buffers but
doesn’t clear stale data; before building the tree (inside the tid==0 branch of
the dynamic tree builder) explicitly reinitialize retrieveNextToken and
retrieveNextSibling entries for this batch (bid) to -1 for all draftTokenNum
slots, and clear all words of treeMask for this batch (not just word 0); also
ensure positions/retrieveIndex for all slots are set to sane defaults if needed.
Locate the tid==0 block that sets positions[bid * draftTokenNum] and the loop
that writes retrieveIndex/retrieveNextToken/retrieveNextSibling and add the
resets there (and mirror the same full-reset logic in the other build region
referenced around lines 245-317).
- Around line 191-214: The ancestor-walk loop can run past bounds when a parent
lookup misses: after the for-loop that searches selectedIndex for tokenIdx
(using curPosition and draftTokenNum) add a guard to detect "not found"
(curPosition == draftTokenNum) and break the while loop to avoid reading/writing
past selectedIndex/treeMask; apply the same defensive check to the equivalent
ancestor-walk logic around the other block referenced (uses the same symbols:
treeMask, tokenTreeIdx, curPosition, selectedIndex, draftTokenNum, parentList,
parentTbIdx, bid, topK) so both paths stop if the parent was not resampled.
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Line 17: The header currently uses `#pragma` once but must follow the repo guard
convention; replace the pragma with a preprocessor include guard named
TRTLLM_DYNAMICTREEKERNELS_H (matching the filename dynamicTreeKernels.h in ALL
CAPS) by adding `#ifndef` TRTLLM_DYNAMICTREEKERNELS_H / `#define`
TRTLLM_DYNAMICTREEKERNELS_H at the top and a matching `#endif` at the bottom,
ensuring no directory names or trailing underscores are used and keeping the
rest of the file (dynamicTreeKernels.h) unchanged.
In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 34-67: build_dynamic_tree_op (and its sibling
verify_dynamic_tree_greedy_op) currently access raw data pointers and call
at::cuda::getCurrentCUDAStream() without validating devices, dtypes, shapes or
the treeMaskMode enum; add TORCH_CHECKs to ensure all input/output tensors
(parentList, selectedIndex, treeMask, positions, retrieveIndex,
retrieveNextToken, retrieveNextSibling, verifiedSeqLen) are CUDA tensors
(is_cuda()), are on the same device (device.index() equality), and have the
expected scalar types (parentList/selectedIndex int64,
positions/retrieveIndex/retrieveNextToken/retrieveNextSibling/verifiedSeqLen
int32 as used by data_ptr<int32_t/int64_t>()), verify output shapes (batchSize,
numDraftTokens-1, etc.) before zero_/fill_, and check treeMaskMode is within the
valid tk::TreeMaskMode range before static_cast; perform these checks at the
start of build_dynamic_tree_op and verify_dynamic_tree_greedy_op so that
tk::invokeBuildDynamicTree and related kernel calls only receive validated
tensors and enum values.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Line 540: The forward reference 'SpecTreeManager' used in the signature
(spec_tree_manager: Optional['SpecTreeManager']) is not imported under
TYPE_CHECKING; add "from tensorrt_llm._torch.speculative.spec_tree_manager
import SpecTreeManager" to the existing TYPE_CHECKING import block (alongside
any existing imports such as DecodingBaseConfig) so the name is resolved for
type checking and Ruff F821 is fixed.
In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 444-461: The leftover variable use_tree_drafter is assigned
earlier but never used after splitting into static_tree_drafter and
dynamic_tree_drafter, causing a linter F841; remove the unused use_tree_drafter
assignment (or fold its logic into the existing predicates) so only
static_tree_drafter and dynamic_tree_drafter derived from draft_spec_config
(EagleDecodingConfig) remain, leaving the branching that returns
StaticTreeDraftingLoopWrapper and DynamicTreeDraftingLoopWrapper unchanged
(references: use_tree_drafter, static_tree_drafter, dynamic_tree_drafter,
draft_spec_config, spec_config, StaticTreeDraftingLoopWrapper,
DynamicTreeDraftingLoopWrapper).
In `@tensorrt_llm/_torch/speculative/drafting_loops.py`:
- Around line 709-718: The code unconditionally overwrites return_draft_logits
with zeros losing real collected logits; change the logic in drafting_loops.py
so that you only allocate the zero tensor as a fallback when return_draft_logits
is missing or has an incompatible shape (e.g., check if return_draft_logits is
None or return_draft_logits.shape != (self.max_total_draft_tokens, batch_size,
vocab_size)); otherwise preserve the existing return_draft_logits from the last
draft layer; when allocating the fallback ensure dtype/device match
(torch.float32 and 'cuda') and add a brief comment referencing
tokens_accumulated to indicate this is a temporary fallback until per-layer
gathering is implemented.
- Around line 1164-1180: The attn_metadata.use_spec_decoding flag is left False
so subsequent drafter forwards ignore the dynamic-tree metadata; set
attn_metadata.use_spec_decoding = True at the end of this preparation block
(after updating kv_lens_cuda, _seq_lens, host_request_types and before leaving
the dynamic-tree growth steps) so the next draft pass uses speculative decoding;
locate the block updating attn_metadata.kv_lens_cuda, attn_metadata._seq_lens,
attn_metadata.host_request_types and set attn_metadata.use_spec_decoding = True
there (ensure this happens before
spec_metadata.eagle3_resource_manager.is_first_draft is toggled).
- Around line 961-975: spec_decoding_position_offsets is being treated as a flat
vector but it’s stored as a 2-D buffer ([max_num_requests,
max_total_draft_tokens+1]); update the code to slice and assign it as 2-D so
rows correspond to requests: read previous_position_offsets =
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_previous_layer], build new_position_offsets by concatenating along
dim=1 (using previous_position_offsets and previous_position_offsets[:,
-self.dynamic_tree_max_topK:]+1), then write it back to
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_current_layer] (no flattening/view needed) so the correct request
rows are updated.
In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 1-12: Add the standard NVIDIA Apache-2.0 license header (with the
latest modification year) at the top of the file before the existing module
docstring in dynamic_tree_ops.py; replace the current file-starting
docstring-only content by prepending the required NVIDIA copyright/license block
so the file begins with the Apache 2.0 header followed by the existing "Dynamic
Tree Operations for EAGLE3 Speculative Decoding" docstring.
In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py`:
- Around line 98-143: The buffers currently fix max_batch_size = 256 which can
overflow for larger deployments; replace the hard-coded max_batch_size with a
runtime value derived from spec_config or the worker's actual max concurrent
sequences (e.g., spec_config.max_batch_size or a passed-in parameter) and use
that variable when allocating dt_draft_tokens_buffer, dt_position_ids_buffer,
history_draft_tokens_buffer, history_score_buffer,
history_draft_tokens_parent_buffer, tree_mask_buffer, tree_mask_init_buffer, and
tree_mask_padding_zeros, and when calling create_dynamic_tree_ops_converter
(preserve device and dtypes). Ensure the new max_batch_size value is validated
(positive int) and that any flattened-size calculations (like tree_mask_buffer
shape) are updated to use the dynamic max_batch_size to avoid out-of-bounds
writes.
- Around line 475-487: The code incorrectly reshapes
spec_decoding_position_offsets into a flat (max_reqs, tokens_per_req) layout
which conflicts with the request-major layout used elsewhere; replace the manual
flattening with a request-major view and index into the existing first
dimension. Concretely, stop computing max_reqs = total_po_size // tokens_per_req
and using pos_2d = attn_metadata.spec_decoding_position_offsets.view(max_reqs,
tokens_per_req); instead treat pos_2d as request-major (e.g., pos_2d =
attn_metadata.spec_decoding_position_offsets.view(-1, tokens_per_req) or simply
use the existing first-dimension shape) and then write pos_2d[req_idx, :n] =
causal_offs[:n] ensuring req_idx is computed as num_contexts + g_idx and within
bounds. Apply the same fix to the other occurrence around lines 1090-1101 to
preserve the [max_num_requests, max_total_draft_tokens + 1] layout everywhere.
In `@tensorrt_llm/_torch/speculative/model_drafter.py`:
- Around line 686-690: Remove the two unused CPU buffer assignments to avoid
dead code: delete the assignments to topk_score_indices and
history_draft_tokens_parent_buffer that read from
dynamic_tree_buffers["topk_score_indices"].cpu() and
dynamic_tree_buffers["history_draft_tokens_parent_buffer"].cpu(). If those
buffers are intended for future use, replace each assignment with a short TODO
comment referencing the buffer name (topk_score_indices,
history_draft_tokens_parent_buffer) and why it will be needed; otherwise simply
remove the two lines. Ensure no other code in the same method depends on these
variables after removal.
- Around line 1004-1005: The prepare_draft_tokens method in ModelDrafter
currently requires resource_manager but the base class Drafter defines it as
optional; change the signature of ModelDrafter.prepare_draft_tokens to accept
resource_manager: Optional[ResourceManager] = None so it matches the base
contract, add or ensure Optional is imported from typing if missing, and mirror
the pattern used in ngram.py; update any internal usage of resource_manager to
handle None safely.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 1865-1876: The file contains a duplicate TypeAlias named
SpeculativeConfig that shadows the earlier discriminated union (the one defined
with Field(discriminator="decoding_type")), which removes SADecodingConfig and
PARDDecodingConfig; remove the second SpeculativeConfig definition (or rename it
if you truly need a separate non-discriminated alias) and keep the original
annotated union (including SADecodingConfig and PARDDecodingConfig alongside
DraftTargetDecodingConfig, Eagle3DecodingConfig, EagleDecodingConfig,
LookaheadDecodingConfig, MedusaDecodingConfig, MTPDecodingConfig,
NGramDecodingConfig, UserProvidedDecodingConfig, SaveHiddenStatesDecodingConfig,
AutoDecodingConfig) so the discriminator-based Pydantic union remains intact.
---
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/interface.py`:
- Around line 371-381: The function update_spec_dec_param currently accepts many
trailing boolean/optional parameters positionally; make is_target_model,
model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare *
before is_target_model in the signature so callers cannot accidentally bind
those flags positionally—update the signature in the update_spec_dec_param
definition and adjust any internal references accordingly (no other logic
changes).
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 502-509: The reshape from 1D to 2D (position_offsets_for_cpp based
on self.spec_decoding_position_offsets) assumes the 1D tensor length equals
max_num_requests * N; add a defensive check before the view to assert that
self.spec_decoding_position_offsets.numel() is divisible by
self.max_num_requests and (optionally) equals self.max_num_requests *
(self.max_total_draft_tokens + 1) (or raise a clear error mentioning
spec_decoding_position_offsets and max_num_requests) so the implicit
allocation/reshape contract in trtllm.py is explicit and fails fast when
violated.
In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 435-437: Replace the direct class imports
DynamicTreeDraftingLoopWrapper, LinearDraftingLoopWrapper,
StaticTreeDraftingLoopWrapper with a namespaced module import for drafting_loops
and update all references to use drafting_loops.DynamicTreeDraftingLoopWrapper,
drafting_loops.LinearDraftingLoopWrapper, and
drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used
in the dispatch/registration code inside py_executor_creator.py) so the module
is imported, not individual classes.
In `@tensorrt_llm/_torch/pyexecutor/sampler.py`:
- Around line 2626-2668: The loop in _process_draft_tokens_dynamic_tree
repeatedly calls accept_index[j].item(), causing device reads per-iteration;
materialize accept_index to a Python list once before the request loop (e.g.
accept_indices = accept_index.tolist() or accept_index.cpu().tolist() and cast
elements to int), then iterate over accept_indices for add_token and
finish_if_reason calls, and compute request.py_num_accepted_draft_tokens_indices
from that list by subtracting 1 for positions after the root; keep using the
same symbols (accept_index -> accept_indices,
_process_draft_tokens_dynamic_tree, add_token, finish_if_reason,
request.py_num_accepted_draft_tokens_indices) so you only replace tensor
indexing with list indexing.
In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py`:
- Around line 8-9: Replace the direct class import with a module-level import
for the drafting_loops module and qualify the class through that namespace:
change the current "from tensorrt_llm._torch.speculative.drafting_loops import
StaticTreeDraftingLoopWrapper" to "import
tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from
tensorrt_llm._torch.speculative import drafting_loops") and update all usages to
drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage
at the other occurrence noted).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 388f70bd-848c-4020-99de-5838cd97e5b3
📒 Files selected for processing (23)
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cucpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.hcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/dynamicTreeOp.cppexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/models/modeling_speculative.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/_torch/pyexecutor/sampler.pytensorrt_llm/_torch/speculative/drafting_loops.pytensorrt_llm/_torch/speculative/dynamic_tree_ops.pytensorrt_llm/_torch/speculative/eagle3.pytensorrt_llm/_torch/speculative/eagle3_dynamic_tree.pytensorrt_llm/_torch/speculative/model_drafter.pytensorrt_llm/_torch/speculative/spec_tree_manager.pytensorrt_llm/_torch/speculative/utils.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
| // ==================== tid==0: Root 线程 ==================== | ||
| // 负责构建"左孩子-右兄弟"链表结构 (retrieveIndex/NextToken/NextSibling) | ||
| if (tid == 0) | ||
| { | ||
| // root 的 position = seqLen (紧接已验证序列之后) | ||
| positions[bid * draftTokenNum] = seqLen; | ||
|
|
||
| // retrieveIndex 是树节点 → batch 内本地下标的映射 | ||
| // 对于 batch bid: retrieveIndex[i] = i (即恒等映射) | ||
| // verify kernel 访问扁平化 targetPredict/predicts 时自行加 bx * N 偏移 | ||
|
|
||
| // 从最后一个节点到第一个节点,逆序构建链表 | ||
| // 逆序遍历的效果:同一父节点的子节点按原始顺序排列在兄弟链表中 | ||
| // (因为每次新节点插入到链表头部,逆序插入后最终顺序即为正序) | ||
| for (int32_t i = draftTokenNum - 1; i > 0; --i) | ||
| { | ||
| // 设置当前节点的 retrieveIndex (本地索引) | ||
| retrieveIndex[bid * draftTokenNum + i] = i; | ||
|
|
||
| // ---- 查找节点 i 在树中的父节点位置 ---- | ||
| // selectedIndex[i-1] 是节点 i 在 history buffer 中的全局索引 | ||
| // 除以 topK 得到父节点在 parentList 中的索引 (parentTbIdx) | ||
| // parentTbIdx==0 表示父节点是 root | ||
| int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + i - 1] / topK; | ||
| int32_t parentPosition = 0; // 父节点在树内的位置 (0=root) | ||
|
|
||
| if (parentTbIdx > 0) | ||
| { | ||
| // 非 root 父节点: 通过 parentList 查找父节点在 history 中的索引 | ||
| int64_t parentTokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx]; | ||
| // 在 selectedIndex 中查找这个 history 索引对应的树内位置 | ||
| for (; parentPosition < draftTokenNum; ++parentPosition) | ||
| { | ||
| if (selectedIndex[bid * (draftTokenNum - 1) + parentPosition] == parentTokenIdx) | ||
| { | ||
| ++parentPosition; // 树内位置 = selectedIndex 中的索引 + 1 (因为 0 是 root) | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| // parentTbIdx==0 时 parentPosition 保持为 0,即父节点是 root | ||
|
|
||
| if (parentPosition == draftTokenNum) | ||
| { | ||
| // 找不到父节点 (数据异常,可能是 logprob 有 nan) | ||
| printf( | ||
| "WARNING: Invalid dynamic tree! Detected a token with no parent token selected. " | ||
| "Please check if the logprob has nan. The token will be ignored.\n"); | ||
| continue; | ||
| } | ||
|
|
||
| // ---- 将节点 i 插入父节点的子节点链表 ---- | ||
| // 使用"左孩子-右兄弟"表示法: | ||
| // retrieveNextToken[parent] = 第一个子节点 | ||
| // retrieveNextSibling[child] = 下一个兄弟节点 | ||
| if (retrieveNextToken[bid * draftTokenNum + parentPosition] == -1) | ||
| { | ||
| // 父节点还没有子节点 → 直接设为第一个子节点 | ||
| retrieveNextToken[bid * draftTokenNum + parentPosition] = i; | ||
| } | ||
| else | ||
| { | ||
| // 父节点已有子节点 → 将节点 i 插入链表头部 | ||
| // 原来的第一个子节点变成 i 的兄弟 | ||
| int32_t originNextToken = retrieveNextToken[bid * draftTokenNum + parentPosition]; | ||
| retrieveNextToken[bid * draftTokenNum + parentPosition] = i; | ||
| retrieveNextSibling[bid * draftTokenNum + i] = originNextToken; | ||
| } | ||
| } | ||
| // root 节点自身的 retrieveIndex (本地索引 = 0) | ||
| retrieveIndex[bid * draftTokenNum] = 0; |
There was a problem hiding this comment.
Reinitialize the reused topology buffers on every tree build.
These kernels reuse preallocated output tensors, but they only overwrite the entries touched by the new tree. retrieveNextToken/retrieveNextSibling are never reset to -1, and the packed builder only clears treeMask word 0. That leaves stale child links and stale higher-word mask bits from the previous iteration in place.
Also applies to: 245-317
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu` around
lines 112 - 182, The kernel reuses preallocated topology buffers but doesn’t
clear stale data; before building the tree (inside the tid==0 branch of the
dynamic tree builder) explicitly reinitialize retrieveNextToken and
retrieveNextSibling entries for this batch (bid) to -1 for all draftTokenNum
slots, and clear all words of treeMask for this batch (not just word 0); also
ensure positions/retrieveIndex for all slots are set to sane defaults if needed.
Locate the tid==0 block that sets positions[bid * draftTokenNum] and the loop
that writes retrieveIndex/retrieveNextToken/retrieveNextSibling and add the
resets there (and mirror the same full-reset logic in the other build region
referenced around lines 245-317).
| while (true) | ||
| { | ||
| position += 1; // 每向上走一步,深度+1 | ||
| // 在 treeMask 中将祖先节点的列置为 1 (表示当前节点可以看到该祖先) | ||
| treeMask[tokenTreeIdx + curPosition] = 1; | ||
|
|
||
| // 查找父节点: selectedIndex[curPosition] / topK = parentTbIdx | ||
| int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + curPosition] / topK; | ||
| if (parentTbIdx == 0) | ||
| { | ||
| break; // 到达 root,停止追溯 | ||
| } | ||
|
|
||
| // 通过 parentList 找到父节点在 history 中的索引, | ||
| // 再在 selectedIndex 中查找其树内位置 | ||
| int64_t tokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx]; | ||
| for (curPosition = 0; curPosition < draftTokenNum; ++curPosition) | ||
| { | ||
| if (selectedIndex[bid * (draftTokenNum - 1) + curPosition] == tokenIdx) | ||
| { | ||
| break; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Guard the ancestor walk when a parent was not resampled.
The root-thread path already detects "parent not found", but the per-node ancestor walk does not. If the lookup misses, curPosition reaches draftTokenNum, and the next loop iteration reads/writes past selectedIndex and the current mask row.
Also applies to: 342-359
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu` around
lines 191 - 214, The ancestor-walk loop can run past bounds when a parent lookup
misses: after the for-loop that searches selectedIndex for tokenIdx (using
curPosition and draftTokenNum) add a guard to detect "not found" (curPosition ==
draftTokenNum) and break the while loop to avoid reading/writing past
selectedIndex/treeMask; apply the same defensive check to the equivalent
ancestor-walk logic around the other block referenced (uses the same symbols:
treeMask, tokenTreeIdx, curPosition, selectedIndex, draftTokenNum, parentList,
parentTbIdx, bid, topK) so both paths stop if the parent was not resampled.
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Use the repo's header guard convention here.
This new header uses #pragma once, but the repo requires a TRTLLM_<FILENAME_IN_CAPS>_H include guard for .h files. As per coding guidelines: **/*.{h,hpp,cuh}: Use a preprocessor guard in C++ header files with the format TRTLLM_<FILENAME_IN_CAPS>_H (e.g., TRTLLM_FOO_BAR_HELLO_H). Do not use directory names or trailing underscores.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h` at line
17, The header currently uses `#pragma` once but must follow the repo guard
convention; replace the pragma with a preprocessor include guard named
TRTLLM_DYNAMICTREEKERNELS_H (matching the filename dynamicTreeKernels.h in ALL
CAPS) by adding `#ifndef` TRTLLM_DYNAMICTREEKERNELS_H / `#define`
TRTLLM_DYNAMICTREEKERNELS_H at the top and a matching `#endif` at the bottom,
ensuring no directory names or trailing underscores are used and keeping the
rest of the file (dynamicTreeKernels.h) unchanged.
| void build_dynamic_tree_op(th::Tensor& parentList, th::Tensor& selectedIndex, th::Tensor& treeMask, | ||
| th::Tensor& positions, th::Tensor& retrieveIndex, th::Tensor& retrieveNextToken, th::Tensor& retrieveNextSibling, | ||
| int64_t topK, int64_t depth, int64_t numDraftTokens, int64_t treeMaskMode) | ||
| { | ||
| // Validate inputs | ||
| TORCH_CHECK(parentList.dim() == 2, "parentList must be 2D tensor"); | ||
| TORCH_CHECK(selectedIndex.dim() == 2, "selectedIndex must be 2D tensor"); | ||
| TORCH_CHECK(parentList.scalar_type() == torch::kInt64, "parentList must be int64 tensor"); | ||
| TORCH_CHECK(selectedIndex.scalar_type() == torch::kInt64, "selectedIndex must be int64 tensor"); | ||
|
|
||
| int64_t batchSize = parentList.size(0); | ||
| TORCH_CHECK(selectedIndex.size(0) == batchSize, "Batch size mismatch"); | ||
| TORCH_CHECK(selectedIndex.size(1) == numDraftTokens - 1, "selectedIndex size mismatch"); | ||
|
|
||
| auto device = parentList.device(); | ||
| auto stream = at::cuda::getCurrentCUDAStream(device.index()); | ||
|
|
||
| // Reset output buffers | ||
| treeMask.zero_(); | ||
| positions.zero_(); | ||
| retrieveIndex.zero_(); | ||
| retrieveNextToken.fill_(-1); | ||
| retrieveNextSibling.fill_(-1); | ||
|
|
||
| // Create zero verifiedSeqLen (positions returned directly without offset) | ||
| auto verifiedSeqLen = torch::zeros({batchSize}, torch::TensorOptions().dtype(torch::kInt32).device(device)); | ||
|
|
||
| // Call kernel | ||
| tk::invokeBuildDynamicTree(parentList.data_ptr<int64_t>(), selectedIndex.data_ptr<int64_t>(), | ||
| verifiedSeqLen.data_ptr<int32_t>(), treeMask.data_ptr(), positions.data_ptr<int32_t>(), | ||
| retrieveIndex.data_ptr<int32_t>(), retrieveNextToken.data_ptr<int32_t>(), | ||
| retrieveNextSibling.data_ptr<int32_t>(), batchSize, topK, depth, numDraftTokens, | ||
| static_cast<tk::TreeMaskMode>(treeMaskMode), stream); | ||
| } |
There was a problem hiding this comment.
Add device, dtype, and shape validation before accessing raw tensor pointers.
Both build_dynamic_tree_op and verify_dynamic_tree_greedy_op access raw tensor data via data_ptr<int32_t/int64_t>() and call at::cuda::getCurrentCUDAStream() without validating that:
- All tensors are CUDA tensors (
is_cuda()) - All tensors reside on the same device
- Output tensor dtypes match expected types (int32 or int64)
- Output tensor shapes match expected dimensions
- Enum parameter
treeMaskModeis within valid range before casting
Missing these checks allows invalid Python input to cause undefined behavior or cryptic crashes instead of clean TORCH_CHECK failures.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp` around lines 34 - 67,
build_dynamic_tree_op (and its sibling verify_dynamic_tree_greedy_op) currently
access raw data pointers and call at::cuda::getCurrentCUDAStream() without
validating devices, dtypes, shapes or the treeMaskMode enum; add TORCH_CHECKs to
ensure all input/output tensors (parentList, selectedIndex, treeMask, positions,
retrieveIndex, retrieveNextToken, retrieveNextSibling, verifiedSeqLen) are CUDA
tensors (is_cuda()), are on the same device (device.index() equality), and have
the expected scalar types (parentList/selectedIndex int64,
positions/retrieveIndex/retrieveNextToken/retrieveNextSibling/verifiedSeqLen
int32 as used by data_ptr<int32_t/int64_t>()), verify output shapes (batchSize,
numDraftTokens-1, etc.) before zero_/fill_, and check treeMaskMode is within the
valid tk::TreeMaskMode range before static_cast; perform these checks at the
start of build_dynamic_tree_op and verify_dynamic_tree_greedy_op so that
tk::invokeBuildDynamicTree and related kernel calls only receive validated
tensors and enum values.
| for i, prompt in enumerate(prompts): | ||
| num_tokens = 0 | ||
| num_iterations = 0 | ||
| for output in llm.generate_async(prompt, | ||
| sampling_params, | ||
| streaming=True): | ||
| new_tokens = output.outputs[0].token_ids | ||
| num_tokens = len(new_tokens) | ||
| num_iterations += 1 | ||
| if num_iterations > 0: | ||
| accept_rate = num_tokens / num_iterations | ||
| print(f"[{i}] Accept rate: {accept_rate:.2f} " | ||
| f"(tokens={num_tokens}, iterations={num_iterations})") | ||
| generated_text = output.outputs[0].text | ||
| print( | ||
| f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") |
There was a problem hiding this comment.
Make the streaming branch handle all sequences and empty streams.
This path assumes output.outputs[0] is the only sequence and that the iterator always yields at least once. That means --n > 1 / beam-search results are silently dropped here, and an empty stream would hit output before assignment on the final print. Either iterate over every output.outputs entry like the non-streaming path does, or explicitly reject those combinations in streaming mode.
| tokens_per_req = spec_metadata.max_total_draft_tokens + 1 | ||
| total_po_size = attn_metadata.spec_decoding_position_offsets.shape[0] | ||
| max_reqs = total_po_size // tokens_per_req | ||
| pos_2d = attn_metadata.spec_decoding_position_offsets.view( | ||
| max_reqs, tokens_per_req | ||
| ) | ||
| max_gl = int(gen_sl.max().item()) | ||
| causal_offs = torch.arange(max_gl, device="cuda", dtype=torch.int32) | ||
| for g_idx in range(num_gens): | ||
| req_idx = num_contexts + g_idx | ||
| n = int(gen_sl[g_idx].item()) | ||
| pos_2d[req_idx, :n] = causal_offs[:n] | ||
|
|
There was a problem hiding this comment.
Keep spec_decoding_position_offsets in request-major layout.
This file mixes two incompatible views of spec_decoding_position_offsets: step 0 reshapes it as if it were flat, and the later update path slices it as [: batch_size * num_tokens_previous_layer]. If the metadata keeps the same [max_num_requests, max_total_draft_tokens + 1] layout used elsewhere in the speculative stack, both blocks will update the wrong memory once multiple requests are active.
Also applies to: 1090-1101
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py` around lines 475 -
487, The code incorrectly reshapes spec_decoding_position_offsets into a flat
(max_reqs, tokens_per_req) layout which conflicts with the request-major layout
used elsewhere; replace the manual flattening with a request-major view and
index into the existing first dimension. Concretely, stop computing max_reqs =
total_po_size // tokens_per_req and using pos_2d =
attn_metadata.spec_decoding_position_offsets.view(max_reqs, tokens_per_req);
instead treat pos_2d as request-major (e.g., pos_2d =
attn_metadata.spec_decoding_position_offsets.view(-1, tokens_per_req) or simply
use the existing first-dimension shape) and then write pos_2d[req_idx, :n] =
causal_offs[:n] ensuring req_idx is computed as num_contexts + g_idx and within
bounds. Apply the same fix to the other occurrence around lines 1090-1101 to
preserve the [max_num_requests, max_total_draft_tokens + 1] layout everywhere.
| topk_score_indices = dynamic_tree_buffers["topk_score_indices"].cpu( | ||
| ) # [batch_size, self.max_total_draft_tokens] | ||
| history_draft_tokens_parent_buffer = dynamic_tree_buffers[ | ||
| "history_draft_tokens_parent_buffer"].cpu( | ||
| ) # [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)] |
There was a problem hiding this comment.
Remove unused variables flagged by static analysis.
The variables topk_score_indices and history_draft_tokens_parent_buffer are assigned but never used in this method. If they are intended for future use, consider adding a TODO comment; otherwise, remove them to avoid confusion.
🧹 Proposed fix
if isinstance(
self.spec_config,
EagleDecodingConfig) and self.spec_config.use_dynamic_tree:
dynamic_tree_buffers = outputs["dynamic_tree_buffers"]
- topk_score_indices = dynamic_tree_buffers["topk_score_indices"].cpu(
- ) # [batch_size, self.max_total_draft_tokens]
- history_draft_tokens_parent_buffer = dynamic_tree_buffers[
- "history_draft_tokens_parent_buffer"].cpu(
- ) # [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)]
+ # Note: topk_score_indices and history_draft_tokens_parent_buffer are available
+ # in dynamic_tree_buffers if needed for debugging or future use
tree_structure = dynamic_tree_buffers.get("tree_structure")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| topk_score_indices = dynamic_tree_buffers["topk_score_indices"].cpu( | |
| ) # [batch_size, self.max_total_draft_tokens] | |
| history_draft_tokens_parent_buffer = dynamic_tree_buffers[ | |
| "history_draft_tokens_parent_buffer"].cpu( | |
| ) # [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)] | |
| # Note: topk_score_indices and history_draft_tokens_parent_buffer are available | |
| # in dynamic_tree_buffers if needed for debugging or future use |
🧰 Tools
🪛 Ruff (0.15.5)
[error] 686-686: Local variable topk_score_indices is assigned to but never used
Remove assignment to unused variable topk_score_indices
(F841)
[error] 688-688: Local variable history_draft_tokens_parent_buffer is assigned to but never used
Remove assignment to unused variable history_draft_tokens_parent_buffer
(F841)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/model_drafter.py` around lines 686 - 690,
Remove the two unused CPU buffer assignments to avoid dead code: delete the
assignments to topk_score_indices and history_draft_tokens_parent_buffer that
read from dynamic_tree_buffers["topk_score_indices"].cpu() and
dynamic_tree_buffers["history_draft_tokens_parent_buffer"].cpu(). If those
buffers are intended for future use, replace each assignment with a short TODO
comment referencing the buffer name (topk_score_indices,
history_draft_tokens_parent_buffer) and why it will be needed; otherwise simply
remove the two lines. Ensure no other code in the same method depends on these
variables after removal.
| def prepare_draft_tokens(self, scheduled_requests: ScheduledRequests, | ||
| resource_manager: ResourceManager) -> None: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Check if other implementations of prepare_draft_tokens have compatible signatures
# Find all prepare_draft_tokens method definitions
rg -n "def prepare_draft_tokens" --type=py -A 3
# Check the base class signature
cat tensorrt_llm/_torch/speculative/drafter.py | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 3502
Fix LSP violation: make resource_manager parameter optional in ModelDrafter.prepare_draft_tokens().
The method signature at lines 1004-1005 requires resource_manager: ResourceManager, but the abstract base class at drafter.py:27-30 defines it as Optional[ResourceManager] = None. This violates Liskov Substitution Principle—callers cannot treat ModelDrafter instances polymorphically as Drafter without potentially failing if resource_manager is omitted. The ngram.py implementation maintains the correct signature. Update the method signature to match the base class contract: resource_manager: Optional[ResourceManager] = None.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/model_drafter.py` around lines 1004 - 1005,
The prepare_draft_tokens method in ModelDrafter currently requires
resource_manager but the base class Drafter defines it as optional; change the
signature of ModelDrafter.prepare_draft_tokens to accept resource_manager:
Optional[ResourceManager] = None so it matches the base contract, add or ensure
Optional is imported from typing if missing, and mirror the pattern used in
ngram.py; update any internal usage of resource_manager to handle None safely.
| if self.use_dynamic_tree or self.dynamic_tree_max_topK is not None: | ||
| self.use_dynamic_tree = True | ||
| assert self.dynamic_tree_max_topK is not None and self.dynamic_tree_max_topK > 0, "dynamic_tree_max_topK is required for dynamic tree" | ||
| assert self.eagle_choices is None, "If use_dynamic_tree is True, eagle_choices should be None" | ||
| total_history_draft_tokens = self.dynamic_tree_max_topK + self.dynamic_tree_max_topK * self.dynamic_tree_max_topK * ( | ||
| self.max_draft_len - 1) | ||
| default_max_total_draft_tokens = self.dynamic_tree_max_topK * self.max_draft_len | ||
|
|
||
| if self.max_total_draft_tokens is None: | ||
| self.max_total_draft_tokens = default_max_total_draft_tokens | ||
| logger.warning( | ||
| f"max_total_draft_tokens is not provided, use the default value {default_max_total_draft_tokens} (default_max_total_draft_tokens = dynamic_tree_max_topK * max_draft_len)" | ||
| ) | ||
| else: | ||
| assert self.max_total_draft_tokens <= total_history_draft_tokens and self.max_total_draft_tokens >= default_max_total_draft_tokens, f"max_total_draft_tokens should be between {default_max_total_draft_tokens} and {total_history_draft_tokens}" | ||
|
|
||
| # Linear tree | ||
| if self.max_total_draft_tokens is None: | ||
| self.max_total_draft_tokens = self.max_draft_len | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Get line count first to understand file size
wc -l tensorrt_llm/llmapi/llm_args.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 99
🏁 Script executed:
# Read the lines around 953-990 to see the full context of the dynamic tree validation
sed -n '940,1000p' tensorrt_llm/llmapi/llm_args.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3558
🏁 Script executed:
# Check the SpeculativeConfig definitions around lines 1847-1876
sed -n '1840,1890p' tensorrt_llm/llmapi/llm_args.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1502
🏁 Script executed:
# Check the quickstart_advanced.py file mentioned in the scratchpad
sed -n '240,260p' examples/llm-api/quickstart_advanced.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1127
Replace assert statements with ValueError and consolidate dynamic-tree validation logic.
The unreachable defaulting at line 981 confirms the existing validation at line 962 will reject omitted max_total_draft_tokens before the fallback runs. Additionally, lines 986-987 and 989 use assert statements that violate the coding guideline requiring ValueError to be raised in Pydantic validators. Consolidate the two dynamic-tree blocks into a single validator, remove the assert statements, and handle both the validation and the new defaulting logic within one flow path.
| SpeculativeConfig: TypeAlias = Optional[Union[ | ||
| DraftTargetDecodingConfig, | ||
| Eagle3DecodingConfig, # Must be before EagleDecodingConfig since it's a subclass | ||
| EagleDecodingConfig, | ||
| LookaheadDecodingConfig, | ||
| MedusaDecodingConfig, | ||
| MTPDecodingConfig, | ||
| NGramDecodingConfig, | ||
| UserProvidedDecodingConfig, | ||
| SaveHiddenStatesDecodingConfig, | ||
| AutoDecodingConfig, | ||
| ]] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l tensorrt_llm/llmapi/llm_args.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 99
🏁 Script executed:
sed -n '1850,1900p' tensorrt_llm/llmapi/llm_args.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 2003
🏁 Script executed:
rg "SpeculativeConfig\s*:" tensorrt_llm/llmapi/llm_args.py -nRepository: NVIDIA/TensorRT-LLM
Length of output: 161
🏁 Script executed:
rg "speculative_config" tensorrt_llm/llmapi/llm_args.py -n -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 6052
🏁 Script executed:
rg "class.*DecodingConfig" tensorrt_llm/llmapi/llm_args.py -n | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 771
🏁 Script executed:
sed -n '1840,1880p' tensorrt_llm/llmapi/llm_args.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1554
🏁 Script executed:
sed -n '1847,1863p' tensorrt_llm/llmapi/llm_args.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 656
Remove the second SpeculativeConfig definition or rename it.
The second SpeculativeConfig definition (lines 1865-1876) shadows the first (lines 1847-1863), removing the discriminator field and excluding SADecodingConfig and PARDDecodingConfig. This violates the guideline to use discriminated unions in Pydantic for fields accepting multiple config classes. Keep the first annotated definition with Field(discriminator="decoding_type") and either remove the second definition or give it a different name if an optional non-discriminated alias is needed elsewhere.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/llmapi/llm_args.py` around lines 1865 - 1876, The file contains
a duplicate TypeAlias named SpeculativeConfig that shadows the earlier
discriminated union (the one defined with Field(discriminator="decoding_type")),
which removes SADecodingConfig and PARDDecodingConfig; remove the second
SpeculativeConfig definition (or rename it if you truly need a separate
non-discriminated alias) and keep the original annotated union (including
SADecodingConfig and PARDDecodingConfig alongside DraftTargetDecodingConfig,
Eagle3DecodingConfig, EagleDecodingConfig, LookaheadDecodingConfig,
MedusaDecodingConfig, MTPDecodingConfig, NGramDecodingConfig,
UserProvidedDecodingConfig, SaveHiddenStatesDecodingConfig, AutoDecodingConfig)
so the discriminator-based Pydantic union remains intact.
Simplify comments: keep only \param docs in English, remove verbose prose descriptions and internal implementation comments. Signed-off-by: qgai <qgai@nvidia.com>
|
PR_Github #38374 [ run ] completed with state
|
Summary
Eagle3OneModelDynamicTreeWorkerandEagle3OneModelDynamicTreeSamplerfor one-model dynamic tree inferenceChanges
tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py— One-model dynamic tree worker and samplertensorrt_llm/_torch/speculative/dynamic_tree_ops.py— Python wrappers for dynamic tree CUDA opscpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu/.h— CUDA kernelscpp/tensorrt_llm/thop/dynamicTreeOp.cpp— Torch custom op bindingstensorrt_llm/_torch/speculative/eagle3.py— RefactoredEagle3OneModelWorkerwith dispatch pattern for linear vs dynamic treetensorrt_llm/_torch/speculative/utils.py— Route to dynamic tree components whenuse_dynamic_tree=Truetensorrt_llm/_torch/speculative/drafting_loops.py— Two-model dynamic tree drafting looptensorrt_llm/_torch/speculative/model_drafter.py— Dynamic tree spec tree manager integrationtensorrt_llm/_torch/speculative/spec_tree_manager.py— Support dynamic tree token organizationtensorrt_llm/_torch/pyexecutor/model_engine.py— Dynamic tree detection for target modeltensorrt_llm/_torch/pyexecutor/sampler.py— Dynamic tree batch verificationtensorrt_llm/_torch/models/modeling_speculative.py— Hidden states handling for dynamic treetensorrt_llm/llmapi/llm_args.py— Configuration validation for dynamic treeTest plan
tests/unittest/_torch/speculative/)Summary by CodeRabbit
New Features
Configuration
--max_total_draft_tokensparameter for controlling total draft token budget--streamingflag for real-time token streaming output