-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
sccbhxc
committed
Dec 9, 2020
0 parents
commit e22338e
Showing
22 changed files
with
1,523 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# NetGraft | ||
|
||
Knowledge distillation has demonstrated encouraging performances in deep model compression. Most existing approaches, however, require massive labeled data to accomplish the knowledge transfer, making the model compression a cumbersome and costly process. In this paper, we investigate the practical **few-shot** knowledge distillation scenario, where we assume only a few samples without human annotations are available for each category. To this end, we introduce a principled dual-stage distillation scheme tailored for few-shot data. In the first step, we graft the student blocks one by one onto the teacher, and learn the parameters of the grafted block intertwined with those of the other teacher blocks. In the second step, the trained student blocks are progressively connected and then together grafted onto the teacher network, allowing the learned student blocks to adapt themselves to each other and eventually replace the teacher network. Experiments demonstrate that our approach, with only a few unlabeled samples, achieves gratifying results on CIFAR10, CIFAR100, and ILSVRC-2012. On CIFAR10 and CIFAR100, our performances are even on par with those of knowledge distillation schemes that utilize the full datasets. | ||
|
||
|
||
|
||
![](images/framework.png) | ||
|
||
|
||
|
||
``` | ||
@inproceedings{shen2021progressive, | ||
author = {Shen, Chengchao and Wang, Xinchao and Yin, Youtan and Song, Jie and Luo, Sihui and Song, Mingli}, | ||
title = {Progressive Network Grafting for Few-Shot Knowledge Distillation}, | ||
booktitle = {AAAI Conference on Artificial Intelligence (AAAI)}, | ||
year = {2021} | ||
} | ||
``` | ||
|
||
|
||
|
||
## Requirements | ||
|
||
To install requirements: | ||
|
||
```bash | ||
conda create -n netgraft python=3.7 | ||
pip install -r requirements.txt | ||
``` | ||
|
||
|
||
|
||
## Pre-trained Teacher Models | ||
|
||
You can download pretrained teacher models here (Github Releases): | ||
|
||
- [Teacher model]() trained on full CIFAR10. | ||
- [Teacher model]() trained on full CIFAR100. | ||
|
||
**Note**: put the pre-trained teacher models in the directory: `ckpt/teacher` | ||
|
||
|
||
|
||
## Dataset Preparation | ||
|
||
To download and build datasets (CIFAR10 and CIFAR100) for few-shot distillation, run this command: | ||
|
||
```python | ||
python build_dataset.py | ||
``` | ||
|
||
|
||
|
||
## Training | ||
|
||
To train the model(s) in the paper, run this command: | ||
|
||
```python | ||
# ----------- Run on CIFAR10 ----------- | ||
python train.py --dataset CIFAR10 # Training [1~10, 20, 50]-Shot Distillation | ||
|
||
# ----------- Run on CIFAR100 ----------- | ||
python train.py --dataset CIFAR100 # Training [1~10, 20, 50]-Shot Distillation | ||
``` | ||
|
||
|
||
|
||
## Evaluation | ||
|
||
To evaluate my model on CIFAR10 and CIFAR100, run: | ||
|
||
```python | ||
# ----------- Run on CIFAR10 ----------- | ||
python evaluate.py --dataset='CIFAR10' --nshot=1 # 1-Shot Distillation | ||
python evaluate.py --dataset='CIFAR10' --nshot=5 # 5-Shot Distillation | ||
python evaluate.py --dataset='CIFAR10' --nshot=10 # 10-Shot Distillation | ||
|
||
# ----------- Run on CIFAR100 ----------- | ||
python evaluate.py --dataset='CIFAR100' --nshot=1 # 1-Shot Distillation | ||
python evaluate.py --dataset='CIFAR100' --nshot=5 # 5-Shot Distillation | ||
python evaluate.py --dataset='CIFAR100' --nshot=10 # 10-Shot Distillation | ||
``` | ||
|
||
|
||
|
||
## Experimental Results | ||
|
||
### Few-Shot Distillation on CIFAR10 and CIFAR100 | ||
|
||
| N-Shot | Accuracy on CIFAR10 (%) | Accuracy on CIFAR100 (%) | | ||
| :----: | :---------------------: | :----------------------: | | ||
| 1 | 90.74$\pm$0.49 | 64.22$\pm$0.17 | | ||
| 2 | 92.60$\pm$0.06 | 66.51$\pm$0.11 | | ||
| 3 | 92.70$\pm$0.07 | 67.35$\pm$0.10 | | ||
| 4 | 92.77$\pm$0.04 | 67.69$\pm$0.03 | | ||
| 5 | 92.88$\pm$0.07 | 68.16$\pm$0.20 | | ||
| 6 | 92.84$\pm$0.08 | 68.38$\pm$0.11 | | ||
| 7 | 92.77$\pm$0.05 | 68.46$\pm$0.10 | | ||
| 8 | 92.83$\pm$0.06 | 68.78$\pm$0.22 | | ||
| 9 | 92.88$\pm$0.05 | 68.77$\pm$0.10 | | ||
| 10 | 92.89$\pm$0.06 | 68.86$\pm$0.03 | | ||
| 20 | 92.78$\pm$0.09 | 69.04$\pm$0.08 | | ||
| 50 | 92.76$\pm$0.09 | 69.06$\pm$0.10 | | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
|
||
from dataset import extract_dataset_from_cifar10, extract_dataset_from_cifar100 | ||
|
||
from torchvision.datasets.utils import download_and_extract_archive | ||
|
||
|
||
def download_dataset(): | ||
filename_10 = "cifar-10-python.tar.gz" | ||
url_10 = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" | ||
download_and_extract_archive(url_10, download_root='./data/', filename=filename_10, extract_root='./data/') | ||
|
||
filename_100 = "cifar-100-python.tar.gz" | ||
url_100 = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" | ||
download_and_extract_archive(url_100, download_root='./data/', filename=filename_100, extract_root='./data/') | ||
|
||
def build_dataset(): | ||
os.makedirs('data/', exist_ok=True) | ||
|
||
download_dataset() | ||
|
||
num_samples = list(range(1, 11)) + [20, 50] | ||
for i in num_samples: | ||
extract_dataset_from_cifar10(i) | ||
extract_dataset_from_cifar100(i) | ||
|
||
|
||
if __name__ == "__main__": | ||
build_dataset() |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import os | ||
import pickle | ||
import sys | ||
import random | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
from torch.utils.data import Dataset | ||
|
||
|
||
class CIFAR10Few(Dataset): | ||
def __init__(self, root, num_per_class=50, transform=None): | ||
|
||
self.root = root | ||
|
||
self.transform = transform | ||
|
||
filename = 'cifar10-random-{}-per-class.pkl'.format(num_per_class) | ||
|
||
file_path = os.path.join(self.root, filename) | ||
with open(file_path, 'rb') as f: | ||
entry = pickle.load(f) | ||
|
||
self.data = entry | ||
|
||
self.data = self.data.transpose((0, 2, 3, 1)) | ||
|
||
def __getitem__(self, index): | ||
img = self.data[index] | ||
|
||
img = Image.fromarray(img) | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
return img | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
def load_cifar10(root): | ||
train_list = [ | ||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'], | ||
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], | ||
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], | ||
['data_batch_4', '634d18415352ddfa80567beed471001a'], | ||
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], | ||
] | ||
|
||
base_folder = 'cifar-10-batches-py' | ||
|
||
data = [] | ||
targets = [] | ||
for file_name, _ in train_list: | ||
file_path = os.path.join(root, base_folder, file_name) | ||
with open(file_path, 'rb') as f: | ||
if sys.version_info[0] == 2: | ||
entry = pickle.load(f) | ||
else: | ||
entry = pickle.load(f, encoding='latin1') | ||
data.append(entry['data']) | ||
if 'labels' in entry: | ||
targets.extend(entry['labels']) | ||
else: | ||
targets.extend(entry['fine_labels']) | ||
data = np.vstack(data) | ||
return data, targets | ||
|
||
|
||
def categorize(data, targets, num_class): | ||
data_by_class = [[] for i in range(num_class)] | ||
for i, target in enumerate(targets): | ||
data_by_class[target].append(data[i]) | ||
|
||
for i in range(num_class): | ||
data_by_class[i] = np.vstack(data_by_class[i]) | ||
|
||
return data_by_class | ||
|
||
|
||
def extract_dataset_from_cifar10(num_per_class): | ||
data, targets = load_cifar10('data') | ||
|
||
num_class = 10 | ||
data_by_class = categorize(data, targets, num_class) | ||
|
||
random.seed(6) | ||
data_select = [] | ||
for i in range(num_class): | ||
idxs = list(range(5000)) | ||
random.shuffle(idxs) | ||
idx_select = idxs[:num_per_class] | ||
data_select.append(data_by_class[i][idx_select]) | ||
|
||
data_select = np.vstack(data_select).reshape(-1, 3, 32, 32) | ||
|
||
with open('data/cifar10-random-{}-per-class.pkl'.format(num_per_class), \ | ||
'wb') as f: | ||
pickle.dump(data_select, f) | ||
|
||
|
||
class CIFAR100Few(Dataset): | ||
def __init__(self, root, num_per_class=10, transform=None): | ||
self.root = root | ||
|
||
self.transform = transform | ||
|
||
filename = 'cifar100-random-{}-per-class.pkl'.format(num_per_class) | ||
|
||
file_path = os.path.join(self.root, filename) | ||
with open(file_path, 'rb') as f: | ||
entry = pickle.load(f) | ||
|
||
self.data = entry | ||
|
||
self.data = self.data.transpose((0, 2, 3, 1)) | ||
|
||
def __getitem__(self, index): | ||
img = self.data[index] | ||
|
||
img = Image.fromarray(img) | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
return img | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
def load_cifar100(root): | ||
train_list = [ | ||
['train', '16019d7e3df5f24257cddd939b257f8d'], | ||
] | ||
|
||
base_folder = 'cifar-100-python' | ||
|
||
data = [] | ||
targets = [] | ||
for file_name, _ in train_list: | ||
file_path = os.path.join(root, base_folder, file_name) | ||
with open(file_path, 'rb') as f: | ||
if sys.version_info[0] == 2: | ||
entry = pickle.load(f) | ||
else: | ||
entry = pickle.load(f, encoding='latin1') | ||
data.append(entry['data']) | ||
if 'fine_labels' in entry: | ||
targets.extend(entry['fine_labels']) | ||
else: | ||
targets.extend(entry['coarse_labels']) | ||
data = np.vstack(data) | ||
return data, targets | ||
|
||
|
||
def extract_dataset_from_cifar100(num_per_class): | ||
data, targets = load_cifar100('data') | ||
|
||
num_class = 100 | ||
data_by_class = categorize(data, targets, num_class) | ||
|
||
random.seed(10) | ||
data_select = [] | ||
for i in range(num_class): | ||
idxs = list(range(500)) | ||
random.shuffle(idxs) | ||
idx_select = idxs[:num_per_class] | ||
data_select.append(data_by_class[i][idx_select]) | ||
|
||
data_select = np.vstack(data_select).reshape(-1, 3, 32, 32) | ||
|
||
with open('data/cifar100-random-{}-per-class.pkl'.format(num_per_class), \ | ||
'wb') as f: | ||
pickle.dump(data_select, f) | ||
|
||
|
||
if __name__ == '__main__': | ||
num_samples = list(range(1, 11)) + [20, 50] | ||
for i in num_samples: | ||
extract_dataset_from_cifar10(i) | ||
|
||
|
Oops, something went wrong.