-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for DeepSeek V3 #11049
Add support for DeepSeek V3 #11049
Conversation
src/llama.cpp
Outdated
// add experts selection bias - introduced in DeepSeek V3 | ||
ggml_tensor * selection_probs = probs; | ||
if (expert_weights_b != nullptr) { | ||
selection_probs = ggml_add(ctx, probs, expert_weights_b); | ||
cb(selection_probs, "ffn_moe_sigm_biased", il); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be simplified to:
// add experts selection bias - introduced in DeepSeek V3 | |
ggml_tensor * selection_probs = probs; | |
if (expert_weights_b != nullptr) { | |
selection_probs = ggml_add(ctx, probs, expert_weights_b); | |
cb(selection_probs, "ffn_moe_sigm_biased", il); | |
} | |
// add experts selection bias - introduced in DeepSeek V3 | |
if (expert_weights_b != nullptr) { | |
probs = ggml_add(ctx, probs, expert_weights_b); | |
cb(probs, "ffn_moe_sigm_b", il); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid this won't work correctly, as the original unmodified weights are still needed for multiplication with the experts output at the end of the function. Biased weights are used only for expert selection. See the DeepSeek V3 technical report:
Note that the bias term is only used for routing. The gating value, which will be multiplied with
the FFN output, is still derived from the original affinity score
Edit: I'm going to add a comment in the code to make it clear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - I missed this.
gguf-py/gguf/constants.py
Outdated
@@ -312,6 +314,7 @@ class MODEL_TENSOR(IntEnum): | |||
FFN_GATE_SHEXP = auto() | |||
FFN_DOWN_SHEXP = auto() | |||
FFN_UP_SHEXP = auto() | |||
FFN_EXPERT_WEIGHTS_B = auto() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more consistency in the names, lets change the EXPERT
to EXP
. Also, it seems that PROBS
is better name since this is a bias for the computed expert probabilities:
FFN_EXPERT_WEIGHTS_B = auto() | |
FFN_EXP_PROBS_B = auto() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
gguf-py/gguf/constants.py
Outdated
@@ -496,6 +499,7 @@ class MODEL_TENSOR(IntEnum): | |||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", | |||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", | |||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", | |||
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b", | |
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/llama.cpp
Outdated
@@ -2912,6 +2934,7 @@ struct llama_layer { | |||
struct ggml_tensor * ffn_down_b = nullptr; // b2 | |||
struct ggml_tensor * ffn_up_b = nullptr; // b3 | |||
struct ggml_tensor * ffn_act = nullptr; | |||
struct ggml_tensor * ffn_expert_weights_bias = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct ggml_tensor * ffn_expert_weights_bias = nullptr; | |
struct ggml_tensor * ffn_exp_probs_b = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/llama.cpp
Outdated
ggml_tensor * probs = nullptr; | ||
switch (gating_op) { | ||
case LLM_EXPERT_GATING_FUNC_SOFTMAX: | ||
{ | ||
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] | ||
cb(probs, "ffn_moe_probs", il); | ||
} break; | ||
case LLM_EXPERT_GATING_FUNC_SIGMOID: | ||
{ | ||
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] | ||
cb(probs, "ffn_moe_sigm", il); | ||
} break; | ||
default: | ||
GGML_ABORT("fatal error"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't set names here:
ggml_tensor * probs = nullptr; | |
switch (gating_op) { | |
case LLM_EXPERT_GATING_FUNC_SOFTMAX: | |
{ | |
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] | |
cb(probs, "ffn_moe_probs", il); | |
} break; | |
case LLM_EXPERT_GATING_FUNC_SIGMOID: | |
{ | |
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] | |
cb(probs, "ffn_moe_sigm", il); | |
} break; | |
default: | |
GGML_ABORT("fatal error"); | |
} | |
ggml_tensor * probs = nullptr; | |
switch (gating_op) { | |
case LLM_EXPERT_GATING_FUNC_SOFTMAX: | |
{ | |
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] | |
} break; | |
case LLM_EXPERT_GATING_FUNC_SIGMOID: | |
{ | |
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] | |
} break; | |
default: | |
GGML_ABORT("fatal error"); | |
} |
Instead, after applying the probs bias, call cb(probs, "ffn_moe_probs", il);
for the final probs
result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved name setting after switch, but I kept it separate from biased probs for reasons mentioned earlier.
…ode categories in pre-tokenization regex
@ggerganov I extended your "collapsed" regex workaround with \p{M} and \p{S} - DeepSeek V3 has these in pre-tokenizer regex. Take a look if it looks sane when you have a moment. I checked with test-tokenizer-0 and tokenization of wiki.test.raw now matches the original. |
@ggerganov Also since you merged #10902 I had to put expert_gating_func enum in a file included in both |
Let's place it in We can merge after you move the |
* convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type * vocab : add DeepSeek V3 pre-tokenizer regexes * unicode : handle ACCENT_MARK and SYMBOL categories in regex * llama : add DeepSeek V3 chat template, handle new model parameters and tensor types --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Related to #11141 While DeepSeek V3 support has been added, there appears to be an ongoing issue specifically with the ROCm backend. When attempting to run DeepSeek models (both V2 and V3) with ROCm:
This behavior is consistent across both DeepSeek V2 and V3 models. Would appreciate if this ROCm-specific issue could be investigated. |
@emuchogu Does it happen even with DeepSeek-V2-Lite? |
Yes. Same behavior with deepseek-v2-16b-lite-chat-q4_K_M. |
Just wondered if there are plans to add "Multi-head Latent Attention"? Is there something about the way |
The goal is after the #11213 refactoring is done is to be able to implement all kinds of attention mechanism independently. Right now, we have the "unified" KV cache which is being overloaded with a lot of logic and it's very difficult to add new things to it. |
Ah, I see. This might also open up the possibility of decomposition of existing "dense" projection matrices into MLA eventually (as an alternative to KV-cache quantisation), so can see it being useful for more than just deepseek. |
@jukofyork I have my experimental DeepSeek V2 branch with MLA here: https://github.com/fairydreaming/llama.cpp/tree/deepseek2-mla-exp From my limited testing on an Epyc CPU the token generation rate goes down very fast on this branch as the prompt length increases. I guess it's because of the additional overhead of recalculating K and V vectors from the cached latent representations for all previous tokens. It's possible that MLA makes sense only when running on a GPU. |
Thanks - how did the KV-cache memory usage compare? Somebody in another thread said it was ~10GB per 2048 tokens for the existing (non-MLA) implementation, so just wondered if it was much better in comparison? |
@jukofyork There are 60 layers in DeepSeek V3, if you use f16 KV cache (2 bytes per vector element) then the number of bytes cached for one token by regular KV cache is: MLA is vastly superior in terms of memory requirements, as it only caches: 60 * 2 * (64 + 512) = 67.5kB per token. For 2048 tokens that will be 135MB. Full 128k context would require only 8.4375GB of memory. |
@fairydreaming Thanks! I see the problem now: The way this is set up means you aren't really "caching" anything of use and it's almost as bad as saving the hidden state and then re-multiplying by the projection matrices each time... Looking through your code, I can see this: Which looks like it is manifesting all the K_t^c matrices for the whole sequence (ie: 32768 x 512 x sequence_len)? If you look at the bottom of the page of the paper above, then I think it's saying this isn't needed as the 32768 x 512 projection can be absorbed into the calculation for Q (and the same for V into O). Sorry I'm on my phone and don't really know any GGML, but got deepseek itself to try to fluff out the idea: To optimize the attention computation while retaining the compressed c vectors in the KV cache and deferring up-projections as late as possible, we can restructure the operations as follows: Key Strategy: Deferred Up-ProjectionLeverage the associativity of matrix multiplications to delay the application of (W^{UK}) and (W^{UV}) until they are absolutely needed. This avoids materializing full K/V matrices and keeps computations in the low-dimensional c space for as long as possible. Step-by-Step Implementation
Advantages
Complexity Analysis
This is strictly better than materializing K/V matrices ((O(d_h n_h \cdot n_{\text{tokens}}))). Practical Optimizations
ConclusionBy reordering computations to keep (W^{UK}) and (W^{UV}) on the "outside" of the critical path, we maintain the compressed KV cache while performing most operations in the low-dimensional space. This approach aligns with the original goal of MLA—reducing memory footprint without sacrificing efficiency—and ensures that up-projections are deferred until the final steps, minimizing redundant computation. |
(nice to see it screws up the LaTeX for GitHub just the same as o1 does 😁) |
Weirdly, their own implementation doesn't actually use the trick the paper mentions: https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
.
.
.
attn_output = torch.matmul(attn_weights, value_states)
.
.
.
attn_output = self.o_proj(attn_output) Maybe for GPU use it doesn't really matter or maybe they didn't care and just copied/pasted from llama (based on the class comment)? It will likely make a much bigger difference for CPU use though and do way more in cache than manifesting these giant matrices!? |
@jukofyork I doubt it's the same code that they use internally. But thanks for reminding me about the trick with absorbing matrices, I completely forgot about this. I'll try to optimize the implementation this way. |
Cool! If it works then I think we should be able to compress any existing projection matrices too:
IIRC, the above is equivalent to minimizing: No idea how well it would work though, but it might allow very long context if your 128k example is anything to go by! |
Somewhat related discussion from last year: |
@fairydreaming Just found their GitHub has the code from the paper: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py if attn_impl == "naive":
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
else:
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
.
.
.
if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale |
@ggerganov Is there a "recommended" way to precalculate some tensors for a model in llama.cpp? For example I want to split some model tensor into two to avoid doing it during inference in the main model graph. Or perhaps it's best to do it during model conversion? |
It's best to do it during model conversion. Though it would be nice to figure out some runtime method at some point. |
You can split tensors easily in |
@fairydreaming I see you've updated the experimental MLA branch: Did you manage to get this working? |
TLDR: it works, but it's not faster (yet). @jukofyork Sure, I optimized the implementation in my branch based on the DeepSeek code, on my CPU it's now almost as fast as the original "naive" implementation. Note that using the branch now requires reconverting the model. Prompt processing performance is still worse, though. Some unknowns to research:
|
Thanks for the heads up! I've just downloaded All this opens up a lot of interesting potential for using SVD to decompose other weight matrices too. I already outlined this for the attention matrices:
But actually there is no reason you can't concatenate different (or even all) sets of matrices like this if you trasnpose to have all with The QuaRot method is essentially doing the opposite of this and smearing the singular values. The "abliteration" and control vectors stuff also shows there is likely a lot of redundancy in the hidden state too (likely meaning the same I'm not sure how feasible it would be to compute the SVD for large models, but it definitely is something that can be explored for small models to see how quickly the singular values drop off and if it's even worth pursuing further. |
I added a second transposed copy of The disadvantage of this solution is that it doubles the cache memory requirements (but it's still way lower than regular KV cache). Also some numbers for the full V3/R1 model: Update: modified permutations to multiple larger matrices and got a nice speed bump: |
Hi, I would like to know if grammar / GBNF works for this model? |
* convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type * vocab : add DeepSeek V3 pre-tokenizer regexes * unicode : handle ACCENT_MARK and SYMBOL categories in regex * llama : add DeepSeek V3 chat template, handle new model parameters and tensor types --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This PR adds support for recently released DeepSeek V3 model. (MoE, 671B)
The model is architecturally very similar to DeepSeek V2, there are only minor changes in expert weights calculation.
Summary of changes:
expert_weights_norm
model parameter indicating whether expert weights shall be normalized or not - they were not normalized in DeepSeek V2 but they are in DeepSeek V3,expert_gating_func
model parameter corresponding to enum value indicating a function used to calculate expert probs - usually it's softmax, but DeepSeek V3 uses sigmoid for this purpose,expert_weights_b
exp_probs_b
tensor type containing expert weights bias tensors - DeepSeek V3 introduced bias term added to calculated expert probs, biased probs are the input to the top k experts selection process,llm_build_moe_ffn()
API and implementation to handle the mentioned differences,Note: DeepSeek V3 also introduced multi-token prediction (MTP), but I decided to skip this feature for now. MTP layer is ignored during model conversion and is not present in resulting GGUF file.