Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion atom/model_engine/async_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def busy_loop(self):
if func is None:
continue
out = func(*args)
if out is not None:
if self.io_addrs[1] is not None and out is not None:
self.io_queues[1].put_nowait(out)
if func_name == "exit":
break
Expand Down
10 changes: 2 additions & 8 deletions atom/model_engine/engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,8 @@ def stop_profiler(self):
logger.info("Profiler stopped.")

def print_mtp_statistics(self):
stats = self.runner_mgr.call_func("get_mtp_statistics", wait_out=True)
if stats and stats.get("total_draft_tokens", 0) > 0:
logger.info(f"\n{'='*50}")
logger.info("MTP (Multi-Token Prediction) Statistics:")
logger.info(f" Total draft tokens: {stats['total_draft_tokens']}")
logger.info(f" Accepted tokens: {stats['total_accepted_tokens']}")
logger.info(f" Acceptance rate: {stats['acceptance_rate']:.2%}")
logger.info(f"{'='*50}\n")
if self.scheduler.spec_stats is not None:
self.scheduler.spec_stats._log()
else:
logger.info(
"\n[MTP Stats] No MTP statistics available (MTP not enabled or no tokens processed)\n"
Expand Down
96 changes: 39 additions & 57 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,57 +175,59 @@ def clean(self):
None # Mapped to current batch order
)

def _process_token_id(self, token_id) -> tuple[int, ...]:
"""Helper function: process a single token_id, handling list and non-list cases.

Optimized: eliminates double traversal (removed 'in' check before 'index').
Returns tuple for better performance and immutability.
"""
if isinstance(token_id, list):
try:
idx = token_id.index(-1)
return tuple(token_id[:idx])
except ValueError:
# No -1 found, return the entire list as tuple
return tuple(token_id)
else:
return (token_id,)
@staticmethod
def _batch_process_token_ids(token_ids: list) -> list[tuple[int, ...]]:
"""Batch process token_ids: vectorized -1 truncation using numpy."""
arr = np.array(token_ids, dtype=np.int64)
mask = arr == -1
if not mask.any():
# No -1 sentinel in any row, convert each row to tuple directly
return [tuple(row) for row in arr.tolist()]
# Per-row: find first -1, truncate
# Use argmax on mask; rows without -1 get 0, disambiguate with ~mask.any(axis=1)
has_sentinel = mask.any(axis=1)
first_neg = mask.argmax(axis=1)
result = []
rows = arr.tolist()
for i, row in enumerate(rows):
if has_sentinel[i]:
result.append(tuple(row[: first_neg[i]]))
else:
result.append(tuple(row))
return result

def prepare_sampled_ids(
self,
batch: ScheduledBatch,
sampled_token_ids: torch.Tensor,
sync_event: torch.cuda.Event,
) -> dict[int, tuple[int, ...]]:
) -> tuple[list[int], list[tuple[int, ...]]]:
if not self.is_deferred_out:
token_ids = sampled_token_ids.tolist()
req_ids = batch.req_ids
ret = {
seq_id: self._process_token_id(token_id)
for seq_id, token_id in zip(req_ids, token_ids)
}
ret[-1] = 0 # is_deferred_out flag
return ret
if token_ids and isinstance(token_ids[0], list):
processed = self._batch_process_token_ids(token_ids)
else:
processed = [(tid,) for tid in token_ids]
return req_ids, processed

token_ids = self.recv_async_output(self.token_ids_cpu)
self.send_to_cpu_async(sampled_token_ids, self.token_ids_cpu, sync_event)
token_id_dict = {}
req_ids_out: list[int] = []
processed_out: list[tuple[int, ...]] = []
self.prev_req_ids = None
if self.prev_batch is not None:
self.prev_req_ids = self.prev_batch.req_ids
token_id_dict = {
seq_id: self._process_token_id(token_id)
for seq_id, token_id in zip(self.prev_req_ids, token_ids)
}
else:
# first time, no previous tokens
token_ids = {}
req_ids_out = self.prev_req_ids
if token_ids and isinstance(token_ids[0], list):
processed_out = self._batch_process_token_ids(token_ids)
else:
processed_out = [(tid,) for tid in token_ids]

self.prev_batch = batch
self.prev_token_ids = sampled_token_ids
token_id_dict[-1] = 1

return token_id_dict
return req_ids_out, processed_out

def get_token_locations(
self, batch: ScheduledBatch
Expand Down Expand Up @@ -611,29 +613,6 @@ def is_qwen_next(self) -> bool:
return True
return False

def get_mtp_statistics(self) -> dict:
if hasattr(self, "mtp_total_draft_tokens"):
acceptance_rate = (
self.mtp_total_accepted_tokens / self.mtp_total_draft_tokens
if self.mtp_total_draft_tokens > 0
else 0.0
)
return {
"total_draft_tokens": self.mtp_total_draft_tokens,
"total_accepted_tokens": self.mtp_total_accepted_tokens,
"acceptance_rate": acceptance_rate,
}
return {
"total_draft_tokens": 0,
"total_accepted_tokens": 0,
"acceptance_rate": 0.0,
}

def reset_mtp_statistics(self):
if hasattr(self, "mtp_total_draft_tokens"):
self.mtp_total_draft_tokens = 0
self.mtp_total_accepted_tokens = 0

def _make_buffer(
self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True
) -> CpuGpuBuffer:
Expand Down Expand Up @@ -1389,7 +1368,7 @@ def postprocess(
sampled_tokens = get_tp_group().broadcast(sampled_tokens, src=0)

self.forward_done_event.record()
token_ids = self.tokenID_processor.prepare_sampled_ids(
req_ids_out, token_ids_out = self.tokenID_processor.prepare_sampled_ids(
batch, sampled_tokens, self.forward_done_event
)

Expand All @@ -1405,6 +1384,8 @@ def postprocess(
sampled_tokens.view(bs, -1), 1, next_token_locs.view(-1, 1)
).view(bs)
self.tokenID_processor.prev_token_ids = next_token_ids
# self.debug(f"{sampled_tokens=}")
# self.debug(f"{next_token_locs=}")
draft_token_ids = self.propose_draft_token_ids(
batch,
self.tokenID_processor.input_ids.gpu[
Expand All @@ -1420,7 +1401,8 @@ def postprocess(
prev_bonus_num = np.zeros(batch.total_seqs_num, dtype=np.int32)

return ScheduledBatchOutput(
token_ids=token_ids,
req_ids=req_ids_out,
token_ids=token_ids_out,
draft_token_ids=draft_token_ids,
is_deferred_out=self.tokenID_processor.is_deferred_out,
num_rejected=prev_rejected_num,
Expand Down
39 changes: 22 additions & 17 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class SpecStats:

def __init__(self, mtp_k: int, log_interval: int = 1000):
self.mtp_k = mtp_k
self._log_interval = log_interval
# Log every log_interval decode steps (in terms of draft tokens)
self._log_interval = log_interval * mtp_k
self.total_draft_tokens: int = 0
self.distribution: dict[int, int] = {k: 0 for k in range(mtp_k + 1)}
# Per-interval tracking
Expand All @@ -44,7 +45,7 @@ def update(self, num_accepted_tokens: int) -> None:
self.distribution[num_bonus] += 1
self._interval_distribution[num_bonus] += 1

if self.total_draft_tokens % self._log_interval < self.mtp_k:
if self.total_draft_tokens % self._log_interval == 0:
self._log()
self._reset_interval()

Expand Down Expand Up @@ -194,23 +195,27 @@ class ScheduledBatchOutput:

def __init__(
self,
token_ids: dict[int, tuple[int, ...]],
num_rejected: np.ndarray,
num_bonus: np.ndarray,
req_ids: list[int],
token_ids: list[tuple[int, ...]],
num_rejected: Optional[np.ndarray],
num_bonus: Optional[np.ndarray],
draft_token_ids: Optional[np.ndarray],
# num_bonus_tokens
is_deferred_out=False,
):
# TODO need refine
self.is_deferred_out = is_deferred_out
self.req_ids = list(token_ids.keys())
self.req_ids = req_ids
self.token_ids = token_ids
self.draft_token_ids = draft_token_ids
self.num_rejected = num_rejected
self.num_bonus = num_bonus
# logger.info(f"ScheduledBatchOutput: req_ids={self.req_ids}")
# assert len(self.req_ids) - 1 == len(draft_token_ids)
# self.num_bonus_tokens = num_bonus_tokens # num per req
self.is_deferred_out = is_deferred_out
# O(1) lookup: req_id -> index (lazy-built on first access)
self._req_id_to_idx: Optional[dict[int, int]] = None

def get_idx(self, req_id: int) -> Optional[int]:
"""O(1) lookup of request index by id."""
if self._req_id_to_idx is None:
self._req_id_to_idx = {rid: i for i, rid in enumerate(self.req_ids)}
return self._req_id_to_idx.get(req_id)


class Scheduler:
Expand Down Expand Up @@ -256,7 +261,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
num_batched_tokens = 0

num_scheduled_tokens: list[int] = []
scheduled_spec_decode_tokens: dict[int, list[int]] = {}
scheduled_spec_decode_tokens: dict[int, np.ndarray] = {}

if not self.running and not self.waiting:
# self.block_manager.reset()
Expand Down Expand Up @@ -375,13 +380,13 @@ def postprocess(
num_placeholder += 1
for seq in self.running:
# Update the running status
if seq.id not in fwd_output.req_ids:
idx = fwd_output.get_idx(seq.id)
if idx is None:
continue
token_ids = prev_token_ids[seq.id]
token_ids = prev_token_ids[idx]
num_new_token = len(token_ids)
if self.spec_stats:
self.spec_stats.update(num_new_token)
idx = fwd_output.req_ids.index(seq.id)
if is_deferred_out or self.use_spec:
num_rejected = fwd_output.num_rejected[idx]
num_bonus = fwd_output.num_bonus[idx]
Expand All @@ -401,7 +406,7 @@ def postprocess(
new_tokens = token_ids

if self.mtp_k > 0:
idx = fwd_output.req_ids.index(seq.id)
# idx already resolved above via get_idx
seq.spec_token_ids = draft_token_ids[idx]

if seq.num_completion_tokens == 1 and seq.first_token_time == 0.0:
Expand Down
6 changes: 0 additions & 6 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,7 @@ def prefill_attention_triton(
# value: [num_blocks, 1, num_kv_heads, head_size]

attn_metadata = fwd_ctx.attn_metadata
ctx = fwd_ctx.context

block_tables = attn_metadata.block_tables
if ctx.is_prefill:
k_cache = k.unsqueeze(1)
v_cache = v.unsqueeze(1)
block_tables = attn_metadata.fake_block_tables

o = torch.empty_like(q)
descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1])
Expand Down
26 changes: 16 additions & 10 deletions atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from atom.model_engine.scheduler import ScheduledBatch
from atom.model_ops.attention_mha import Attention
from atom.utils import CpuGpuBuffer
from atom.utils.block_convert import block_table_convert_triton
from atom.utils.block_convert import (
block_table_convert_triton,
kv_indices_generate_triton,
)
from atom.utils.forward_context import AttentionMetaData, Context

from .backends import AttentionBackend, CommonAttentionBuilder
Expand Down Expand Up @@ -94,7 +97,7 @@ def __init__(self, model_runner):
),
"kv_indptr": CpuGpuBuffer(self.max_bs + 1, **i32_kwargs),
"kv_indices": CpuGpuBuffer(
self.max_bs * self.max_num_blocks_per_seq // self.block_ratio,
self.max_bs * self.max_num_blocks_per_seq,
**i32_kwargs,
),
}
Expand Down Expand Up @@ -207,13 +210,7 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int):
num_blocks_per_seq = cdiv(context_lens, self.block_size)
kv_indptr = np.cumsum(num_blocks_per_seq)
sum_blocks = kv_indptr[-1] if len(kv_indptr) > 0 else 0
sum_blocks_before_converted = cdiv(num_blocks_per_seq, self.block_ratio).sum()

var["kv_indices"].np[:sum_blocks_before_converted] = np.fromiter(
itertools.chain.from_iterable(block_tables),
dtype=np.int32,
count=sum_blocks_before_converted,
)
var["kv_indptr"].np[0] = 0
var["kv_indptr"].np[1 : scheduled_bs + 1] = kv_indptr
var["kv_indptr"].np[scheduled_bs + 1 : bs + 1] = sum_blocks
Expand All @@ -224,13 +221,22 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int):
("cu_seqlens_q", bs + 1),
("block_tables", bs),
("kv_indptr", bs + 1),
("kv_indices", sum_blocks_before_converted),
]

ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used}
if self.block_size == 1024:
ctx_pa_ps = self.set_aiter_persistent_worker_buffers(bs)
ctx.update(ctx_pa_ps)

ctx["kv_indices"] = var["kv_indices"].gpu
max_seqlen_k = context_lens.max()
kv_indices_generate_triton(
ctx["block_tables"],
ctx["kv_indices"],
ctx["kv_indptr"],
self.block_ratio,
max_seqlen_k,
)
if self.block_ratio > 1 and "block_tables" in ctx:
block_table_convert_triton(
var["block_tables"].gpu[:bs],
Expand Down Expand Up @@ -262,7 +268,7 @@ def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData:
max_seqlen_q=var["max_qlen"],
cu_seqlens_q=var["cu_seqlens_q"].gpu[: bs + 1],
kv_indptr=var["kv_indptr"].gpu[: bs + 1],
kv_indices=var["kv_indices"].gpu[:],
kv_indices=var["kv_indices"].gpu,
max_seqlen_k=self.model_runner.config.max_model_len,
block_tables_converted=(
var["block_tables_converted"].gpu[:bs]
Expand Down
Loading