This repository contains code for pretraining Wide Residual Network (WRN) [1] on downsampled [2] ImageNet 32x32, ImageNet 64x64, and ImageNet 224x224 using cross-entropy and triplet loss [3].
For creating conda environment, a yml file tf2.yml
is provided for replicating setup.
conda env create -f tf2.yml
conda activate tf2
ImageNet full dataset can be downloaded from link. After downloading, set the
path of base_dir in data.py
.
ImageNet 32x32 and ImageNet 64x64 datasets can be generated either using scripts provided by Downsampled ImageNet or TensorFlow datasets package. The tensorflow_datasets package can be installed using pip:
pip install tensorflow_datasets
The current version of tensorflow_datasets=4.4.0
package has a broken link for downloading ImageNet 32x32 and ImageNet
64x64. The workaround is available at GitHub.
For pretraining from scratch using different setups, pretrain.py
can be used. Details of self-explanatory commandline
arguments can be seen by passing --help
to it.
python pretrain.py --help
USAGE: pretrain.py [flags]
flags:
pretrain.py:
--bs: batch_size
(default: '128')
(an integer)
--d: <imagenet_resized/32x32|imagenet_resized/64x64|imagenet-full>: dataset
(default: 'imagenet_resized/32x32')
--e: number of epochs
(default: '50')
(an integer)
--g: gpu id
(default: '0')
--lbl: <lda|knn>: Specify labelling method either LDA or KNN.
(default: 'lda')
--lr: learning_rate
(default: '0.001')
(a number)
--lt: <cross-entropy|triplet>: loss_type either cross-entropy or triplet.
(default: 'cross-entropy')
--margin: margin for triplet loss
(default: '1.0')
(a number)
--n: network
(default: 'wrn-28-2')
--[no]sw: save weights
(default: 'false')
Try --helpfull to get a list of all flags.
Pretrained weights will be saved into weights/
directory. We also provide pretrained weights. They can be downloaded
from releases and saved into weights/
directory. Path of downloaded weights can be set in wrn.py
.
For using pretrained weights, an example notebook is provided . For more details, see cifar_example.ipynb.
If you use the provided weights, kindly cite our paper.
@inproceedings{sahito2022better,
title={Better self-training for image classification through self-supervision},
author={Sahito, Attaullah and Frank, Eibe and Pfahringer, Bernhard},
booktitle={Australasian Joint Conference on Artificial Intelligence},
pages={645--657},
year={2022},
organization={Springer}
}
- Wide Residual Networks. Sergey Zagoruyko and Nikos Komodakis. In British Machine Vision Conference 2016. British Machine Vision Association, 2016.
- A downsampled variant of ImageNet as an alternative to the CIFAR datasets. Patryk Chrabaszcz, Ilya Loshchilov, and Frank Hutter. arXiv preprint arXiv:1707.08819, 2017 .
- Distance metric learning for large margin nearest neighbour classification. Kilian Q Weinberger and Lawrence K Saul. Journal of Machine Learning Research, 10(2), 2009.