Skip to content

Commit

Permalink
[#441,#439,#436]Fixed a Save & Load issue for Qwen2
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Nov 30, 2024
1 parent cd997ba commit 85aa041
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
39 changes: 20 additions & 19 deletions examples/LLMs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,13 @@ wikitext perplexity 92795.3984375

### :rocket: Qwen/Qwen2-7B

The Qwen2-7b model has 28 heads with ``num_key_value_heads=4``. This limits the pruning ratio to be [1/7, 2/7, 3/7, 4/7, 5/7, 6/7] if you want to save and load the pruned model using hugingface transformers, since it HF only supports the same in_features and out_features for the ``q_proj`` and ``o_proj``.

```bash
python prune_llm.py --model Qwen/Qwen2-7B --pruning_ratio 0.5 --max_seq_len 4096
# 3/7 ~ 0.428571428, this script will craft a 2B model
python prune_llm.py --model Qwen/Qwen2-7B --pruning_ratio 0.428571428 --max_seq_len 4096
```


<details>
<summary>Output:</summary>

Expand Down Expand Up @@ -522,30 +524,30 @@ Qwen2ForCausalLM(
----------------- After Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 1792)
(embed_tokens): Embedding(152064, 2048)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=1792, out_features=2048, bias=True)
(k_proj): Linear(in_features=1792, out_features=512, bias=True)
(v_proj): Linear(in_features=1792, out_features=512, bias=True)
(o_proj): Linear(in_features=2048, out_features=1792, bias=False)
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=512, bias=True)
(v_proj): Linear(in_features=2048, out_features=512, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1792, out_features=9472, bias=False)
(up_proj): Linear(in_features=1792, out_features=9472, bias=False)
(down_proj): Linear(in_features=9472, out_features=1792, bias=False)
(gate_proj): Linear(in_features=2048, out_features=10825, bias=False)
(up_proj): Linear(in_features=2048, out_features=10825, bias=False)
(down_proj): Linear(in_features=10825, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((1792,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((1792,), eps=1e-06)
(input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((1792,), eps=1e-06)
(norm): Qwen2RMSNorm((2048,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=1792, out_features=152064, bias=False)
(lm_head): Linear(in_features=2048, out_features=152064, bias=False)
)
Qwen2Config {
"_attn_implementation_autoset": true,
Expand All @@ -557,9 +559,9 @@ Qwen2Config {
"bos_token_id": 151643,
"eos_token_id": 151643,
"hidden_act": "silu",
"hidden_size": 1792,
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 9472,
"intermediate_size": 10825,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "qwen2",
Expand All @@ -578,13 +580,12 @@ Qwen2Config {
"vocab_size": 152064
}
num_params 2227887872
num_params 2778904576
evaluating on wikitext2
Token indices sequence length is longer than the specified maximum sequence length for this model (2541000 > 32768). Running this sequence through the model will result in indexing errors
nsamples 73
sample 0
sample 50
wikitext perplexity 13779.380859375
wikitext perplexity 44195.69140625
```

</details>
Expand Down
7 changes: 4 additions & 3 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .scheduler import linear_scheduler
from ..import function
from ... import ops, dependency
import math

class MetaPruner:
"""
Expand Down Expand Up @@ -560,13 +561,13 @@ def _prune(self) -> typing.Generator:
_is_gqa = not all([self.num_heads[qkv_layer]==num_heads for qkv_layer in qkv_layers])

if not self.global_pruning: # local pruning
n_heads_removed_per_group = int(self.get_target_head_pruning_ratio(qkv_layers[0]) * len(head_imp))
n_heads_removed_per_group = math.ceil(self.get_target_head_pruning_ratio(qkv_layers[0]) * len(head_imp))
if not _is_gqa:
head_pruning_indices = torch.topk(head_imp, k=n_heads_removed_per_group, largest=False)[1] # local ranking
else: # chunk the head imp
num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
n_heads_removed_per_group = n_heads_removed_per_group // num_kv_heads
n_heads_removed_per_group = math.ceil(n_heads_removed_per_group / num_kv_heads)
head_pruning_indices = []
for kv_head_id in range(num_kv_heads):
head_imp_kv = head_imp[kv_head_id * num_heads//num_kv_heads: (kv_head_id+1) * num_heads//num_kv_heads]
Expand All @@ -578,7 +579,7 @@ def _prune(self) -> typing.Generator:
head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) # global ranking
if _is_gqa:
num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
n_heads_removed_per_group = len(head_pruning_indices) // num_kv_heads
n_heads_removed_per_group = math.ceil(len(head_pruning_indices) / num_kv_heads)
head_pruning_indices = []
for kv_head_id in range(num_kv_heads):
head_imp_kv = head_imp[kv_head_id * len(head_imp)//num_kv_heads: (kv_head_id+1) * len(head_imp)//num_kv_heads]
Expand Down

0 comments on commit 85aa041

Please sign in to comment.