Skip to content

fix: skip unchanged layers entirely during partial KV recompute#8

Open
rinarina0429 wants to merge 4 commits intomainfrom
fix/#7
Open

fix: skip unchanged layers entirely during partial KV recompute#8
rinarina0429 wants to merge 4 commits intomainfrom
fix/#7

Conversation

@rinarina0429
Copy link
Member

Stage 전환 시 boundary 이전 레이어(가중치 불변)에 대한
KV-only path(LayerNorm+QKV+Rotary+write)를 완전 스킵으로 교체.

  • Layer 0~boundary-1은 Stage 간 가중치 동일 → K,V 값도 동일
  • vLLM prefix caching이 Stage 1 KV 블록을 재사용하므로 재계산 불필요
  • hidden states는 CPU 캐시에서 GPU로 복원만 수행
  • cache miss 시 normal forward로 안전하게 폴백

제거: _kv_only_forward_layer (더 이상 불필요)
성능: partial recompute의 skip 구간 2.57x 빠름 (~15ms → ~6ms, 200 tokens 기준) (측정마다 미세하게 다름)

Stage 전환 시 boundary 이전 레이어(가중치 불변)에 대한
KV-only path(LayerNorm+QKV+Rotary+write)를 완전 스킵으로 교체.

- Layer 0~boundary-1은 Stage 간 가중치 동일 → K,V 값도 동일
- vLLM prefix caching이 Stage 1 KV 블록을 재사용하므로 재계산 불필요
- hidden states는 CPU 캐시에서 GPU로 복원만 수행
- cache miss 시 normal forward로 안전하게 폴백

제거: _kv_only_forward_layer (더 이상 불필요)
성능: partial recompute의 skip 구간 2.57x 빠름 (~15ms → ~6ms, 200 tokens 기준)
@rinarina0429 rinarina0429 self-assigned this Feb 18, 2026
Stage 전환 직후 한 번도 inference 없이 다음 stage를 요청할 경우
clear_persistent_buffers()로 zeros가 된 레이어가 CPU 캐시에 올라가
partial recompute에서 잘못된 hidden states가 사용되는 버그 수정.

- _populated_layers(set) 추가: 실제 inference 데이터가 기록된 레이어 추적
- forward(): buffer 기록 시 _populated_layers에 레이어 추가
- clear_persistent_buffers(): _populated_layers 함께 초기화
- sync_persistent_cache(): _populated_layers에 없는 레이어는 건너뜀
  (cache miss → 해당 레이어는 normal forward로 안전하게 폴백)
… skip

KV-only path에서 complete skip 방식으로 변경된 구현을 반영하여
chatbot_partial_cache.py의 주석 3곳 업데이트.

- 파일 상단: KV-only forward → 완전 스킵 (가중치 불변 → KV cache 유효)
- ProgressiveChatbotPartial 클래스 docstring: KV-only → 완전 스킵 (GPU 연산 없음)
- _trigger_partial_recompute() 동작 설명을 새 방식으로 전면 재작성
  (GPU persistent buffer → CPU cache → layer skip 흐름 반영)
Stage 전환 시 Front layers의 K,V를 GPU 캐시에서 직접 읽어 CPU에 저장(snapshot)하고,
Partial recompute 시 재계산 없이 GPU memcopy로 복원하는 최적화.

■ chatbot_partial_cache.py
- chat(): generate() 완료 후 vLLM 실제 token IDs(prompt+generated) 저장
  → _build_prompt() 재토크나이징 시 chat template 차이로 발생하는
    KV cache block hash 불일치 문제 방지
- _save_kv_snapshot(): [신규] Stage 전환 직전 GPU KV 블록에서 K,V 직접 읽기
  · PrefixCachingBlock.hash_block_tokens()로 block_id 조회
  · CpuGpuBlockAllocator._allocators[Device.GPU]._cached_blocks 접근
  · 중간 block 미발견 시 전체 abort 대신 찾은 블록까지만 사용
- _clear_kv_prefix_cache(): [신규] llm.reset_prefix_cache()로 stale KV blocks 퇴출
- _trigger_partial_recompute(): KV snapshot 파이프라인으로 재구성
  (snapshot 저장 → stale cache 퇴출 → model에 snapshot 전달 → generate)

■ progressive_serve/progressive_model_dual_path.py
- _kv_snapshot, _snapshot_num_full_tokens 인스턴스 변수 추가
- set_kv_snapshot(): [신규] chatbot에서 snapshot 수신하는 public API
- forward() partial recompute 로직 전면 개편:
  · 방법 A (Full blocks): _kv_snapshot에서 write_kv_to_cache() memcopy → FLOPs 0
  · 방법 B (Partial last block + fallback): qkv_proj + rotary_emb → write_kv_to_cache
    (flash_attn / o_proj / MLP 생략)
  · hidden_states는 기존대로 CPU cache에서 복원
  · partial recompute 완료 후 snapshot 자동 클리어 (1회성)
- set_partial_recompute(): docstring 및 로그 메시지 업데이트

■ progressive_serve/progressive_for_causal_lm.py
- Falcon 모델 checkpoint key aliasing 추가
  (transformer.h.* → model.layers.*, lm_head 등 네임스페이스 정규화)

■ progressive_serve/model_config.py
- Mistral layer class 수정: MistralDecoderLayer → LlamaDecoderLayer
  (vLLM에서 Mistral이 LLaMA 아키텍처를 공유하므로)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix: hidden states만 skip하는게 아닌, kv cache를 skip하도록 수정

1 participant