Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DeepSeek V3 #11049

Merged
merged 10 commits into from
Jan 4, 2025
Merged

Conversation

fairydreaming
Copy link
Collaborator

@fairydreaming fairydreaming commented Jan 2, 2025

This PR adds support for recently released DeepSeek V3 model. (MoE, 671B)

The model is architecturally very similar to DeepSeek V2, there are only minor changes in expert weights calculation.

Summary of changes:

  • added boolean expert_weights_norm model parameter indicating whether expert weights shall be normalized or not - they were not normalized in DeepSeek V2 but they are in DeepSeek V3,
  • added numerical expert_gating_func model parameter corresponding to enum value indicating a function used to calculate expert probs - usually it's softmax, but DeepSeek V3 uses sigmoid for this purpose,
  • added expert_weights_b exp_probs_b tensor type containing expert weights bias tensors - DeepSeek V3 introduced bias term added to calculated expert probs, biased probs are the input to the top k experts selection process,
  • updated llm_build_moe_ffn() API and implementation to handle the mentioned differences,
  • added new pre-tokenization regex for DeepSeek V3 - some wise man could take a look if it needs any modifications to work correctly.

Note: DeepSeek V3 also introduced multi-token prediction (MTP), but I decided to skip this feature for now. MTP layer is ignored during model conversion and is not present in resulting GGUF file.

@github-actions github-actions bot added the python python script changes label Jan 2, 2025
@fairydreaming fairydreaming linked an issue Jan 2, 2025 that may be closed by this pull request
4 tasks
src/llama.cpp Outdated
Comment on lines 10299 to 10304
// add experts selection bias - introduced in DeepSeek V3
ggml_tensor * selection_probs = probs;
if (expert_weights_b != nullptr) {
selection_probs = ggml_add(ctx, probs, expert_weights_b);
cb(selection_probs, "ffn_moe_sigm_biased", il);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be simplified to:

Suggested change
// add experts selection bias - introduced in DeepSeek V3
ggml_tensor * selection_probs = probs;
if (expert_weights_b != nullptr) {
selection_probs = ggml_add(ctx, probs, expert_weights_b);
cb(selection_probs, "ffn_moe_sigm_biased", il);
}
// add experts selection bias - introduced in DeepSeek V3
if (expert_weights_b != nullptr) {
probs = ggml_add(ctx, probs, expert_weights_b);
cb(probs, "ffn_moe_sigm_b", il);
}

Copy link
Collaborator Author

@fairydreaming fairydreaming Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this won't work correctly, as the original unmodified weights are still needed for multiplication with the experts output at the end of the function. Biased weights are used only for expert selection. See the DeepSeek V3 technical report:

Note that the bias term is only used for routing. The gating value, which will be multiplied with
the FFN output, is still derived from the original affinity score

Edit: I'm going to add a comment in the code to make it clear

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - I missed this.

@@ -312,6 +314,7 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
FFN_EXPERT_WEIGHTS_B = auto()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more consistency in the names, lets change the EXPERT to EXP. Also, it seems that PROBS is better name since this is a bias for the computed expert probabilities:

Suggested change
FFN_EXPERT_WEIGHTS_B = auto()
FFN_EXP_PROBS_B = auto()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -496,6 +499,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

src/llama.cpp Outdated
@@ -2912,6 +2934,7 @@ struct llama_layer {
struct ggml_tensor * ffn_down_b = nullptr; // b2
struct ggml_tensor * ffn_up_b = nullptr; // b3
struct ggml_tensor * ffn_act = nullptr;
struct ggml_tensor * ffn_expert_weights_bias = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct ggml_tensor * ffn_expert_weights_bias = nullptr;
struct ggml_tensor * ffn_exp_probs_b = nullptr;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

src/llama.cpp Outdated
Comment on lines 10283 to 10297
ggml_tensor * probs = nullptr;
switch (gating_op) {
case LLM_EXPERT_GATING_FUNC_SOFTMAX:
{
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_probs", il);
} break;
case LLM_EXPERT_GATING_FUNC_SIGMOID:
{
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_sigm", il);
} break;
default:
GGML_ABORT("fatal error");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't set names here:

Suggested change
ggml_tensor * probs = nullptr;
switch (gating_op) {
case LLM_EXPERT_GATING_FUNC_SOFTMAX:
{
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_probs", il);
} break;
case LLM_EXPERT_GATING_FUNC_SIGMOID:
{
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_sigm", il);
} break;
default:
GGML_ABORT("fatal error");
}
ggml_tensor * probs = nullptr;
switch (gating_op) {
case LLM_EXPERT_GATING_FUNC_SOFTMAX:
{
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
} break;
case LLM_EXPERT_GATING_FUNC_SIGMOID:
{
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
} break;
default:
GGML_ABORT("fatal error");
}

Instead, after applying the probs bias, call cb(probs, "ffn_moe_probs", il); for the final probs result.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved name setting after switch, but I kept it separate from biased probs for reasons mentioned earlier.

@fairydreaming
Copy link
Collaborator Author

@ggerganov I extended your "collapsed" regex workaround with \p{M} and \p{S} - DeepSeek V3 has these in pre-tokenizer regex. Take a look if it looks sane when you have a moment. I checked with test-tokenizer-0 and tokenization of wiki.test.raw now matches the original.

@fairydreaming
Copy link
Collaborator Author

fairydreaming commented Jan 3, 2025

@ggerganov Also since you merged #10902 I had to put expert_gating_func enum in a file included in both llama-hparams.h, llama.cpp and llama-model.cpp. I put it in llama.h, let me know if you have other plans for this enum.

@ggerganov
Copy link
Member

@ggerganov Also since you merged #10902 I had to put expert_gating_func enum in a file included in both llama-hparams.h, llama.cpp and llama-model.cpp. I put it in llama.h, let me know if you have other plans for this enum.

Let's place it in llama-hparams.h for now. We can potentially make it public if we find some utility in the future, but for it's better to try to hide more things from the public API - there are some other enums in llama.h that can also be moved to the implementation.

We can merge after you move the llama_expert_gating_func_type to llama-hparams.h

@fairydreaming fairydreaming merged commit 9394bbd into ggml-org:master Jan 4, 2025
51 checks passed
netrunnereve pushed a commit to netrunnereve/llama.cpp that referenced this pull request Jan 5, 2025
* convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type

* vocab : add DeepSeek V3 pre-tokenizer regexes

* unicode : handle ACCENT_MARK and SYMBOL categories in regex

* llama : add DeepSeek V3 chat template, handle new model parameters and tensor types

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
@x66ccff x66ccff mentioned this pull request Jan 5, 2025
@emuchogu
Copy link

emuchogu commented Jan 8, 2025

Related to #11141

While DeepSeek V3 support has been added, there appears to be an ongoing issue specifically with the ROCm backend. When attempting to run DeepSeek models (both V2 and V3) with ROCm:

  • Models load successfully into VRAM
  • No output is generated
  • One GPU becomes pinned at 100% utilization
  • Other GPUs remain idle

This behavior is consistent across both DeepSeek V2 and V3 models. Would appreciate if this ROCm-specific issue could be investigated.

@fairydreaming
Copy link
Collaborator Author

@emuchogu Does it happen even with DeepSeek-V2-Lite?

@emuchogu
Copy link

emuchogu commented Jan 8, 2025

Yes. Same behavior with deepseek-v2-16b-lite-chat-q4_K_M.
I tested this using Ollama and got the same 100% GPU pegging with no output.

@jukofyork
Copy link
Contributor

Just wondered if there are plans to add "Multi-head Latent Attention"?

Is there something about the way llama.cpp is written that would make this hard to implement without lots of refactoring (eg: to do with the KV-cache, etc)?

@ggerganov
Copy link
Member

ggerganov commented Jan 24, 2025

The goal is after the #11213 refactoring is done is to be able to implement all kinds of attention mechanism independently. Right now, we have the "unified" KV cache which is being overloaded with a lot of logic and it's very difficult to add new things to it.

@jukofyork
Copy link
Contributor

The goal is after the #11213 refactoring is done is to be able to implement all kinds of attention mechanism independently. Right now, we have the "unified" KV cache which is being overloaded with a lot of logic and it's very difficult to add new things to it.

Ah, I see.

This might also open up the possibility of decomposition of existing "dense" projection matrices into MLA eventually (as an alternative to KV-cache quantisation), so can see it being useful for more than just deepseek.

@fairydreaming
Copy link
Collaborator Author

@jukofyork I have my experimental DeepSeek V2 branch with MLA here: https://github.com/fairydreaming/llama.cpp/tree/deepseek2-mla-exp

From my limited testing on an Epyc CPU the token generation rate goes down very fast on this branch as the prompt length increases. I guess it's because of the additional overhead of recalculating K and V vectors from the cached latent representations for all previous tokens. It's possible that MLA makes sense only when running on a GPU.

@jukofyork
Copy link
Contributor

@jukofyork I have my experimental DeepSeek V2 branch with MLA here: https://github.com/fairydreaming/llama.cpp/tree/deepseek2-mla-exp

From my limited testing on an Epyc CPU the token generation rate goes down very fast on this branch as the prompt length increases. I guess it's because of the additional overhead of recalculating K and V vectors from the cached latent representations for all previous tokens. It's possible that MLA makes sense only when running on a GPU.

Thanks - how did the KV-cache memory usage compare? Somebody in another thread said it was ~10GB per 2048 tokens for the existing (non-MLA) implementation, so just wondered if it was much better in comparison?

@fairydreaming
Copy link
Collaborator Author

@jukofyork There are 60 layers in DeepSeek V3, if you use f16 KV cache (2 bytes per vector element) then the number of bytes cached for one token by regular KV cache is:
60 * 2 * (24576 + 16384) = 4.6875 MB. So for 2048 tokens it will cache 9.375GB of data (that 10GB estimate was correct). Full 128k context would take 600GB of memory.

MLA is vastly superior in terms of memory requirements, as it only caches: 60 * 2 * (64 + 512) = 67.5kB per token. For 2048 tokens that will be 135MB. Full 128k context would require only 8.4375GB of memory.

@jukofyork
Copy link
Contributor

jukofyork commented Jan 24, 2025

@fairydreaming Thanks!

I see the problem now:

Screenshot_20250124-184642

The way this is set up means you aren't really "caching" anything of use and it's almost as bad as saving the hidden state and then re-multiplying by the projection matrices each time...

Looking through your code, I can see this:

Screenshot_20250124-192216

Which looks like it is manifesting all the K_t^c matrices for the whole sequence (ie: 32768 x 512 x sequence_len)?

If you look at the bottom of the page of the paper above, then I think it's saying this isn't needed as the 32768 x 512 projection can be absorbed into the calculation for Q (and the same for V into O).



Sorry I'm on my phone and don't really know any GGML, but got deepseek itself to try to fluff out the idea:

To optimize the attention computation while retaining the compressed c vectors in the KV cache and deferring up-projections as late as possible, we can restructure the operations as follows:


Key Strategy: Deferred Up-Projection

Leverage the associativity of matrix multiplications to delay the application of (W^{UK}) and (W^{UV}) until they are absolutely needed. This avoids materializing full K/V matrices and keeps computations in the low-dimensional c space for as long as possible.


Step-by-Step Implementation

  1. KV Cache:
    Store only the compressed vectors (\mathbf{c}_j^{\text{KV}} \in \mathbb{R}^{d_c}) for all previous tokens (j).

    • Cache Size: (d_c \cdot n_{\text{tokens}}), far smaller than (2d_h n_h \cdot n_{\text{tokens}}).
  2. Query Projection with Integrated Key Up-Projection:
    For each query (\mathbf{q}_i) (from (\mathbf{h}_i) via (W^Q)), compute a pre-transformed query that absorbs (W^{UK}):
    [
    \mathbf{q}_i^{\text{proj}} = \mathbf{q}_i^T W^{UK} \quad \in \mathbb{R}^{d_c}
    ]

    • This combines the query projection ((W^Q)) and key up-projection ((W^{UK})) into a single step.
    • Complexity: (O(d_h \cdot d_c)) per query.
  3. Attention Scores in Compressed Space:
    Compute scores using the compressed (\mathbf{c}j):
    [
    \text{score}
    {ij} = \frac{\mathbf{q}_i^{\text{proj}} \mathbf{c}_j}{\sqrt{d_h}} \quad \forall j \leq i
    ]

    • No full K matrices: Scores are computed directly in the (d_c)-dimensional space.
    • Complexity: (O(d_c \cdot n_{\text{tokens}})) per query.
  4. Softmax and Attention Weights:
    Apply softmax to (\text{score}{ij}) to get attention weights (\alpha{ij}).

  5. Value Aggregation in Compressed Space:
    Compute the weighted sum of (\mathbf{c}j) vectors:
    [
    \mathbf{c}
    {\text{agg}} = \sum_j \alpha_{ij} \mathbf{c}_j \quad \in \mathbb{R}^{d_c}
    ]

    • No full V matrices: Aggregation occurs in the compressed space.
  6. Final Value Up-Projection:
    Project (\mathbf{c}{\text{agg}}) to the value space after aggregation:
    [
    \mathbf{v}
    {\text{out}} = W^{UV} \mathbf{c}_{\text{agg}} \quad \in \mathbb{R}^{d_h n_h}
    ]

    • Complexity: (O(d_h n_h \cdot d_c)), deferred until the final step.
  7. Output Projection:
    Concatenate outputs across heads and apply (W^O) as usual.


Advantages

  • Minimal KV Cache: Only (\mathbf{c}_j^{\text{KV}}) are stored ((d_c \ll d_h n_h)).
  • Avoid Large Intermediate Matrices:
    • K/V matrices are never explicitly materialized.
    • Up-projections ((W^{UK}), (W^{UV})) are applied only when necessary.
  • Optimized GEMM Operations:
    • Batched Pre-Projection: Compute (\mathbf{q}_i^{\text{proj}}) for multiple queries/heads in a single GEMM call.
    • Cache-Friendly Layouts: Store (W^{UK}) and (W^{UV}) in transposed/blocked formats to align with CPU memory access patterns.

Complexity Analysis

Step Complexity per Token
Query Projection ((W^{UK})) (O(d_h \cdot d_c))
Attention Scores (O(d_c \cdot n_{\text{tokens}}))
Value Aggregation + Projection (O(d_c \cdot n_{\text{tokens}} + d_h n_h \cdot d_c))

This is strictly better than materializing K/V matrices ((O(d_h n_h \cdot n_{\text{tokens}}))).


Practical Optimizations

  1. Fused Query-Key Projection:
    Precompute (W^{Q}W^{UK}) offline to merge the query and key up-projection into a single matrix.

    • Reduces runtime overhead by (O(d_h \cdot d_c)).
  2. Blocked Attention Computation:
    Process tokens in blocks to reuse cached (\mathbf{c}_j) vectors and exploit temporal locality.

  3. Kernel Fusion:
    Fuse the softmax and aggregation steps to avoid intermediate memory writes.


Conclusion

By reordering computations to keep (W^{UK}) and (W^{UV}) on the "outside" of the critical path, we maintain the compressed KV cache while performing most operations in the low-dimensional space. This approach aligns with the original goal of MLA—reducing memory footprint without sacrificing efficiency—and ensures that up-projections are deferred until the final steps, minimizing redundant computation.

@jukofyork
Copy link
Contributor

(nice to see it screws up the LaTeX for GitHub just the same as o1 does 😁)

@jukofyork
Copy link
Contributor

jukofyork commented Jan 24, 2025

Weirdly, their own implementation doesn't actually use the trick the paper mentions:

https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py

        kv = (
            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
            .transpose(1, 2)
        )
.
.
.
        attn_output = torch.matmul(attn_weights, value_states)
.
.
.
        attn_output = self.o_proj(attn_output)

Maybe for GPU use it doesn't really matter or maybe they didn't care and just copied/pasted from llama (based on the class comment)?

It will likely make a much bigger difference for CPU use though and do way more in cache than manifesting these giant matrices!?

@fairydreaming
Copy link
Collaborator Author

@jukofyork I doubt it's the same code that they use internally. But thanks for reminding me about the trick with absorbing matrices, I completely forgot about this. I'll try to optimize the implementation this way.

@jukofyork
Copy link
Contributor

jukofyork commented Jan 24, 2025

@jukofyork I doubt it's the same code that they use internally. But thanks for reminding me about the trick with absorbing matrices, I completely forgot about this. I'll try to optimize the implementation this way.

Cool!

If it works then I think we should be able to compress any existing projection matrices too:

  1. Concatenate: $W' = [W^K \mid W^V]$.

  2. Compute SVD: $W' = U \Sigma V^T$.

  3. Truncate: Keep top $r$ singular values/vectors.

  4. Factorize:

$$W^{DKV} = U_r \Sigma_r^{1/2}$$ $$W^{UK} = \Sigma_r^{1/2} V_r^T[:, 1:n]$$ $$W^{UV} = \Sigma_r^{1/2} V_r^T[:, n+1:2n]$$

IIRC, the above is equivalent to minimizing:

$$ |W^K - W^{DKV} W^{UK}|_F^2 + |W^V - W^{DKV} W^{UV}|_F^2 $$

No idea how well it would work though, but it might allow very long context if your 128k example is anything to go by!

@jukofyork
Copy link
Contributor

Somewhat related discussion from last year:

#8831

@jukofyork
Copy link
Contributor

@fairydreaming Just found their GitHub has the code from the paper:

https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py

        if attn_impl == "naive":
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        else:
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
.
.
.
        if attn_impl == "naive":
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
        else:
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale

@jukofyork
Copy link
Contributor

@fairydreaming
Copy link
Collaborator Author

@ggerganov Is there a "recommended" way to precalculate some tensors for a model in llama.cpp? For example I want to split some model tensor into two to avoid doing it during inference in the main model graph. Or perhaps it's best to do it during model conversion?

@ggerganov
Copy link
Member

It's best to do it during model conversion. Though it would be nice to figure out some runtime method at some point.

@slaren
Copy link
Member

slaren commented Jan 25, 2025

You can split tensors easily in convert_hf_to_gguf.py in the modify_tensors function, just do the split logic there and return each tensor as an element of the list.

@jukofyork
Copy link
Contributor

@fairydreaming I see you've updated the experimental MLA branch:

de538aa

Did you manage to get this working?

@fairydreaming
Copy link
Collaborator Author

TLDR: it works, but it's not faster (yet).

@jukofyork Sure, I optimized the implementation in my branch based on the DeepSeek code, on my CPU it's now almost as fast as the original "naive" implementation. Note that using the branch now requires reconverting the model. Prompt processing performance is still worse, though.

Some unknowns to research:

  • I'm not sure how the ggml_mul_mat(ctx, x, y) performance depends on the tensor shapes. For example I have multiplication of two tensors: {512, 32, 1, 1} and {512, 1, 2, 16}. Will it be faster if I permute the second tensor to {512, 16, 2, 1}?
  • In the code I use both kv_cache (with latent kv representations) of dimensions {kv_lora_rank, n_kv} and transposed kv_cache of dimensions {n_kv, kv_lora_rank} as the first argument in matrix multiplications. I guess that transposing the whole cache takes a lot of time with longer context sizes, would be nice to avoid that. Unfortunately, it would require implementing efficient multiplication of transposed matrices - not sure if it's worth the effort.

@jukofyork
Copy link
Contributor

jukofyork commented Jan 26, 2025

Thanks for the heads up! I've just downloaded r1 to try on an old dual Xeon E5 system (approximately 1/3 of your memory bandwidth, but it does have dual A6000 in), but will be next week before I can test it out.


All this opens up a lot of interesting potential for using SVD to decompose other weight matrices too. I already outlined this for the attention matrices:

  1. Concatenate: $W' = [W^K \mid W^V]$.

  2. Compute SVD: $W' = U \Sigma V^T$.

  3. Truncate: Keep top $r$ singular values/vectors.

  4. Factorize:

$$W^{DKV} = U_r \Sigma_r^{1/2}$$ $$W^{UK} = \Sigma_r^{1/2} V_r^T[:, 1:n]$$ $$W^{UV} = \Sigma_r^{1/2} V_r^T[:, n+1:2n]$$

But actually there is no reason you can't concatenate different (or even all) sets of matrices like this if you trasnpose to have all with hidden_dim as the rows (compressing the massive MoE tensor-triplets would be huge if it worked!).

The QuaRot method is essentially doing the opposite of this and smearing the singular values. The "abliteration" and control vectors stuff also shows there is likely a lot of redundancy in the hidden state too (likely meaning the same $U_r \Sigma_r^{1/2}$ could potentially be shared between layers).

I'm not sure how feasible it would be to compute the SVD for large models, but it definitely is something that can be explored for small models to see how quickly the singular values drop off and if it's even worth pursuing further.

@fairydreaming
Copy link
Collaborator Author

fairydreaming commented Jan 26, 2025

I added a second transposed copy of $c^{KV}$ cache to check if transposing the cache is the bottleneck and indeed - now optimized MLA implementation is faster than naive one:

deepseek-mla

The disadvantage of this solution is that it doubles the cache memory requirements (but it's still way lower than regular KV cache). Also some numbers for the full V3/R1 model:

deepseek-r1-mla

Update: modified permutations to multiple larger matrices and got a nice speed bump:

deepseek-r1-mla

@sirceljm
Copy link

Hi, I would like to know if grammar / GBNF works for this model?

tinglou pushed a commit to tinglou/llama.cpp that referenced this pull request Feb 13, 2025
* convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type

* vocab : add DeepSeek V3 pre-tokenizer regexes

* unicode : handle ACCENT_MARK and SYMBOL categories in regex

* llama : add DeepSeek V3 chat template, handle new model parameters and tensor types

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: add DeepSeek-v3 support
7 participants