From e2be8e0f96607a48c337f2c55a1cb9fc858307cf Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Sat, 28 Feb 2026 10:27:33 +0000 Subject: [PATCH 1/8] mtp_draft_fix --- atom/model_engine/engine_core.py | 10 ++----- atom/model_engine/model_runner.py | 28 +++----------------- atom/model_ops/attention_mha.py | 6 ----- atom/model_ops/attentions/aiter_attention.py | 25 ++++++++++------- atom/model_ops/attentions/aiter_mla.py | 22 +++++++++++++-- atom/model_ops/attentions/backends.py | 3 ++- atom/models/deepseek_mtp.py | 9 +++---- atom/spec_decode/eagle.py | 28 ++++++++++++++++---- atom/utils/forward_context.py | 1 - 9 files changed, 71 insertions(+), 61 deletions(-) diff --git a/atom/model_engine/engine_core.py b/atom/model_engine/engine_core.py index 80f85fd90..f85baa0b8 100644 --- a/atom/model_engine/engine_core.py +++ b/atom/model_engine/engine_core.py @@ -293,14 +293,8 @@ def stop_profiler(self): logger.info("Profiler stopped.") def print_mtp_statistics(self): - stats = self.runner_mgr.call_func("get_mtp_statistics", wait_out=True) - if stats and stats.get("total_draft_tokens", 0) > 0: - logger.info(f"\n{'='*50}") - logger.info("MTP (Multi-Token Prediction) Statistics:") - logger.info(f" Total draft tokens: {stats['total_draft_tokens']}") - logger.info(f" Accepted tokens: {stats['total_accepted_tokens']}") - logger.info(f" Acceptance rate: {stats['acceptance_rate']:.2%}") - logger.info(f"{'='*50}\n") + if self.scheduler.spec_stats is not None: + self.scheduler.spec_stats._log() else: logger.info( "\n[MTP Stats] No MTP statistics available (MTP not enabled or no tokens processed)\n" diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index bad557093..201926ef4 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -611,29 +611,6 @@ def is_qwen_next(self) -> bool: return True return False - def get_mtp_statistics(self) -> dict: - if hasattr(self, "mtp_total_draft_tokens"): - acceptance_rate = ( - self.mtp_total_accepted_tokens / self.mtp_total_draft_tokens - if self.mtp_total_draft_tokens > 0 - else 0.0 - ) - return { - "total_draft_tokens": self.mtp_total_draft_tokens, - "total_accepted_tokens": self.mtp_total_accepted_tokens, - "acceptance_rate": acceptance_rate, - } - return { - "total_draft_tokens": 0, - "total_accepted_tokens": 0, - "acceptance_rate": 0.0, - } - - def reset_mtp_statistics(self): - if hasattr(self, "mtp_total_draft_tokens"): - self.mtp_total_draft_tokens = 0 - self.mtp_total_accepted_tokens = 0 - def _make_buffer( self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True ) -> CpuGpuBuffer: @@ -707,6 +684,7 @@ def stop_profiler(self): def debug(self, *args: Any): if self.rank == 0: logger.info(*args) + # logger.info(*args) def dummy_execution(self): """Execute dummy decode batch for DP synchronization.""" @@ -1317,7 +1295,7 @@ def prepare_model(self, batch: ScheduledBatch): temperatures = self.prepare_sample(batch) input_ids = self.tokenID_processor.prepare_input_ids(batch) - # self.debug(f"{input_ids=}") + self.debug(f"{input_ids=}") self.prepare_inputs(batch, input_ids) return ( input_ids, @@ -1405,6 +1383,8 @@ def postprocess( sampled_tokens.view(bs, -1), 1, next_token_locs.view(-1, 1) ).view(bs) self.tokenID_processor.prev_token_ids = next_token_ids + logger.info(f"{sampled_tokens=}") + logger.info(f"{next_token_locs=}") draft_token_ids = self.propose_draft_token_ids( batch, self.tokenID_processor.input_ids.gpu[ diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 9fcc3407d..174354b7d 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -369,13 +369,7 @@ def prefill_attention_triton( # value: [num_blocks, 1, num_kv_heads, head_size] attn_metadata = fwd_ctx.attn_metadata - ctx = fwd_ctx.context - block_tables = attn_metadata.block_tables - if ctx.is_prefill: - k_cache = k.unsqueeze(1) - v_cache = v.unsqueeze(1) - block_tables = attn_metadata.fake_block_tables o = torch.empty_like(q) descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1]) diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index c0071af19..1bcaf0aec 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -11,7 +11,10 @@ from atom.model_engine.scheduler import ScheduledBatch from atom.model_ops.attention_mha import Attention from atom.utils import CpuGpuBuffer -from atom.utils.block_convert import block_table_convert_triton +from atom.utils.block_convert import ( + block_table_convert_triton, + kv_indices_generate_triton, +) from atom.utils.forward_context import AttentionMetaData, Context from .backends import AttentionBackend, CommonAttentionBuilder @@ -94,7 +97,7 @@ def __init__(self, model_runner): ), "kv_indptr": CpuGpuBuffer(self.max_bs + 1, **i32_kwargs), "kv_indices": CpuGpuBuffer( - self.max_bs * self.max_num_blocks_per_seq // self.block_ratio, + self.max_bs * self.max_num_blocks_per_seq, **i32_kwargs, ), } @@ -209,11 +212,6 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): sum_blocks = kv_indptr[-1] if len(kv_indptr) > 0 else 0 sum_blocks_before_converted = cdiv(num_blocks_per_seq, self.block_ratio).sum() - var["kv_indices"].np[:sum_blocks_before_converted] = np.fromiter( - itertools.chain.from_iterable(block_tables), - dtype=np.int32, - count=sum_blocks_before_converted, - ) var["kv_indptr"].np[0] = 0 var["kv_indptr"].np[1 : scheduled_bs + 1] = kv_indptr var["kv_indptr"].np[scheduled_bs + 1 : bs + 1] = sum_blocks @@ -224,13 +222,22 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): ("cu_seqlens_q", bs + 1), ("block_tables", bs), ("kv_indptr", bs + 1), - ("kv_indices", sum_blocks_before_converted), ] ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used} if self.block_size == 1024: ctx_pa_ps = self.set_aiter_persistent_worker_buffers(bs) ctx.update(ctx_pa_ps) + + ctx["kv_indices"] = var["kv_indices"].gpu + max_seqlen_k = context_lens.max() + kv_indices_generate_triton( + ctx["block_tables"], + ctx["kv_indices"], + ctx["kv_indptr"], + self.block_ratio, + max_seqlen_k, + ) if self.block_ratio > 1 and "block_tables" in ctx: block_table_convert_triton( var["block_tables"].gpu[:bs], @@ -262,7 +269,7 @@ def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData: max_seqlen_q=var["max_qlen"], cu_seqlens_q=var["cu_seqlens_q"].gpu[: bs + 1], kv_indptr=var["kv_indptr"].gpu[: bs + 1], - kv_indices=var["kv_indices"].gpu[:], + kv_indices=var["kv_indices"].gpu, max_seqlen_k=self.model_runner.config.max_model_len, block_tables_converted=( var["block_tables_converted"].gpu[:bs] diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 356bac5e0..8e8fb2cd2 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -95,7 +95,7 @@ def __init__(self, model_runner): ), "kv_indptr": CpuGpuBuffer(self.max_bs + 1, **i32_kwargs), "kv_indices": CpuGpuBuffer( - self.max_bs * self.max_num_blocks_per_seq // self.block_ratio, + self.max_bs * self.max_num_blocks_per_seq, **i32_kwargs, ), "kv_last_page_lens": CpuGpuBuffer(self.max_bs, **i32_kwargs), @@ -248,6 +248,24 @@ def prepare_prefill(self, batch: ScheduledBatch): attn_metadata.kv_indptr[1 : bs + 1] = torch.cumsum( attn_metadata.context_lens, 0 ) + if attn_metadata.block_tables is None: + self.prepare_block_tables(batch) + attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) + kv_indices_generate_triton( + attn_metadata.block_tables, + attn_metadata.kv_indices, + attn_metadata.kv_indptr, + self.block_ratio, + attn_metadata.max_seqlen_k, + ) + else: + kv_indices_generate_triton( + attn_metadata.block_tables, + attn_metadata.kv_indices, + attn_metadata.kv_indptr, + self.block_ratio, + attn_metadata.max_seqlen_k, + ) return attn_metadata, positions @@ -389,7 +407,7 @@ def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData: max_seqlen_q=max_q_len, cu_seqlens_q=var["cu_seqlens_q"].gpu[: bs + 1], kv_indptr=var["kv_indptr"].gpu[: bs + 1], - kv_indices=var["kv_indices"].gpu[:], + kv_indices=var["kv_indices"].gpu, kv_last_page_lens=var["kv_last_page_lens"].gpu[:bs], sparse_kv_indptr=sparse_kv_indptr, block_tables_converted=( diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index c02b7a707..a2226f612 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -176,6 +176,7 @@ def prepare_prefill(self, batch: ScheduledBatch): if cu_seqlens_k[-1] > batch.total_tokens_num: # prefix cache self.prepare_block_tables(batch) var["positions"].np[:sum_scheduled_tokens] = positions + var["slot_mapping"].np[:sum_scheduled_tokens] = -1 var["slot_mapping"].np[: len(slot_mapping)] = slot_mapping cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True) var["context_lens"].np[:bs] = batch.context_lens[:bs] @@ -183,7 +184,7 @@ def prepare_prefill(self, batch: ScheduledBatch): dropout_p = 0.0 vars_used = [ ("cu_seqlens_q", bs + 1), - ("slot_mapping", len(slot_mapping)), + ("slot_mapping", sum_scheduled_tokens), ("context_lens", bs), ] diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index 1ce6bef31..630388b57 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn - from aiter import dtypes from aiter.dist.communication_op import tensor_model_parallel_all_reduce from atom.config import Config, QuantizationConfig @@ -13,7 +12,6 @@ from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled from atom.models.utils import IntermediateTensors - from atom.utils.decorators import support_torch_compile from transformers import DeepseekV2Config, DeepseekV3Config, PretrainedConfig @@ -78,9 +76,10 @@ def forward( spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None - masked_inputs_embeds = torch.where( - positions.unsqueeze(-1) == 0, 0, inputs_embeds - ) + # masked_inputs_embeds = torch.where( + # positions.unsqueeze(-1) == 0, 0, inputs_embeds + # ) + masked_inputs_embeds = inputs_embeds inputs_embeds = self.enorm(masked_inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) diff --git a/atom/spec_decode/eagle.py b/atom/spec_decode/eagle.py index 7ed9ce115..46ed257e4 100644 --- a/atom/spec_decode/eagle.py +++ b/atom/spec_decode/eagle.py @@ -114,7 +114,7 @@ def propose( input_ids = target_token_ids # input_ids[last_token_indices] = next_token_ids input_ids.scatter_(0, last_token_indices, next_token_ids) - positions = target_positions + positions = target_positions+1 hidden_states = target_hidden_states draft_token_ids = torch.empty( @@ -123,6 +123,17 @@ def propose( # return draft_token_ids.fill_(1) # for debug var = self.runner.forward_vars for i in range(self.mtp_k): + self.runner.debug(f"Draft step {i}, {hidden_states.shape=}") + self.runner.debug(f"Draft step {i}, {input_ids=}") + self.runner.debug(f"Draft step {i}, {positions=}") + self.runner.debug(f"Draft step {i}, {attn_metadata.kv_indptr=}") + # self.runner.debug( + # f"Draft step {i}, {attn_metadata.kv_indices[:attn_metadata.kv_indptr[-1].item()]=}" + # ) + # self.runner.debug(f"Draft step {i}, {attn_metadata.block_tables=}") + self.runner.debug(f"Draft step {i}, {attn_metadata.cu_seqlens_q=}") + self.runner.debug(f"Draft step {i}, {attn_metadata.slot_mapping=}") + self.runner.debug(f"Draft step {i}, {attn_metadata.context_lens=}") ret_hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -132,23 +143,30 @@ def propose( ret_hidden_states[last_token_indices] if i == 0 else ret_hidden_states ) logits = self.model.compute_logits(sample_hidden_states) + if sample_hidden_states.shape[0] ==4: + self.runner.debug(f"Draft step {i}, {sample_hidden_states=}") + self.runner.debug(f"Draft step {i}, {logits=}") new_draft_ids = logits.argmax(dim=-1) + self.runner.debug(f"Draft step {i}, {new_draft_ids=}") draft_token_ids[:, i] = new_draft_ids if i < self.mtp_k - 1: if i == 0: + attn_metadata.max_seqlen_q = 1 kv_indptr = var["kv_indptr"].gpu[: bs + 1] kv_indices = var["kv_indices"].gpu - slot_mapping = var["slot_mapping"].gpu[:bs] + slot_mapping = var["slot_mapping"].gpu[:bs* attn_metadata.max_seqlen_q] kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] + cu_seqlens_q = var["cu_seqlens_q"].gpu[: bs + 1] attn_metadata.kv_indptr = kv_indptr attn_metadata.kv_indices = kv_indices + attn_metadata.cu_seqlens_q = cu_seqlens_q attn_metadata.slot_mapping = slot_mapping attn_metadata.kv_last_page_lens = kv_last_page_lens - positions = positions[last_token_indices] - attn_metadata.max_seqlen_q = 1 - attn_metadata.cu_seqlens_q[: bs + 1] = self.arrange_bs[: bs + 1] + # kv_indptr[: bs + 1] += self.arrange_bs[: bs + 1] + cu_seqlens_q[: bs + 1] = self.arrange_bs[: bs + 1] kv_indptr[1 : bs + 1] -= torch.cumsum(num_reject_tokens, dim=0) + positions = positions[last_token_indices] context.is_prefill = False # update metadata diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 341cef129..771c34053 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -168,7 +168,6 @@ class AttentionMetaData: slot_mapping: Optional[torch.Tensor] = None context_lens: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None - fake_block_tables: Optional[torch.Tensor] = None dropout_p: float = 0.0 kv_indptr: Optional[torch.Tensor] = None From 34bb768a683f9268896354a25afcb40344dfd65a Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Sat, 28 Feb 2026 13:31:20 +0000 Subject: [PATCH 2/8] remove torch compile for ds_mtp --- atom/model_engine/model_runner.py | 7 +++--- atom/model_ops/attentions/aiter_mla.py | 18 ++++++------- atom/models/deepseek_mtp.py | 2 +- atom/spec_decode/eagle.py | 35 ++++++++++++++------------ 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 201926ef4..c0fbae17f 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -684,7 +684,6 @@ def stop_profiler(self): def debug(self, *args: Any): if self.rank == 0: logger.info(*args) - # logger.info(*args) def dummy_execution(self): """Execute dummy decode batch for DP synchronization.""" @@ -1295,7 +1294,7 @@ def prepare_model(self, batch: ScheduledBatch): temperatures = self.prepare_sample(batch) input_ids = self.tokenID_processor.prepare_input_ids(batch) - self.debug(f"{input_ids=}") + # self.debug(f"{input_ids=}") self.prepare_inputs(batch, input_ids) return ( input_ids, @@ -1383,8 +1382,8 @@ def postprocess( sampled_tokens.view(bs, -1), 1, next_token_locs.view(-1, 1) ).view(bs) self.tokenID_processor.prev_token_ids = next_token_ids - logger.info(f"{sampled_tokens=}") - logger.info(f"{next_token_locs=}") + # self.debug(f"{sampled_tokens=}") + # self.debug(f"{next_token_locs=}") draft_token_ids = self.propose_draft_token_ids( batch, self.tokenID_processor.input_ids.gpu[ diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 8e8fb2cd2..dab7c5caf 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -369,15 +369,15 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): self.block_ratio, max_seqlen_k, ) - if self.block_ratio > 1: - if "block_tables" in ctx: - block_table_convert_triton( - var["block_tables"].gpu[:bs], - var["block_tables_converted"].gpu[:bs], - var["context_lens"].gpu[:bs], - self.block_ratio, - ) - ctx["block_tables_converted"] = var["block_tables_converted"].gpu[:bs] + # if self.block_ratio > 1: + # if "block_tables" in ctx: + # block_table_convert_triton( + # var["block_tables"].gpu[:bs], + # var["block_tables_converted"].gpu[:bs], + # var["context_lens"].gpu[:bs], + # self.block_ratio, + # ) + # ctx["block_tables_converted"] = var["block_tables_converted"].gpu[:bs] attn_metadata = AttentionMetaData( dropout_p=dropout_p, max_seqlen_q=max_seqlen_q, diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index 630388b57..aca35f2d3 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -149,7 +149,7 @@ def compute_logits( return logits -@support_torch_compile +# @support_torch_compile class DeepSeekMTP(nn.Module): def __init__(self, atom_config: Config, prefix: str = ""): diff --git a/atom/spec_decode/eagle.py b/atom/spec_decode/eagle.py index 46ed257e4..504d0fbce 100644 --- a/atom/spec_decode/eagle.py +++ b/atom/spec_decode/eagle.py @@ -123,17 +123,17 @@ def propose( # return draft_token_ids.fill_(1) # for debug var = self.runner.forward_vars for i in range(self.mtp_k): - self.runner.debug(f"Draft step {i}, {hidden_states.shape=}") - self.runner.debug(f"Draft step {i}, {input_ids=}") - self.runner.debug(f"Draft step {i}, {positions=}") - self.runner.debug(f"Draft step {i}, {attn_metadata.kv_indptr=}") - # self.runner.debug( - # f"Draft step {i}, {attn_metadata.kv_indices[:attn_metadata.kv_indptr[-1].item()]=}" - # ) - # self.runner.debug(f"Draft step {i}, {attn_metadata.block_tables=}") - self.runner.debug(f"Draft step {i}, {attn_metadata.cu_seqlens_q=}") - self.runner.debug(f"Draft step {i}, {attn_metadata.slot_mapping=}") - self.runner.debug(f"Draft step {i}, {attn_metadata.context_lens=}") + # self.runner.debug(f"Draft step {i}, {hidden_states.shape=}") + # self.runner.debug(f"Draft step {i}, {input_ids=}") + # self.runner.debug(f"Draft step {i}, {positions=}") + # self.runner.debug(f"Draft step {i}, {attn_metadata.kv_indptr=}") + # # self.runner.debug( + # # f"Draft step {i}, {attn_metadata.kv_indices[:attn_metadata.kv_indptr[-1].item()]=}" + # # ) + # # self.runner.debug(f"Draft step {i}, {attn_metadata.block_tables=}") + # self.runner.debug(f"Draft step {i}, {attn_metadata.cu_seqlens_q=}") + # self.runner.debug(f"Draft step {i}, {attn_metadata.slot_mapping=}") + # self.runner.debug(f"Draft step {i}, {attn_metadata.context_lens=}") ret_hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -143,11 +143,14 @@ def propose( ret_hidden_states[last_token_indices] if i == 0 else ret_hidden_states ) logits = self.model.compute_logits(sample_hidden_states) - if sample_hidden_states.shape[0] ==4: - self.runner.debug(f"Draft step {i}, {sample_hidden_states=}") - self.runner.debug(f"Draft step {i}, {logits=}") + # if sample_hidden_states.shape[0] ==4: + # self.runner.debug(f"Draft step {i}, {sample_hidden_states=}") + # self.runner.debug(f"Draft step {i}, {logits=}") new_draft_ids = logits.argmax(dim=-1) - self.runner.debug(f"Draft step {i}, {new_draft_ids=}") + # topk_weights, topk_ids = torch.topk(logits, k=10, dim=1) + # self.runner.debug(f"Draft step {i}, {new_draft_ids=}") + # self.runner.debug(f"Draft step {i}, {topk_weights=}") + # self.runner.debug(f"Draft step {i}, {topk_ids=}") draft_token_ids[:, i] = new_draft_ids if i < self.mtp_k - 1: @@ -181,7 +184,7 @@ def propose( positions += 1 hidden_states = sample_hidden_states - # self.runner.debug(f"{draft_token_ids=}") + # self.runner.debug(f"final {draft_token_ids=}") # [batch_size, mtp_k] return draft_token_ids From 55ede6846b69005e98f73a87eaceb606982a2afa Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Sat, 28 Feb 2026 16:45:02 +0000 Subject: [PATCH 3/8] clean up --- atom/model_ops/attentions/aiter_mla.py | 16 +------------ atom/model_ops/sampler.py | 27 +++++++++++++--------- atom/spec_decode/eagle.py | 31 +++++++------------------- 3 files changed, 25 insertions(+), 49 deletions(-) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index dab7c5caf..651f7942f 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -311,22 +311,11 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): var["slot_mapping"].np[:sum_scheduled_tokens] = slot_mapping var["positions"].np[:sum_scheduled_tokens] = positions var["context_lens"].np[:scheduled_bs] = context_lens - # var["context_lens"].np[scheduled_bs:bs] = 0 num_blocks_per_seq = cdiv(context_lens, self.block_size) kv_indptr = np.cumsum(num_blocks_per_seq) sum_blocks = kv_indptr[-1] - # sum_blocks_before_converted = cdiv(num_blocks_per_seq, self.block_ratio).sum() - # def prepare_kv_indices(): - # dst = var["kv_indices"].np - # offset = 0 - # for bt in block_tables: - # n = len(bt) - # dst[offset : offset + n] = bt - # offset += n - - # prepare_kv_indices() self.prepare_block_tables(batch) var["kv_indptr"].np[1 : scheduled_bs + 1] = kv_indptr var["kv_indptr"].np[scheduled_bs + 1 : bs + 1] = sum_blocks @@ -340,12 +329,9 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): ("cu_seqlens_q", bs + 1), ("kv_indptr", bs + 1), ("block_tables", bs), - # ("kv_indices", sum_blocks), ("kv_last_page_lens", bs), ] if self.is_sparse: - # self.prepare_block_tables(batch) - # vars_used.append(("block_tables", bs)) index_topk = self.index_topk sparse_context_lens = np.clip(var["context_lens"].np[:bs], None, index_topk) var["sparse_kv_indptr"].np[1 : bs + 1] = np.cumsum( @@ -386,7 +372,7 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): ) positions = var["positions"].copy_to_gpu(sum_scheduled_tokens) - # if str(positions.device) == "cuda:0": + # if self.model_runner.rank == 0: # logger.info(f"context_lens: {ctx['context_lens']}") # # logger.info(f"{positions=}") # # for el, var in ctx.items(): diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 0ea4df9f3..0b0fab486 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +from functools import lru_cache + import torch from aiter import mixed_sample_outer_exponential from aiter.ops.triton.softmax import softmax @@ -8,6 +10,16 @@ from torch import nn +@lru_cache(maxsize=1) +def get_per_token_exponential(vocab_size: int, device) -> torch.Tensor: + """Returns a tensor of shape (1, vocab_size) filled with exponential random values. + This is key to deterministic inference, as it ensures that the same random values are used for each token across different runs. + """ + return torch.empty((1, vocab_size), dtype=torch.float, device=device).exponential_( + 1 + ) + + class Sampler(nn.Module): def __init__(self): @@ -19,22 +31,15 @@ def forward( logits: torch.Tensor, # (token_num, vocab_size) temperatures: torch.Tensor, # (token_num,) ) -> torch.Tensor: # (token_num,) - sampled_tokens = torch.empty( - logits.size(0), dtype=torch.int, device=logits.device - ) - exponential = ( - torch.empty((1, logits.shape[-1]), dtype=torch.float, device=logits.device) - .exponential_(1) - .expand(*logits.shape) + token_num, vocab_size = logits.shape + sampled_tokens = torch.empty(token_num, dtype=torch.int, device=logits.device) + exponential = get_per_token_exponential(vocab_size, logits.device).expand( + token_num, vocab_size ) mixed_sample_outer_exponential( sampled_tokens, logits, exponential, temperatures, eps=self.eps ) return sampled_tokens - logits = logits.float() - return torch.where( - temperatures == 0, self.greedy_sample(logits), self.random_sample(logits) - ).to(torch.int) def greedy_sample( self, logits: torch.Tensor # (token_num, vocab_size) diff --git a/atom/spec_decode/eagle.py b/atom/spec_decode/eagle.py index 504d0fbce..2d23f38f5 100644 --- a/atom/spec_decode/eagle.py +++ b/atom/spec_decode/eagle.py @@ -114,7 +114,7 @@ def propose( input_ids = target_token_ids # input_ids[last_token_indices] = next_token_ids input_ids.scatter_(0, last_token_indices, next_token_ids) - positions = target_positions+1 + positions = target_positions + 1 hidden_states = target_hidden_states draft_token_ids = torch.empty( @@ -123,34 +123,18 @@ def propose( # return draft_token_ids.fill_(1) # for debug var = self.runner.forward_vars for i in range(self.mtp_k): - # self.runner.debug(f"Draft step {i}, {hidden_states.shape=}") - # self.runner.debug(f"Draft step {i}, {input_ids=}") - # self.runner.debug(f"Draft step {i}, {positions=}") - # self.runner.debug(f"Draft step {i}, {attn_metadata.kv_indptr=}") - # # self.runner.debug( - # # f"Draft step {i}, {attn_metadata.kv_indices[:attn_metadata.kv_indptr[-1].item()]=}" - # # ) - # # self.runner.debug(f"Draft step {i}, {attn_metadata.block_tables=}") - # self.runner.debug(f"Draft step {i}, {attn_metadata.cu_seqlens_q=}") - # self.runner.debug(f"Draft step {i}, {attn_metadata.slot_mapping=}") - # self.runner.debug(f"Draft step {i}, {attn_metadata.context_lens=}") ret_hidden_states = self.model( input_ids=input_ids, positions=positions, hidden_states=hidden_states, ) sample_hidden_states = ( - ret_hidden_states[last_token_indices] if i == 0 else ret_hidden_states + torch.index_select(ret_hidden_states, 0, last_token_indices) + if i == 0 + else ret_hidden_states ) logits = self.model.compute_logits(sample_hidden_states) - # if sample_hidden_states.shape[0] ==4: - # self.runner.debug(f"Draft step {i}, {sample_hidden_states=}") - # self.runner.debug(f"Draft step {i}, {logits=}") new_draft_ids = logits.argmax(dim=-1) - # topk_weights, topk_ids = torch.topk(logits, k=10, dim=1) - # self.runner.debug(f"Draft step {i}, {new_draft_ids=}") - # self.runner.debug(f"Draft step {i}, {topk_weights=}") - # self.runner.debug(f"Draft step {i}, {topk_ids=}") draft_token_ids[:, i] = new_draft_ids if i < self.mtp_k - 1: @@ -158,7 +142,9 @@ def propose( attn_metadata.max_seqlen_q = 1 kv_indptr = var["kv_indptr"].gpu[: bs + 1] kv_indices = var["kv_indices"].gpu - slot_mapping = var["slot_mapping"].gpu[:bs* attn_metadata.max_seqlen_q] + slot_mapping = var["slot_mapping"].gpu[ + : bs * attn_metadata.max_seqlen_q + ] kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] cu_seqlens_q = var["cu_seqlens_q"].gpu[: bs + 1] attn_metadata.kv_indptr = kv_indptr @@ -166,10 +152,9 @@ def propose( attn_metadata.cu_seqlens_q = cu_seqlens_q attn_metadata.slot_mapping = slot_mapping attn_metadata.kv_last_page_lens = kv_last_page_lens - # kv_indptr[: bs + 1] += self.arrange_bs[: bs + 1] cu_seqlens_q[: bs + 1] = self.arrange_bs[: bs + 1] kv_indptr[1 : bs + 1] -= torch.cumsum(num_reject_tokens, dim=0) - positions = positions[last_token_indices] + positions = torch.gather(positions, 0, last_token_indices) context.is_prefill = False # update metadata From 2c2c0f5a9ab5284a7f0d0f994ccf44f2ae2ab1de Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 1 Mar 2026 16:37:12 +0000 Subject: [PATCH 4/8] reduce mtp stats interval --- atom/model_engine/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 46405682f..6cb54a6d9 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -29,7 +29,8 @@ class SpecStats: def __init__(self, mtp_k: int, log_interval: int = 1000): self.mtp_k = mtp_k - self._log_interval = log_interval + # Log every log_interval decode steps (in terms of draft tokens) + self._log_interval = log_interval * mtp_k self.total_draft_tokens: int = 0 self.distribution: dict[int, int] = {k: 0 for k in range(mtp_k + 1)} # Per-interval tracking @@ -44,7 +45,7 @@ def update(self, num_accepted_tokens: int) -> None: self.distribution[num_bonus] += 1 self._interval_distribution[num_bonus] += 1 - if self.total_draft_tokens % self._log_interval < self.mtp_k: + if self.total_draft_tokens % self._log_interval == 0: self._log() self._reset_interval() From 0babab8977da02c26c06cb2cc804c0b5f2867b9f Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Mon, 2 Mar 2026 16:40:33 +0000 Subject: [PATCH 5/8] update --- atom/model_engine/async_proc.py | 2 +- atom/model_engine/model_runner.py | 71 ++++++++++++++++--------------- atom/model_engine/scheduler.py | 34 ++++++++------- atom/model_ops/embed_head.py | 3 +- atom/models/deepseek_mtp.py | 3 +- 5 files changed, 60 insertions(+), 53 deletions(-) diff --git a/atom/model_engine/async_proc.py b/atom/model_engine/async_proc.py index cd14b640a..c60c78626 100644 --- a/atom/model_engine/async_proc.py +++ b/atom/model_engine/async_proc.py @@ -110,7 +110,7 @@ def busy_loop(self): if func is None: continue out = func(*args) - if out is not None: + if self.io_addrs[1] is not None and out is not None: self.io_queues[1].put_nowait(out) if func_name == "exit": break diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index c0fbae17f..f110e8bf1 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -175,57 +175,59 @@ def clean(self): None # Mapped to current batch order ) - def _process_token_id(self, token_id) -> tuple[int, ...]: - """Helper function: process a single token_id, handling list and non-list cases. - - Optimized: eliminates double traversal (removed 'in' check before 'index'). - Returns tuple for better performance and immutability. - """ - if isinstance(token_id, list): - try: - idx = token_id.index(-1) - return tuple(token_id[:idx]) - except ValueError: - # No -1 found, return the entire list as tuple - return tuple(token_id) - else: - return (token_id,) + @staticmethod + def _batch_process_token_ids(token_ids: list) -> list[tuple[int, ...]]: + """Batch process token_ids: vectorized -1 truncation using numpy.""" + arr = np.array(token_ids, dtype=np.int64) + mask = arr == -1 + if not mask.any(): + # No -1 sentinel in any row, convert each row to tuple directly + return [tuple(row) for row in arr.tolist()] + # Per-row: find first -1, truncate + # Use argmax on mask; rows without -1 get 0, disambiguate with ~mask.any(axis=1) + has_sentinel = mask.any(axis=1) + first_neg = mask.argmax(axis=1) + result = [] + rows = arr.tolist() + for i, row in enumerate(rows): + if has_sentinel[i]: + result.append(tuple(row[: first_neg[i]])) + else: + result.append(tuple(row)) + return result def prepare_sampled_ids( self, batch: ScheduledBatch, sampled_token_ids: torch.Tensor, sync_event: torch.cuda.Event, - ) -> dict[int, tuple[int, ...]]: + ) -> tuple[list[int], list[tuple[int, ...]]]: if not self.is_deferred_out: token_ids = sampled_token_ids.tolist() req_ids = batch.req_ids - ret = { - seq_id: self._process_token_id(token_id) - for seq_id, token_id in zip(req_ids, token_ids) - } - ret[-1] = 0 # is_deferred_out flag - return ret + if token_ids and isinstance(token_ids[0], list): + processed = self._batch_process_token_ids(token_ids) + else: + processed = [(tid,) for tid in token_ids] + return req_ids, processed token_ids = self.recv_async_output(self.token_ids_cpu) self.send_to_cpu_async(sampled_token_ids, self.token_ids_cpu, sync_event) - token_id_dict = {} + req_ids_out: list[int] = [] + processed_out: list[tuple[int, ...]] = [] self.prev_req_ids = None if self.prev_batch is not None: self.prev_req_ids = self.prev_batch.req_ids - token_id_dict = { - seq_id: self._process_token_id(token_id) - for seq_id, token_id in zip(self.prev_req_ids, token_ids) - } - else: - # first time, no previous tokens - token_ids = {} + req_ids_out = self.prev_req_ids + if token_ids and isinstance(token_ids[0], list): + processed_out = self._batch_process_token_ids(token_ids) + else: + processed_out = [(tid,) for tid in token_ids] self.prev_batch = batch self.prev_token_ids = sampled_token_ids - token_id_dict[-1] = 1 - return token_id_dict + return req_ids_out, processed_out def get_token_locations( self, batch: ScheduledBatch @@ -1366,7 +1368,7 @@ def postprocess( sampled_tokens = get_tp_group().broadcast(sampled_tokens, src=0) self.forward_done_event.record() - token_ids = self.tokenID_processor.prepare_sampled_ids( + req_ids_out, token_ids_out = self.tokenID_processor.prepare_sampled_ids( batch, sampled_tokens, self.forward_done_event ) @@ -1399,7 +1401,8 @@ def postprocess( prev_bonus_num = np.zeros(batch.total_seqs_num, dtype=np.int32) return ScheduledBatchOutput( - token_ids=token_ids, + req_ids=req_ids_out, + token_ids=token_ids_out, draft_token_ids=draft_token_ids, is_deferred_out=self.tokenID_processor.is_deferred_out, num_rejected=prev_rejected_num, diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 6cb54a6d9..30a490dd4 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -195,23 +195,27 @@ class ScheduledBatchOutput: def __init__( self, - token_ids: dict[int, tuple[int, ...]], - num_rejected: np.ndarray, - num_bonus: np.ndarray, + req_ids: list[int], + token_ids: list[tuple[int, ...]], + num_rejected: Optional[np.ndarray], + num_bonus: Optional[np.ndarray], draft_token_ids: Optional[np.ndarray], - # num_bonus_tokens is_deferred_out=False, ): - # TODO need refine - self.is_deferred_out = is_deferred_out - self.req_ids = list(token_ids.keys()) + self.req_ids = req_ids self.token_ids = token_ids self.draft_token_ids = draft_token_ids self.num_rejected = num_rejected self.num_bonus = num_bonus - # logger.info(f"ScheduledBatchOutput: req_ids={self.req_ids}") - # assert len(self.req_ids) - 1 == len(draft_token_ids) - # self.num_bonus_tokens = num_bonus_tokens # num per req + self.is_deferred_out = is_deferred_out + # O(1) lookup: req_id -> index (lazy-built on first access) + self._req_id_to_idx: Optional[dict[int, int]] = None + + def get_idx(self, req_id: int) -> Optional[int]: + """O(1) lookup of request index by id.""" + if self._req_id_to_idx is None: + self._req_id_to_idx = {rid: i for i, rid in enumerate(self.req_ids)} + return self._req_id_to_idx.get(req_id) class Scheduler: @@ -257,7 +261,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_batched_tokens = 0 num_scheduled_tokens: list[int] = [] - scheduled_spec_decode_tokens: dict[int, list[int]] = {} + scheduled_spec_decode_tokens: dict[int, np.ndarray] = {} if not self.running and not self.waiting: # self.block_manager.reset() @@ -376,13 +380,13 @@ def postprocess( num_placeholder += 1 for seq in self.running: # Update the running status - if seq.id not in fwd_output.req_ids: + idx = fwd_output.get_idx(seq.id) + if idx is None: continue - token_ids = prev_token_ids[seq.id] + token_ids = prev_token_ids[idx] num_new_token = len(token_ids) if self.spec_stats: self.spec_stats.update(num_new_token) - idx = fwd_output.req_ids.index(seq.id) if is_deferred_out or self.use_spec: num_rejected = fwd_output.num_rejected[idx] num_bonus = fwd_output.num_bonus[idx] @@ -402,7 +406,7 @@ def postprocess( new_tokens = token_ids if self.mtp_k > 0: - idx = fwd_output.req_ids.index(seq.id) + # idx already resolved above via get_idx seq.spec_token_ids = draft_token_ids[idx] if seq.num_completion_tokens == 1 and seq.first_token_time == 0.0: diff --git a/atom/model_ops/embed_head.py b/atom/model_ops/embed_head.py index 77281d4f9..e5b4b3d10 100644 --- a/atom/model_ops/embed_head.py +++ b/atom/model_ops/embed_head.py @@ -6,9 +6,8 @@ from aiter.dist.communication_op import tensor_model_parallel_all_gather from aiter.dist.parallel_state import get_tp_group from aiter.tuned_gemm import tgemm -from torch import nn - from atom.utils.forward_context import ForwardContext, get_forward_context +from torch import nn class VocabParallelEmbedding(nn.Module): diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index aca35f2d3..1d603c8e6 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -12,7 +12,8 @@ from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled from atom.models.utils import IntermediateTensors -from atom.utils.decorators import support_torch_compile + +# from atom.utils.decorators import support_torch_compile from transformers import DeepseekV2Config, DeepseekV3Config, PretrainedConfig from .deepseek_v2 import DeepseekV2DecoderLayer From 54efce310bffe7338463dbfb5d161c1313aa5f19 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Mon, 2 Mar 2026 16:44:15 +0000 Subject: [PATCH 6/8] fix lint --- atom/model_ops/attentions/aiter_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 1bcaf0aec..b285a0838 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -210,7 +210,6 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): num_blocks_per_seq = cdiv(context_lens, self.block_size) kv_indptr = np.cumsum(num_blocks_per_seq) sum_blocks = kv_indptr[-1] if len(kv_indptr) > 0 else 0 - sum_blocks_before_converted = cdiv(num_blocks_per_seq, self.block_ratio).sum() var["kv_indptr"].np[0] = 0 var["kv_indptr"].np[1 : scheduled_bs + 1] = kv_indptr From c664137f57f14914afa7d283d6fbf03e1d2ac524 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Tue, 3 Mar 2026 04:01:31 +0000 Subject: [PATCH 7/8] more cleanup --- atom/model_ops/attentions/aiter_mla.py | 22 +++++++--------------- atom/spec_decode/eagle.py | 1 - 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 651f7942f..fa4ec6c36 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -251,21 +251,13 @@ def prepare_prefill(self, batch: ScheduledBatch): if attn_metadata.block_tables is None: self.prepare_block_tables(batch) attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) - kv_indices_generate_triton( - attn_metadata.block_tables, - attn_metadata.kv_indices, - attn_metadata.kv_indptr, - self.block_ratio, - attn_metadata.max_seqlen_k, - ) - else: - kv_indices_generate_triton( - attn_metadata.block_tables, - attn_metadata.kv_indices, - attn_metadata.kv_indptr, - self.block_ratio, - attn_metadata.max_seqlen_k, - ) + kv_indices_generate_triton( + attn_metadata.block_tables, + attn_metadata.kv_indices, + attn_metadata.kv_indptr, + self.block_ratio, + attn_metadata.max_seqlen_k, + ) return attn_metadata, positions diff --git a/atom/spec_decode/eagle.py b/atom/spec_decode/eagle.py index 2d23f38f5..936ebc86b 100644 --- a/atom/spec_decode/eagle.py +++ b/atom/spec_decode/eagle.py @@ -107,7 +107,6 @@ def propose( forward_context = get_forward_context() context = forward_context.context attn_metadata = forward_context.attn_metadata - context.is_draft = True bs = context.batch_size assert self.runner is not None From 7a318ed388076529d504e326969d44f45eaf35cf Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Tue, 3 Mar 2026 09:07:18 +0000 Subject: [PATCH 8/8] fix --- atom/spec_decode/eagle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/spec_decode/eagle.py b/atom/spec_decode/eagle.py index 936ebc86b..88fd9f204 100644 --- a/atom/spec_decode/eagle.py +++ b/atom/spec_decode/eagle.py @@ -108,6 +108,7 @@ def propose( context = forward_context.context attn_metadata = forward_context.attn_metadata bs = context.batch_size + context.is_draft = True assert self.runner is not None input_ids = target_token_ids