Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sccbhxc committed Dec 9, 2020
0 parents commit e22338e
Show file tree
Hide file tree
Showing 22 changed files with 1,523 additions and 0 deletions.
104 changes: 104 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# NetGraft

Knowledge distillation has demonstrated encouraging performances in deep model compression. Most existing approaches, however, require massive labeled data to accomplish the knowledge transfer, making the model compression a cumbersome and costly process. In this paper, we investigate the practical **few-shot** knowledge distillation scenario, where we assume only a few samples without human annotations are available for each category. To this end, we introduce a principled dual-stage distillation scheme tailored for few-shot data. In the first step, we graft the student blocks one by one onto the teacher, and learn the parameters of the grafted block intertwined with those of the other teacher blocks. In the second step, the trained student blocks are progressively connected and then together grafted onto the teacher network, allowing the learned student blocks to adapt themselves to each other and eventually replace the teacher network. Experiments demonstrate that our approach, with only a few unlabeled samples, achieves gratifying results on CIFAR10, CIFAR100, and ILSVRC-2012. On CIFAR10 and CIFAR100, our performances are even on par with those of knowledge distillation schemes that utilize the full datasets.



![](images/framework.png)



```
@inproceedings{shen2021progressive,
author = {Shen, Chengchao and Wang, Xinchao and Yin, Youtan and Song, Jie and Luo, Sihui and Song, Mingli},
title = {Progressive Network Grafting for Few-Shot Knowledge Distillation},
booktitle = {AAAI Conference on Artificial Intelligence (AAAI)},
year = {2021}
}
```



## Requirements

To install requirements:

```bash
conda create -n netgraft python=3.7
pip install -r requirements.txt
```



## Pre-trained Teacher Models

You can download pretrained teacher models here (Github Releases):

- [Teacher model]() trained on full CIFAR10.
- [Teacher model]() trained on full CIFAR100.

**Note**: put the pre-trained teacher models in the directory: `ckpt/teacher`



## Dataset Preparation

To download and build datasets (CIFAR10 and CIFAR100) for few-shot distillation, run this command:

```python
python build_dataset.py
```



## Training

To train the model(s) in the paper, run this command:

```python
# ----------- Run on CIFAR10 -----------
python train.py --dataset CIFAR10 # Training [1~10, 20, 50]-Shot Distillation

# ----------- Run on CIFAR100 -----------
python train.py --dataset CIFAR100 # Training [1~10, 20, 50]-Shot Distillation
```



## Evaluation

To evaluate my model on CIFAR10 and CIFAR100, run:

```python
# ----------- Run on CIFAR10 -----------
python evaluate.py --dataset='CIFAR10' --nshot=1 # 1-Shot Distillation
python evaluate.py --dataset='CIFAR10' --nshot=5 # 5-Shot Distillation
python evaluate.py --dataset='CIFAR10' --nshot=10 # 10-Shot Distillation

# ----------- Run on CIFAR100 -----------
python evaluate.py --dataset='CIFAR100' --nshot=1 # 1-Shot Distillation
python evaluate.py --dataset='CIFAR100' --nshot=5 # 5-Shot Distillation
python evaluate.py --dataset='CIFAR100' --nshot=10 # 10-Shot Distillation
```



## Experimental Results

### Few-Shot Distillation on CIFAR10 and CIFAR100

| N-Shot | Accuracy on CIFAR10 (%) | Accuracy on CIFAR100 (%) |
| :----: | :---------------------: | :----------------------: |
| 1 | 90.74$\pm$0.49 | 64.22$\pm$0.17 |
| 2 | 92.60$\pm$0.06 | 66.51$\pm$0.11 |
| 3 | 92.70$\pm$0.07 | 67.35$\pm$0.10 |
| 4 | 92.77$\pm$0.04 | 67.69$\pm$0.03 |
| 5 | 92.88$\pm$0.07 | 68.16$\pm$0.20 |
| 6 | 92.84$\pm$0.08 | 68.38$\pm$0.11 |
| 7 | 92.77$\pm$0.05 | 68.46$\pm$0.10 |
| 8 | 92.83$\pm$0.06 | 68.78$\pm$0.22 |
| 9 | 92.88$\pm$0.05 | 68.77$\pm$0.10 |
| 10 | 92.89$\pm$0.06 | 68.86$\pm$0.03 |
| 20 | 92.78$\pm$0.09 | 69.04$\pm$0.08 |
| 50 | 92.76$\pm$0.09 | 69.06$\pm$0.10 |

29 changes: 29 additions & 0 deletions build_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from dataset import extract_dataset_from_cifar10, extract_dataset_from_cifar100

from torchvision.datasets.utils import download_and_extract_archive


def download_dataset():
filename_10 = "cifar-10-python.tar.gz"
url_10 = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
download_and_extract_archive(url_10, download_root='./data/', filename=filename_10, extract_root='./data/')

filename_100 = "cifar-100-python.tar.gz"
url_100 = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
download_and_extract_archive(url_100, download_root='./data/', filename=filename_100, extract_root='./data/')

def build_dataset():
os.makedirs('data/', exist_ok=True)

download_dataset()

num_samples = list(range(1, 11)) + [20, 50]
for i in num_samples:
extract_dataset_from_cifar10(i)
extract_dataset_from_cifar100(i)


if __name__ == "__main__":
build_dataset()
Empty file.
Empty file.
Empty file added data/put-data-here.txt
Empty file.
185 changes: 185 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
import pickle
import sys
import random

import numpy as np
from PIL import Image

from torch.utils.data import Dataset


class CIFAR10Few(Dataset):
def __init__(self, root, num_per_class=50, transform=None):

self.root = root

self.transform = transform

filename = 'cifar10-random-{}-per-class.pkl'.format(num_per_class)

file_path = os.path.join(self.root, filename)
with open(file_path, 'rb') as f:
entry = pickle.load(f)

self.data = entry

self.data = self.data.transpose((0, 2, 3, 1))

def __getitem__(self, index):
img = self.data[index]

img = Image.fromarray(img)

if self.transform is not None:
img = self.transform(img)

return img

def __len__(self):
return len(self.data)


def load_cifar10(root):
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]

base_folder = 'cifar-10-batches-py'

data = []
targets = []
for file_name, _ in train_list:
file_path = os.path.join(root, base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
data.append(entry['data'])
if 'labels' in entry:
targets.extend(entry['labels'])
else:
targets.extend(entry['fine_labels'])
data = np.vstack(data)
return data, targets


def categorize(data, targets, num_class):
data_by_class = [[] for i in range(num_class)]
for i, target in enumerate(targets):
data_by_class[target].append(data[i])

for i in range(num_class):
data_by_class[i] = np.vstack(data_by_class[i])

return data_by_class


def extract_dataset_from_cifar10(num_per_class):
data, targets = load_cifar10('data')

num_class = 10
data_by_class = categorize(data, targets, num_class)

random.seed(6)
data_select = []
for i in range(num_class):
idxs = list(range(5000))
random.shuffle(idxs)
idx_select = idxs[:num_per_class]
data_select.append(data_by_class[i][idx_select])

data_select = np.vstack(data_select).reshape(-1, 3, 32, 32)

with open('data/cifar10-random-{}-per-class.pkl'.format(num_per_class), \
'wb') as f:
pickle.dump(data_select, f)


class CIFAR100Few(Dataset):
def __init__(self, root, num_per_class=10, transform=None):
self.root = root

self.transform = transform

filename = 'cifar100-random-{}-per-class.pkl'.format(num_per_class)

file_path = os.path.join(self.root, filename)
with open(file_path, 'rb') as f:
entry = pickle.load(f)

self.data = entry

self.data = self.data.transpose((0, 2, 3, 1))

def __getitem__(self, index):
img = self.data[index]

img = Image.fromarray(img)

if self.transform is not None:
img = self.transform(img)

return img

def __len__(self):
return len(self.data)


def load_cifar100(root):
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]

base_folder = 'cifar-100-python'

data = []
targets = []
for file_name, _ in train_list:
file_path = os.path.join(root, base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
data.append(entry['data'])
if 'fine_labels' in entry:
targets.extend(entry['fine_labels'])
else:
targets.extend(entry['coarse_labels'])
data = np.vstack(data)
return data, targets


def extract_dataset_from_cifar100(num_per_class):
data, targets = load_cifar100('data')

num_class = 100
data_by_class = categorize(data, targets, num_class)

random.seed(10)
data_select = []
for i in range(num_class):
idxs = list(range(500))
random.shuffle(idxs)
idx_select = idxs[:num_per_class]
data_select.append(data_by_class[i][idx_select])

data_select = np.vstack(data_select).reshape(-1, 3, 32, 32)

with open('data/cifar100-random-{}-per-class.pkl'.format(num_per_class), \
'wb') as f:
pickle.dump(data_select, f)


if __name__ == '__main__':
num_samples = list(range(1, 11)) + [20, 50]
for i in num_samples:
extract_dataset_from_cifar10(i)


Loading

0 comments on commit e22338e

Please sign in to comment.