Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeepseekR1]How ragged prefill manage kv_cache? #3849

Open
AlvL1225 opened this issue Feb 25, 2025 · 3 comments
Open

[DeepseekR1]How ragged prefill manage kv_cache? #3849

AlvL1225 opened this issue Feb 25, 2025 · 3 comments
Assignees

Comments

@AlvL1225
Copy link

I'm investigating the chunked prefill method in DeepSeek V3/R1. The code shows that it uses self.prefill_wrapper_ragged.forward_return_lse for both prefill and chunked prefill operations. However, I haven't been able to locate where the KV cache is provided in the code. Could you help me identify this part of the implementation?

Image Image
@minleminzui minleminzui self-assigned this Feb 25, 2025
@Fridge003
Copy link
Collaborator

Fridge003 commented Feb 25, 2025

Hi @AlvL1225 , kv cache in flashinfer backend is provided by forward batch in the form of forward_batch.token_to_kv_pool. The setting of kv cache indices is completed in function call_begin_forward of FlashInferIndicesUpdaterDecode and FlashInferIndicesUpdaterPreill classes, and this function is called in init_forward_metadata function of attention backend.

For Deepseek v3/r1 models, the logic of flashinfer mla has been moved to flashinfer_mla_backend.py. You can refer to #3785.

@AlvL1225
Copy link
Author

AlvL1225 commented Feb 26, 2025

Hi @AlvL1225 , kv cache in flashinfer backend is provided by forward batch in the form of forward_batch.token_to_kv_pool. The setting of kv cache indices is completed in function call_begin_forward of FlashInferIndicesUpdaterDecode and FlashInferIndicesUpdaterPreill classes, and this function is called in init_forward_metadata function of attention backend.

For Deepseek v3/r1 models, the logic of flashinfer mla has been moved to flashinfer_mla_backend.py. You can refer to #3785.

Hi @Fridge003 , thank you for your reply!
I checked flashinfer_mla_backend.py but I am still confused:

Image

During chunk prefill, we use Multi-Head Attention (MHA) with 192 dimensions for query and key, and 128 dimensions for value and output. However, the cache is saved in a compressed format as "compressed_kv" (512 dimensions for non-rotary position embeddings and 64 dimensions for rotary position embeddings in keys).

When executing chunked prefill operations, do we decompress "lora_kv" every time, or is the "MHA kv cache" used directly at some point in the process?

@Fridge003
Copy link
Collaborator

Fridge003 commented Feb 26, 2025

Hi @AlvL1225 , the ragged part of flashinfer mla backend is still under developing, so the code is unfinished. Please stay tuned for related PR in two or three days.

For flashinfer mla backend, paged prefilling(prefiling with prefix) will call forward_absorb, and save the kv cache in the forward_extend function of attention backend. Ragged prefilling (prefilling without prefix) will call forward_normal and save the kv cache before forward_extend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants