APT: Adaptive Pruning and Tuning Pretrained Language Models for Efficient Training and Inference
- [2024/01/22] Our paper is now on arXiv! Check it out here.
- [2024/05/01] APT is accepted by ICML 2024 as oral presentation!
We propose APT, a methodology that Adaptively selects model parameters for Pruning and fine-Tuning. APT combines the benefits of PEFT and structured pruning to make fine-tuning and inference more efficient.
Our intuition is that pretrained language model (LM) parameters contain general knowledge, but their importance to downstream tasks varies. Therefore, we can remove the parameters irrelevant to the fine-tuning task in the early training stage. Early-removing these parameters improves training and inference efficiency while not substantially hurting model accuracy. Meanwhile, continuously adding more parameters for fine-tuning can improve LM performance because task-specific skills live in a subset of LM parameters.
Based on this setup, we find that using self-distillation where the main parameters between the teacher and student models are shared can vasly prune small LMs with high end-task performance retained. Meanwhile, considering in-block outliers by calculating kurtosis when pruning large LMs before training can accurately prune them with less training memory footprint.
RoBERTa-base experiment results:
Method | MNLI | SST2 | SQuAD v2 | Train Time | Train Mem. | Inf Time | Inf Mem. |
---|---|---|---|---|---|---|---|
FT | 87.6 | 94.8 | 82.9 | 100.0% | 100.0% | 100.0% | 100.0% |
LoRA | 87.5 | 95.1 | 83.0 | 2137.0% | 60.5% | 100.0% | 100.0% |
LoRA+Prune | 84.0 | 93.0 | 79.2 | 5128.3% | 60.5% | 38.0% | 75.1% |
Prune+Distill | 87.3 | 94.5 | - | 1495.3% | 168.5% | 38.6% | 79.2% |
LoRA+Prune+Distill | 84.2 | 91.9 | - | 6534.6% | 141.4% | 39.4% | 82.3% |
APT | 86.4 | 94.5 | 81.8 | 592.1% | 70.1% | 41.3% | 78.1% |
T5-base experiment results:
Method | MNLI | SST2 | CNN/DM | Train Time | Train Mem. | Inf Time | Inf Mem. |
---|---|---|---|---|---|---|---|
FT | 87.1 | 95.2 | 42.1/20.3/39.4 | 100.0% | 100.0% | 100.0% | 100.0% |
LoRA | 87.0 | 95.0 | 38.7/17.2/36.0 | 255.5% | 62.0% | 100.0% | 100.0% |
LoRA+Prune | 80.9 | 92.3 | 36.7/15.7/33.9 | 4523.5% | 62.0% | 47.1% | 73.4% |
APT | 87.0 | 95.0 | 38.6/17.0/35.8 | 484.7% | 73.9% | 74.6% | 81.5% |
LLaMA-7B experiment results:
Method | ARC | HellaSwag | MMLU | TruthfulQA | Avg. | Train Time | Train Mem. | Inf Time | Inf Mem. |
---|---|---|---|---|---|---|---|---|---|
LLaMA 2 7B | 53.1 | 77.7 | 43.8 | 39.0 | 53.4 | - | - | - | - |
LoRA | 55.6 | 79.3 | 46.9 | 49.9 | 57.9 | 100.0% | 100.0% | 100.0% | 100.0% |
LoRA+Prune | 46.8 | 65.2 | 23.9 | 46.2 | 45.5 | 180.9% | 100.0% | 115.5% | 68.9% |
LLMPruner | 39.2 | 67.0 | 24.9 | 40.6 | 42.9 | 86.9% | 253.6% | 114.8% | 74.2% |
APT | 45.4 | 71.1 | 36.9 | 46.6 | 50.0 | 106.0% | 75.8% | 117.0% | 67.2% |
conda env create -f environment.yml
conda activate apt
For finetuning RoBERTa-base models with APT, please run:
bash scripts/adaptpruning/roberta_base_sst2_momentum.sh
For finetuning T5-base models with APT, please run:
bash scripts/adaptpruning/t5_base_lm_adapt_cnndm_momentum.sh
For finetuning LLaMA2 models on Alpaca with APT, please run:
bash scripts/adaptpruning/llama_2_7b_alpaca_gpt4.sh
If you use this code or our tuned models, please cite our paper:
@misc{zhao2024apt,
title={APT: Adaptive Pruning and Tuning Pretrained Language Models for Efficient Training and Inference},
author={Bowen Zhao and Hannaneh Hajishirzi and Qingqing Cao},
year={2024},
eprint={2401.12200},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
This project uses modified code from the following projects:
- CoFiPruning: developed by Princeton-nlp. Model backbone codes reused for pruning BERT and RoBERTa. See
models/modeling_bert.py
andmodels/modeling_roberta.py
. - Mask Tuning: developed by Woosuk Kwon. We use the codes as a baseline to prune the post-training LMs. See
prune/fisher.py
.