Skip to content

fix: resolve prefix caching crashes with MTP speculative decoding#234

Draft
valarLip wants to merge 1 commit intomainfrom
ds_prefix_cache
Draft

fix: resolve prefix caching crashes with MTP speculative decoding#234
valarLip wants to merge 1 commit intomainfrom
ds_prefix_cache

Conversation

@valarLip
Copy link
Collaborator

Fix GPU memory access fault caused by double conversion of block_tables in cached prefill path. kv_indices_generate_triton applies block_ratio internally, but was receiving already-converted block_tables (via block_tables_converted), causing indices to be multiplied by block_ratio twice (e.g. block_id256 instead of block_id16), exceeding KV cache bounds.

Key changes:

  • Use raw block_tables for kv_indices generation in aiter_mla prefill
  • Route cached prefill through paged MLA attention (supports Q≠K) instead of flash_attn_varlen_func (requires Q==K)
  • Track has_cached flag through AttentionMetaData for path selection
  • Fix block_manager: hash table leak, can_allocate cache-hit accounting, can_append for multi-token decode, O(1) free block tracking
  • Add CacheStats to scheduler for prefix cache hit rate monitoring
  • Add comprehensive block_manager tests (119 passing)

Verified: gsm8k 1319 samples, 95.83% accuracy, 0 GPU faults.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Fix GPU memory access fault caused by double conversion of block_tables
in cached prefill path. kv_indices_generate_triton applies block_ratio
internally, but was receiving already-converted block_tables (via
block_tables_converted), causing indices to be multiplied by block_ratio
twice (e.g. block_id*256 instead of block_id*16), exceeding KV cache
bounds.

Key changes:
- Use raw block_tables for kv_indices generation in aiter_mla prefill
- Route cached prefill through paged MLA attention (supports Q≠K)
  instead of flash_attn_varlen_func (requires Q==K)
- Track has_cached flag through AttentionMetaData for path selection
- Fix block_manager: hash table leak, can_allocate cache-hit accounting,
  can_append for multi-token decode, O(1) free block tracking
- Add CacheStats to scheduler for prefix cache hit rate monitoring
- Add comprehensive block_manager tests (119 passing)

Verified: gsm8k 1319 samples, 95.83% accuracy, 0 GPU faults.
from atom.model_engine.scheduler import ScheduledBatch

logger = logging.getLogger("atom")
from atom.model_ops.attention_mla import MLAModules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file


logger = logging.getLogger("atom")
from atom.model_ops.attention_mla import MLAModules
from atom.utils import CpuGpuBuffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

logger = logging.getLogger("atom")
from atom.model_ops.attention_mla import MLAModules
from atom.utils import CpuGpuBuffer
from atom.utils.block_convert import block_table_convert_triton
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

Comment on lines +13 to +14
import json
import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
json imported but unused

Suggested change
import json
import re
import re

sys.exit(1)

model = get_model_name(base_url)
print(f"=== Prefix Cache Accuracy Test ===")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F541> reported by reviewdog 🐶
f-string without any placeholders

Suggested change
print(f"=== Prefix Cache Accuracy Test ===")
print("=== Prefix Cache Accuracy Test ===")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant