Skip to content

Commit

Permalink
Add an example for head pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Oct 16, 2023
1 parent 3ab8214 commit ac07c1e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
8 changes: 5 additions & 3 deletions examples/transformers/prune_timm_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def parse_args():
parser.add_argument('--pruning_type', default='l1', type=str, help='pruning type', choices=['random', 'taylor', 'l2', 'l1', 'hessian'])
parser.add_argument('--test_accuracy', default=False, action='store_true', help='test accuracy')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--prune_num_heads', default=False, action='store_true', help='global pruning')
parser.add_argument('--head_pruning_ratio', default=0.0, type=float, help='head pruning ratio')
parser.add_argument('--use_imagenet_mean_std', default=False, action='store_true', help='use imagenet mean and std')
parser.add_argument('--train_batch_size', default=64, type=int, help='train batch size')
parser.add_argument('--val_batch_size', default=128, type=int, help='val batch size')
Expand Down Expand Up @@ -146,9 +148,9 @@ def main():
pruning_ratio=args.pruning_ratio, # target pruning ratio
ignored_layers=ignored_layers,
num_heads=num_heads, # number of heads in self attention
prune_num_heads=False, # reduce num_heads by pruning entire heads (default: False)
prune_head_dims=True, # reduce head_dim by pruning featrues dims of each head (default: True)
head_pruning_ratio=0.5, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
prune_num_heads=args.prune_num_heads, # reduce num_heads by pruning entire heads (default: False)
prune_head_dims=not args.prune_num_heads, # reduce head_dim by pruning featrues dims of each head (default: True)
head_pruning_ratio=args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
round_to=2
)

Expand Down
59 changes: 59 additions & 0 deletions examples/transformers/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,65 @@ wget https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pr
python measure_latency.py --model pretrained/vit_b_16_pruning_taylor_uniform.pth
```

### Pruning attention heads

```bash
python prune_timm_vit.py --prune_num_heads --head_pruning_ratio 0.5
```

```bash
...
Head #0
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #1
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #2
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #3
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #4
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #5
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #6
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #7
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #8
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #9
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #10
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64

Head #11
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 6, Head Dim: 64
...
```


## Pruning ViT-ImageNet-1K from [HF Transformers](https://huggingface.co/docs/transformers/index)

### Pruning
Expand Down

0 comments on commit ac07c1e

Please sign in to comment.