Another version of code documentation is inside readme.pdf file.
Provides installation guide.
Details on how to run the code.
Some examples to reproduce results in the paper.
1.2 Pytorch(1.0.0) which can be installed with details provided here: https://pytorch.org/get-started/locally/
For most users, pip3 install torch torchvision
should work.
If you are using Anaconda, conda install pytorch torchvision -c pytorch
should work.
1.3 Microsoft NNI toolkit: Details of installation provided here: https://github.com/Microsoft/nni/
For most users, pip3 install nni
should work.
2.1.3 train.py
The main code. Given a training config which we will explain below, it will train models.
These settings are needed duri
Number of training epochs.
Dataset name, can be 'cifar10' or 'cifar100'. Default is 'cifar100'.
mini batch size used in training. Default is 128.
Initial learning rate for SGD optimizer. Depending on models in might be changed during training.
momentum for SGD optimizer. Default is 0.9.
Weight decay for SGD optimizer. Default is 0.0001.
Teacher model name.
Can be resnetX
for resnetwhere X can be any value in (8, 14, 20, 26, 32, 44, 56, 110)
or PlaneY
for plane(vanilla CNN) networks where Y can be any value in (2, 4, 6, 8, 10). For details of network please refer to the paper.
Student model name.
Values can be in forms of resnetX
or planeY
which explained before.
Path for a file which has pretrained teacher. Defauls is empty which means we need to train the teacher also.
wheather or not train on GPU. (must have GPU supportive pytorch installed). Values 1 and true can be used for training on GPU.
location of the dataset. default is './data/'
NNI toolkit needs a search_space file (like search_space.json
) consists T
and lambda
in equation 3 of the paper. Also, in order to get reliable results, there will be multiple seeds
to avoid bad runs. For more details on search space file, please refer to the example sections or https://microsoft.github.io/nni/docs/SearchSpaceSpec.html.
You have to run the code using nnictl create --config config.yml
. Then the hyper parameter optimizer will run experiments with different hyper parameters and the results will be available thorough a dashboard.
In all examples, you need to change the command
line of config.yml
and tell the nnictl runner to how to run an experiment.
You should change the command part of the config file like this:
command: python3 train.py --epochs 160 --teacher resnet110 --student resnet8 --cuda 1 --dataset cifar10
-
Train Teacher(Resnet110): This phase is not knowledge distillation. So there's no teacher and only a student trained alone.
command: python3 train.py --epochs 160 --student resnet110 --cuda 1 --dataset cifar100
-
After first step, choose the weights which had best accuracy on valdiation data and train TA(Resnet20) with teacher (Resnet110) weights. Say the best resnet110 weights file was resnet110_XXXX_best.pth.tar
command: python3 train.py --epochs 160 --teacher resnet110 --teacher-checkpoint ./resnet110_XXXX_best.pth.tar --student resnet20 --cuda 1 --dataset cifar100
-
Repeat like step two, distillate knowledge from TA to student (Teacher is resnet20, student is resnet8). Also, we assume the best weights from step two was resnet20_XXXX_best.pth.tar
command: python3 train.py --epochs 160 --teacher resnet20 --teacher-checkpoint ./resnet20_XXXX_best.pth.tar --student resnet8 --cuda 1 --dataset cifar100
-
Train Teacher(Resnet110): This phase is not knowledge distillation. So there's no teacher and only a student trained alone.
command: python3 train.py --epochs 160 --student resnet110 --cuda 1 --dataset cifar10
-
After first step, choose the weights which had best accuracy on valdiation data and train TA(Resnet14) with teacher (Resnet110) weights. Say the best resnet110 weights file was resnet110_XXXX_best.pth.tar
command: python3 train.py --epochs 160 --teacher resnet110 --teacher-checkpoint ./resnet110_XXXX_best.pth.tar --student resnet14 --cuda 1 --dataset cifar10
-
Repeat like step two, distillate knowledge from TA to student (Teacher is resnet14, student is resnet8). Also, we assume the best weights from step two was resnet14_XXXX_best.pth.tar
command: python3 train.py --epochs 160 --teacher resnet14 --teacher-checkpoint ./resnet14_XXXX_best.pth.tar --student resnet8 --cuda 1 --dataset cifar10
arxiv link: https://arxiv.org/pdf/1902.03393.pdf
If you found this library useful in your research, please consider citing:
@article{mirzadeh2019improved,
title={Improved Knowledge Distillation via Teacher Assistant: Bridging the Gap Between Student and Teacher},
author={Mirzadeh, Seyed-Iman and Farajtabar, Mehrdad and Li, Ang and Ghasemzadeh, Hassan},
journal={arXiv preprint arXiv:1902.03393},
year={2019}
}