中文 | English
Figure 1: PrePrompt two-stage framework (left) and feature extrapolation mechanism (right).
🔥 Official PyTorch implementation of PrePrompt, a two-stage predictive prompting framework that enables pre-trained models to first predict task-specific prompts and then perform label prediction, effectively balancing stability and plasticity in class-incremental learning.
Figure 2: Main difference between conventional prompt-based CIL methods and PrePrompt.
PrePrompt introduces a predictive prompting mechanism that leverages pre-trained models' natural classification ability to predict task-specific prompts.
Unlike conventional prompt-based CIL methods that rely on correlation-based strategies, where an image's classification feature is used as a query to retrieve the most related key prompts and select the corresponding value prompts for training, PrePrompt circumvents the correlation-based limitations that fitting the entire feature space of all tasks with only a few trainable prompts - ensuring robust knowledge retention, minimal forgetting, and efficient adaptation.
- 🧩 Predictive Prompting: Learns to anticipate task evolution, improving long-term adaptability.
- 📈 State-of-the-Art Results: Outperforms all prior prompt-based CIL methods across multiple benchmarks.
- ⚡ Lightweight Integration: Minimal computation overhead — plug-and-play for any ViT-based model.
- 🔁 Stable & Scalable: Balances plasticity (learning new tasks) and stability (preserving old knowledge).
10 tasks with equal number of classes of CIFAR-100, ImageNet-R, CUB-200 while 5 tasks of 5-Datasets:
| Dataset | Final Accuracy (%) ↑ | Average Incremental Accuracy (%) ↑ | Forgetting Rate (%) ↓ |
|---|---|---|---|
| CIFAR-100 | 93.74 | 95.41 | 1.27 |
| ImageNet-R | 75.09 | 78.96 | 1.11 |
| CUB-200 | 88.27 | 88.29 | 1.81 |
| 5-Datasets | 94.54 | 95.78 | 0.21 |
# Create and activate conda environment
conda create -n preprompt python=3.8 -y
conda activate preprompt
# Install dependencies (retry if network issues occur)
pip install -r requirements.txttimm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
torchprofile==0.0.4
torch==1.13.1
torchvision==0.14.1
urllib3==2.0.3
scipy==1.7.3
scikit-learn==1.0.2
numpy==1.21.6
PrePrompt automatically handles downloading and preprocessing for the following datasets:
- 🖼️ CIFAR-100 — 100-class object recognition
- 🎨 ImageNet-R — artistic renditions of ImageNetrenditions
- 🐦 CUB-200 — fine-grained bird classification
- 🔢 5-Datasets — composite of SVHN, MNIST, CIFAR-10, notMNIST, and Fashion-MNIST
💡 Tip: If your network is unstable, pre-download datasets into ./datasets/.
Run the corresponding training scripts for each benchmark:
# CIFAR-100 Experiments
bash training_scripts/train_cifar100_vit.sh
# ImageNet-R Experiments
bash training_scripts/train_imr_vit.sh
# CUB-200 Fine-grained Classification
bash training_scripts/train_cub_vit.sh
# 5-Datasets Sequential Learning
bash training_scripts/train_5datasets_vit.shLogs and checkpoints will be stored in ./outputs/.
This repository builds upon the following excellent open-source projects:
- DualPrompt — continual prompting foundations.
- HiDe-Prompt — hierarchical prompt architecture
This project is released under the MIT License. See the LICENSE file for details.


