This implements training of Residual Attention Network on the ImageNet dataset, and provide the pretrained weights.
pip install 'git+ssh://git@github.com/phamquiluan/ResidualAttentionNetwork.git@v0.2.0'
import torch
from resattnet import resattnet56
m = resattnet56(in_channels=3, num_classes=10) # pretrained is load automatically
tensor = torch.Tensor(1, 3, 224, 224)
output = m(tensor)
print(output.shape) # torch.Size([1, 10])
Download resattnet56 pretrained Imagenet1K: link
Eval: Acc@1 77.024 Acc@5 93.574
- Install PyTorch (pytorch.org)
pip install -r requirements.txt
- Download the ImageNet dataset from http://www.image-net.org/
- Then, and move validation images to labeled subfolders, using the following shell script
To train a model, run main.py
with the desired model architecture and the path to the ImageNet dataset:
python main.py -a resattnet56 [imagenet-folder with train and val folders]
You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.
python main.py -a resattnet56 --dist-url 'tcp://127.0.0.1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 [imagenet-folder with train and val folders]