Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 33 additions & 39 deletions atom/model_engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,45 +122,21 @@ def deallocate(self, seq: Sequence):
self._deallocate_block(block_id)
seq.mamba_block_table.clear()

def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool:
needed_blocks = (
(len(seq) + num_new_tokens + self.block_size - 1) // self.block_size
)
blocks_to_allocate = needed_blocks - len(seq.block_table)
return len(self.free_block_ids) >= max(0, blocks_to_allocate)

def may_append(self, seq: Sequence, num_new_tokens: int = 1):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
seq_len = len(seq)
# Check if we need to allocate a new block
# When len(seq) % block_size == 1, we need a new block for the next token
# When block_size == 1, every token needs a new block
if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1:
needed_blocks = (seq_len + self.block_size - 1) // self.block_size
while len(block_table) < needed_blocks:
# For block_size == 1, we need to update hash for each new block
# For block_size > 1, the previous block should have hash != -1 (unless it's the first block)
if self.block_size == 1:
# Allocate new block and update hash immediately (like allocate does for full blocks)
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
token_ids = [seq[-1]]
prefix = (
self.blocks[block_table[-2]].hash
if len(block_table) > 1
else -1
)
h = self.compute_hash(token_ids, prefix)
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
else:
# For block_size > 1, we only allocate new block when needed
# The hash will be updated when the block becomes full
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
last_block = block
elif seq_len % self.block_size == 0:
# Last block is now full, update its hash (similar to allocate)
# TODO: fix hash

# Phase 1: If the last block just became full, register its hash
# so it can be reused for prefix caching on future sequences.
if seq_len % self.block_size == 0 and self.block_size > 1:
token_ids = seq.block(seq.num_blocks - 1)
if len(token_ids) == self.block_size:
prefix = (
Expand All @@ -169,8 +145,26 @@ def may_append(self, seq: Sequence, num_new_tokens: int = 1):
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
pass
# Last block is not full and not at the boundary
# Hash remains -1 until block is full (consistent with allocate logic)
# assert last_block.hash == -1, last_block.block_id

# Phase 2: Allocate new blocks for the upcoming tokens.
needed_blocks = (
(seq_len + num_new_tokens + self.block_size - 1) // self.block_size
)
while len(block_table) < needed_blocks:
if self.block_size == 1:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
token_ids = [seq[-1]]
prefix = (
self.blocks[block_table[-2]].hash
if len(block_table) > 1
else -1
)
h = self.compute_hash(token_ids, prefix)
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
else:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
5 changes: 3 additions & 2 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,10 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:

# decode
num_seqs_decode = 0
num_new_tokens = self.mtp_k + 1
while self.running and num_seqs_decode < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
while not self.block_manager.can_append(seq, num_new_tokens):
if self.running:
self.preempt(self.running.pop())
else:
Expand All @@ -319,7 +320,6 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
if seq.spec_token_ids.size > 0:
scheduled_spec_decode_tokens[seq.id] = seq.spec_token_ids
num_seqs_decode += 1
num_new_tokens = self.mtp_k + 1
self.block_manager.may_append(seq, num_new_tokens)
scheduled_seqs[seq.id] = seq
seq.type = SequenceType.DECODE
Expand Down Expand Up @@ -382,6 +382,7 @@ def postprocess(
if self.spec_stats:
self.spec_stats.update(num_new_token)
idx = fwd_output.req_ids.index(seq.id)
num_rejected = 0
if is_deferred_out or self.use_spec:
num_rejected = fwd_output.num_rejected[idx]
num_bonus = fwd_output.num_bonus[idx]
Expand Down
195 changes: 193 additions & 2 deletions tests/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,85 @@ def test_cannot_append_no_free(self, seq_factory):
seq.append_token(5)
assert not bm.can_append(seq)

def test_at_block_boundary_needs_block(self, seq_factory):
"""seq_len=4, block_size=4 → at boundary, 1 new token needs 1 new block."""
cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4])
bm.allocate(seq)
assert bm.can_append(seq, num_new_tokens=1)

def test_at_block_boundary_no_free(self, seq_factory):
"""seq_len=4, block_size=4, 0 free blocks → cannot append."""
cfg = MockConfig(num_kvcache_blocks=1, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4])
bm.allocate(seq)
assert not bm.can_append(seq, num_new_tokens=1)

def test_multi_token_needs_two_blocks(self, seq_factory):
"""seq_len=7, block_size=4, num_new_tokens=4 → total 11, needs 3 blocks,
2 allocated, need 1 more. With only 1 free block, should succeed."""
cfg = MockConfig(num_kvcache_blocks=3, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4, 5, 6, 7])
bm.allocate(seq)
assert len(seq.block_table) == 2
assert bm.can_append(seq, num_new_tokens=4)

def test_multi_token_not_enough_free(self, seq_factory):
"""seq_len=5, block_size=4, num_new_tokens=4 → total 9, needs 3 blocks,
2 allocated, need 1 more. With 0 free blocks, should fail."""
cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4, 5])
bm.allocate(seq)
assert len(seq.block_table) == 2
assert not bm.can_append(seq, num_new_tokens=4)

def test_multi_token_enough_free(self, seq_factory):
"""seq_len=7, block_size=4, num_new_tokens=4 → needs 1 more block.
With enough free blocks, should succeed."""
cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4, 5, 6, 7])
bm.allocate(seq)
assert bm.can_append(seq, num_new_tokens=4)

def test_multi_token_crosses_two_boundaries(self, seq_factory):
"""seq_len=5, block_size=4, num_new_tokens=4 → total 9, needs 3 blocks,
but only 2 allocated. Need 1 more free block."""
cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4, 5])
bm.allocate(seq)
assert len(seq.block_table) == 2
assert bm.can_append(seq, num_new_tokens=4)

def test_multi_token_exact_fit(self, seq_factory):
"""seq_len=4, block_size=4, num_new_tokens=4 → total 8, needs 2 blocks.
With exactly 1 free block, should succeed."""
cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4])
bm.allocate(seq)
assert bm.can_append(seq, num_new_tokens=4)

def test_multi_token_one_short(self, seq_factory):
"""seq_len=4, block_size=4, num_new_tokens=5 → total 9, needs 3 blocks.
With only 1 free block, should fail."""
cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4])
bm.allocate(seq)
assert not bm.can_append(seq, num_new_tokens=5)


class TestMayAppend:
def test_no_new_block_within_boundary(self, block_manager, seq_factory):
seq = seq_factory([1, 2, 3])
seq = seq_factory([1, 2])
block_manager.allocate(seq)
seq.append_token(4)
seq.append_token(3)
block_manager.may_append(seq)
assert len(seq.block_table) == 1

Expand All @@ -166,10 +239,128 @@ def test_new_block_on_boundary_crossing(self, block_manager, seq_factory):
assert len(seq.block_table) == 2

def test_block_size_1(self, seq_factory):
"""block_size=1: seq=[1,2] → 2 blocks. append(3) → seq_len=3.
may_append(num_new_tokens=1) → needs ceil((3+1)/1) = 4 blocks."""
cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=1)
bm = BlockManager(cfg)
seq = seq_factory([1, 2], block_size=1)
bm.allocate(seq)
seq.append_token(3)
bm.may_append(seq)
assert len(seq.block_table) == 4

def test_multi_token_allocates_enough_blocks(self, seq_factory):
"""seq_len=5, block_size=4, num_new_tokens=4 → total 9, needs 3 blocks."""
cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4, 5])
bm.allocate(seq)
assert len(seq.block_table) == 2
bm.may_append(seq, num_new_tokens=4)
assert len(seq.block_table) == 3

def test_multi_token_at_boundary(self, seq_factory):
"""seq_len=4, block_size=4, num_new_tokens=4 → total 8, needs 2 blocks."""
cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4])
bm.allocate(seq)
assert len(seq.block_table) == 1
bm.may_append(seq, num_new_tokens=4)
assert len(seq.block_table) == 2

def test_multi_token_crosses_two_boundaries(self, seq_factory):
"""seq_len=4, block_size=4, num_new_tokens=5 → total 9, needs 3 blocks."""
cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3, 4])
bm.allocate(seq)
assert len(seq.block_table) == 1
bm.may_append(seq, num_new_tokens=5)
assert len(seq.block_table) == 3

def test_hash_registered_at_boundary(self, seq_factory):
"""When seq fills a block exactly, may_append should register its hash."""
cfg = MockConfig(
num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True
)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3])
bm.allocate(seq)
seq.append_token(4)
bm.may_append(seq, num_new_tokens=1)
last_block = bm.blocks[seq.block_table[0]]
assert last_block.hash != -1
assert last_block.hash in bm.hash_to_block_id

def test_block_size_1_multi_token(self, seq_factory):
"""block_size=1: seq=[1,2] → 2 blocks. append(3) → seq_len=3.
may_append(num_new_tokens=3) → needs ceil((3+3)/1) = 6 blocks."""
cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=1)
bm = BlockManager(cfg)
seq = seq_factory([1, 2], block_size=1)
bm.allocate(seq)
assert len(seq.block_table) == 2
seq.append_token(3)
bm.may_append(seq, num_new_tokens=3)
assert len(seq.block_table) == 6


# ── Prefix caching during decode ──────────────────────────────────────────


class TestPrefixCachingDecode:
def test_hash_registered_during_decode(self, seq_factory):
"""Block completed during decode should register its hash for reuse."""
cfg = MockConfig(
num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True
)
bm = BlockManager(cfg)
seq = seq_factory([1, 2, 3])
bm.allocate(seq)
seq.append_token(4)
bm.may_append(seq, num_new_tokens=1)

block = bm.blocks[seq.block_table[0]]
expected_hash = BlockManager.compute_hash([1, 2, 3, 4])
assert block.hash == expected_hash
assert bm.hash_to_block_id[expected_hash] == block.block_id

def test_decode_block_reused_by_new_sequence(self, seq_factory):
"""A block completed and hashed during decode should be a cache hit
for a new sequence with the same prefix."""
cfg = MockConfig(
num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True
)
bm = BlockManager(cfg)

s1 = seq_factory([1, 2, 3])
bm.allocate(s1)
s1.append_token(4)
bm.may_append(s1, num_new_tokens=1)
bm.deallocate(s1)

s2 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8])
bm.allocate(s2)
assert s2.num_cached_tokens == 4

def test_multi_step_decode_builds_prefix(self, seq_factory):
"""Simulate multiple decode steps filling blocks, then verify
a new sequence gets cache hits on the completed blocks."""
cfg = MockConfig(
num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True
)
bm = BlockManager(cfg)

seq = seq_factory([1, 2, 3, 4])
bm.allocate(seq)

for tok in [5, 6, 7, 8]:
seq.append_token(tok)
bm.may_append(seq, num_new_tokens=1)

bm.deallocate(seq)

s2 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8, 9])
bm.allocate(s2)
assert s2.num_cached_tokens == 8
12 changes: 10 additions & 2 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def _prefill(self, scheduler, seq):

def _output(self, seq_id, tokens):
return ScheduledBatchOutput(
token_ids={seq_id: tuple(tokens)}, draft_token_ids=None
token_ids={seq_id: tuple(tokens)},
num_rejected=None,
num_bonus=None,
draft_token_ids=None,
)

def test_appends_token(self, scheduler, seq_factory):
Expand Down Expand Up @@ -166,7 +169,12 @@ def test_stop_token_ids(self, seq_factory):
sched.schedule()
finished = sched.postprocess(
list(sched.running),
ScheduledBatchOutput(token_ids={seq.id: (99,)}, draft_token_ids=None),
ScheduledBatchOutput(
token_ids={seq.id: (99,)},
num_rejected=None,
num_bonus=None,
draft_token_ids=None,
),
)
assert len(finished) == 1
assert "stop_99" in finished[0].leave_reason
Expand Down
Loading