Authors: Mengmi Zhang, Tao Wang, Joo Hwee Lim, and Gabriel Kreiman, and Jiashi Feng
This repository contains an implementation of a few-shot continual learning method for preventing catastrophic forgetting in object classification tasks.
Arxiv version is HERE.
Continual learning refers to the ability to acquire and transfer knowledge without catastrophically forgetting what was previously learned. In this work, we consider few-shot continual learning in classification tasks, and we propose a novel method, Variational Prototype Replays, that efficiently consolidates and recalls previous knowledge to avoid catastrophic forgetting. In each classification task, our method learns a set of variational prototypes with their means and variances, where embedding of the samples from the same class can be represented in a prototypical distribution and class-representative prototypes are separated apart. To alleviate catastrophic forgetting, our method replays one sample per class from previous tasks, and correspondingly matches newly predicted embeddings to their nearest class-representative prototypes stored from previous tasks. Compared with recent continual learning approaches, our method can readily adapt to new tasks with more classes without requiring the addition of new units. Furthermore, our method is more memory efficient since only class-representative prototypes with their means and variances, as well as only one sample per class from previous tasks need to be stored. Without tampering with the performance on initial tasks, our method learns novel concepts given a few training examples of each class in new tasks.
The code has been successfully tested in Ubuntu 18.04 with one GPU (NVIDIA RTX 2080 Ti). It requires the following:
- PyTorch = 1.1.0
- python = 2.7
- CUDA = 10.2
- torchvision = 0.3.0
Dependencies:
- numpy
- opencv
- scipy
- matplotlib
- skimage
Refer to link for Anaconda installation. Alternatively, execute the following command:
curl -O https://repo.anaconda.com/archive/Anaconda3-2019.03-Linux-x86_64.sh
bash Anaconda3-2019.03-Linux-x86_64.sh
After Anaconda installation, create a conda environment:
conda create -n pytorch27 python=2.7
Activate the conda environment:
conda activate pytorch27
In the conda environment, refer to link for Pytorch installation.
Download our repository:
git clone https://github.com/kreimanlab/VariationalPrototypeReplaysCL.git
The following instructions below are meant for permuted MNIST protocol but similar instructions can be applied in Split CIFAR10 protocol except that the folder to excute commands changes to CIFARincrement/variationalProtoCL
.
Navigate to MINSTpermuted/variationalProtoCL/
using the following command:
cd MINSTpermuted/variationalProtoCL/
Run the script and select GPU ID (example below selects GPU ID 0) to start training, evaluating and testing for total 50 tasks on GPUs.
./fewshots.sh 0
The trained models are saved in MINSTpermuted/variationalProtoCL/models
. The test results and variational prototypes for each task are stored in MINSTpermuted/variationalProtoCL/results
.
Use the following scripts to re-format the test results and save to .mat
for Matlab plots using script plots/MNIST_full.m
.
python EvalDataPlot.py
NOTE Modifications of output path directory may be needed for evaluation and plotting.
Download Mini-ImageNet dataset from HERE, unzip and place the downloaded dataset mini_imagenet
in folder IMAGENETincrement/variationalProtoCL/
Run the following command only ONCE to pre-process miniImageNet images:
cd IMAGENETincrement/variationalProtoCL/dataloaders/
python preprocess_mini_imagenet.py
NOTE Modifications of path directory are needed in preprocess_mini_imagenet.py
.
Follow instructions for permuted MNIST above for training, evaluating and testing our method in split miniImageNet protocol.
We adapted ICARL code and code for other methods to test comparative methods used in our paper. It is strongly recommended to download from their original repositories. Here, we provided the revised version of their codes in each paradigm. Run the following shell scrips in each dataset folder MINSTpermuted
, CIFARincrement
and IMAGENETincrement
to get results of ICARL:
icarl/icarlBCE.sh #binary cross-entropy loss on logits
icarl/icarlKLD.sh #KLDivergence loss on logits
icarl/icarlMSE.sh #mean squared error loss on logits
Similarly, to get resutls of all other comparative methods, use the following scripts:
MINSTpermuted/Continual-Learning-Benchmark-master/scripts/permuted_MNIST_incremental_domain.sh
CIFARincrement/Continual-Learning-Benchmark-master/scripts/split_CIFAR10_incremental_class.sh
IMAGENETincrement/Continual-Learning-Benchmark-master/scripts/split_ImageNet_incremental_class.sh
To run any of the shell scripts above, simiply follow the example below:
cd IMAGENETincrement/Continual-Learning-Benchmark-master/scripts/
./split_ImageNet_incremental_class.sh 0 #where 0 is GPU_ID
The source code is for illustration purpose only. We do not provide techinical supports but we would be happy to discuss about SCIENCE!
See Kreiman lab for license agreements before downloading and using our source codes and datasets.