Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR refines speculative decoding (MTP/EAGLE) execution and attention metadata handling, while also restructuring scheduler/model-runner output plumbing and MTP statistics reporting.
Changes:
- Updates attention metadata preparation (slot mapping initialization, kv_indices generation/buffer sizing) to better support speculative decoding paths.
- Refactors scheduler/model-runner output formats to use ordered
req_ids+token_idslists with O(1) req-id indexing. - Revises MTP stats logging behavior and routes stats printing through
Scheduler.spec_stats.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
atom/utils/forward_context.py |
Removes unused fake_block_tables from AttentionMetaData. |
atom/spec_decode/eagle.py |
Adjusts speculative proposer position/index handling and updates attention metadata for MTP decode. |
atom/models/deepseek_mtp.py |
Disables masked embedding behavior and comments out support_torch_compile usage. |
atom/model_ops/sampler.py |
Introduces cached exponential tensor helper for sampling path. |
atom/model_ops/embed_head.py |
Minor import reordering. |
atom/model_ops/attentions/backends.py |
Initializes slot mapping with -1 for scheduled tokens; copies full scheduled range to GPU. |
atom/model_ops/attentions/aiter_mla.py |
Increases kv_indices buffer sizing and generates kv_indices via Triton; various decode/prefill path adjustments. |
atom/model_ops/attentions/aiter_attention.py |
Similar kv_indices buffer sizing and generation changes for persistent attention. |
atom/model_ops/attention_mha.py |
Removes prefill-time fake block table handling. |
atom/model_engine/scheduler.py |
Changes MTP stats logging cadence; refactors ScheduledBatchOutput structure and adds O(1) req-id lookup. |
atom/model_engine/model_runner.py |
Batch token-id postprocessing; adapts to new ScheduledBatchOutput API; removes old MTP stats APIs. |
atom/model_engine/engine_core.py |
Prints MTP stats via scheduler instead of runner RPC. |
atom/model_engine/async_proc.py |
Avoids enqueueing outputs when no output address is configured. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Contributor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (2)
atom/model_engine/scheduler.py:416
num_rejectedis only assigned insideif is_deferred_out or self.use_spec:, but it's used later unconditionally when computingnum_tokens = seq.num_tokens - self.mtp_k - num_rejected. In the non-speculative, non-deferred path this will raiseUnboundLocalError(or reuse a stale value from a previous loop iteration). Initializenum_rejected = 0per-sequence (or computenum_tokensdifferently) so the non-spec path is safe.
if self.mtp_k > 0:
# idx already resolved above via get_idx
seq.spec_token_ids = draft_token_ids[idx]
if seq.num_completion_tokens == 1 and seq.first_token_time == 0.0:
seq.first_token_time = time.time()
num_tokens = seq.num_tokens - self.mtp_k - num_rejected
leave_reason = None
atom/model_engine/engine_core.py:301
print_mtp_statistics()now calls the privateSpecStats._log()unconditionally whenspec_statsexists._log()divides byiv_steps, which will be 0 if no decode steps have been recorded yet, causing aZeroDivisionError. Please add a guard (e.g.,if spec_stats.total_draft_tokens > 0/total_steps > 0) or expose a safe public logging method onSpecStatsthat handles the empty case.
def print_mtp_statistics(self):
if self.scheduler.spec_stats is not None:
self.scheduler.spec_stats._log()
else:
logger.info(
"\n[MTP Stats] No MTP statistics available (MTP not enabled or no tokens processed)\n"
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist