Skip to content

[arXiv-2025.5.13] Official implementation of "PrePrompt: Predictive Prompting for Class-Incremental Learning"

License

Notifications You must be signed in to change notification settings

libo-huang/preprompt

Repository files navigation

🌟 PrePrompt: Predictive Prompting for Class-Incremental Learning License: MIT Framework Python 3.8

中文 | English

PrePrompt Framework Feature Txtrapolation

Figure 1: PrePrompt two-stage framework (left) and feature extrapolation mechanism (right).

🔥 Official PyTorch implementation of PrePrompt, a two-stage predictive prompting framework that enables pre-trained models to first predict task-specific prompts and then perform label prediction, effectively balancing stability and plasticity in class-incremental learning.


🧠 Overview

PrePrompt Overview

Figure 2: Main difference between conventional prompt-based CIL methods and PrePrompt.

PrePrompt introduces a predictive prompting mechanism that leverages pre-trained models' natural classification ability to predict task-specific prompts.

Unlike conventional prompt-based CIL methods that rely on correlation-based strategies, where an image's classification feature is used as a query to retrieve the most related key prompts and select the corresponding value prompts for training, PrePrompt circumvents the correlation-based limitations that fitting the entire feature space of all tasks with only a few trainable prompts - ensuring robust knowledge retention, minimal forgetting, and efficient adaptation.

🚀 Key Highlights

  • 🧩 Predictive Prompting: Learns to anticipate task evolution, improving long-term adaptability.
  • 📈 State-of-the-Art Results: Outperforms all prior prompt-based CIL methods across multiple benchmarks.
  • Lightweight Integration: Minimal computation overhead — plug-and-play for any ViT-based model.
  • 🔁 Stable & Scalable: Balances plasticity (learning new tasks) and stability (preserving old knowledge).

📊 Benchmark Results

10 tasks with equal number of classes of CIFAR-100, ImageNet-R, CUB-200 while 5 tasks of 5-Datasets:

Dataset Final Accuracy (%) ↑ Average Incremental Accuracy (%) ↑ Forgetting Rate (%) ↓
CIFAR-100 93.74 95.41 1.27
ImageNet-R 75.09 78.96 1.11
CUB-200 88.27 88.29 1.81
5-Datasets 94.54 95.78 0.21

🛠️ Installation

Environment Setup

# Create and activate conda environment
conda create -n preprompt python=3.8 -y
conda activate preprompt

# Install dependencies (retry if network issues occur)
pip install -r requirements.txt

Dependencies

timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
torchprofile==0.0.4
torch==1.13.1
torchvision==0.14.1
urllib3==2.0.3
scipy==1.7.3
scikit-learn==1.0.2
numpy==1.21.6

📁 Datasets

PrePrompt automatically handles downloading and preprocessing for the following datasets:

  • 🖼️ CIFAR-100 — 100-class object recognition
  • 🎨 ImageNet-R — artistic renditions of ImageNetrenditions
  • 🐦 CUB-200 — fine-grained bird classification
  • 🔢 5-Datasets — composite of SVHN, MNIST, CIFAR-10, notMNIST, and Fashion-MNIST

💡 Tip: If your network is unstable, pre-download datasets into ./datasets/.

🎯 Quick Start

Run the corresponding training scripts for each benchmark:

# CIFAR-100 Experiments
bash training_scripts/train_cifar100_vit.sh

# ImageNet-R Experiments  
bash training_scripts/train_imr_vit.sh

# CUB-200 Fine-grained Classification
bash training_scripts/train_cub_vit.sh

# 5-Datasets Sequential Learning
bash training_scripts/train_5datasets_vit.sh

Logs and checkpoints will be stored in ./outputs/.

🙏 Acknowledgments

This repository builds upon the following excellent open-source projects:

📜 License

This project is released under the MIT License. See the LICENSE file for details.

About

[arXiv-2025.5.13] Official implementation of "PrePrompt: Predictive Prompting for Class-Incremental Learning"

Resources

License

Stars

Watchers

Forks