PyTorch implementation and pretrained models for AttMask.
[arXiv
]
You can download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide arguments and training logs. All backbones are ViT-S/16
models.
Percentage of ImageNet Train Set | Epochs | AttMask Mode | k-NN | Linear Probing | download | ||||
---|---|---|---|---|---|---|---|---|---|
100% | 100 | Hints | 72.8 | 76.1 | backbone only | full ckpt | logs | ||
100% | 100 | High | 72.5 | 75.7 | backbone only | full ckpt | logs | ||
100% | 300 | High | 75.0 | 77.5 | backbone only | full ckpt | logs | ||
20% | 100 | Hints | 49.5 | 57.5 | backbone only | full ckpt | logs | ||
20% | 100 | High | 49.7 | 57.9 | backbone only | full ckpt | logs |
Please install PyTorch and download the ImageNet dataset. The experiments have been performed with python version 3.7.6, PyTorch version 1.7.0, CUDA 11.0 and torchvision 0.8.1.
The requirements are easily installed via Anaconda:
conda create -n attmask python=3.7.6
conda activate attmask
conda install pytorch==1.7.0 torchvision==0.8.1 cudatoolkit=11.0 pillow==8.0.1 -c pytorch
pip install timm==0.4.12 tensorboardx==2.5.1 scikit-learn==0.23.2 munkres==1.1.4 tqdm
Pre-train AttMask on ImageNet-1k. Modify --nproc_per_node
and --batch_size_per_gpu
based on you available GPUs.
Train AttMask-Hint with a ViT-small network for 100 epochs on the full ImageNet-1k:
python -m torch.distributed.launch --nproc_per_node=4 main_attmask.py --batch_size_per_gpu 60 \
--norm_last_layer False --momentum_teacher 0.996 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 100 \
--shared_head True --out_dim 8192 --local_crops_number 10 --global_crops_scale 0.25 1 \
--local_crops_scale 0.05 0.25 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_hint --show_max 0.1 \
--subset -1 --data_path /path/to/ImageNet --output_dir /path/to/output/
Train AttMask-High with a ViT-small network for 100 epochs on the full ImageNet-1k:
python -m torch.distributed.launch --nproc_per_node=4 main_attmask.py --batch_size_per_gpu 60 \
--norm_last_layer False --momentum_teacher 0.996 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 100 \
--shared_head True --out_dim 8192 --local_crops_number 10 --global_crops_scale 0.25 1 \
--local_crops_scale 0.05 0.25 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_high \
--subset -1 --data_path /path/to/ImageNet --output_dir /path/to/output/
Train AttMask-High with a ViT-small network for 300 epochs on the full ImageNet-1k:
python -m torch.distributed.launch --nproc_per_node=8 main_attmask.py --batch_size_per_gpu 100 \
--norm_last_layer False --momentum_teacher 0.996 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 300 \
--shared_head True --out_dim 8192 --local_crops_number 10 --global_crops_scale 0.32 1 \
--local_crops_scale 0.05 0.32 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_high \
--subset -1 --data_path /path/to/ImageNet --output_dir /path/to/output/
Train AttMask-Hint with a ViT-small network for 100 epochs on the 20% of the ImageNet-1k:
python3 -m torch.distributed.launch --nproc_per_node=4 main_attmask.py --batch_size_per_gpu 60 \
--norm_last_layer False --momentum_teacher 0.99 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 100 \
--shared_head True --out_dim 8192 --local_crops_number 6 --global_crops_scale 0.25 1 \
--local_crops_scale 0.05 0.25 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_hint --show_max 0.1 \
--subset 260 --data_path /path/to/ImageNet --output_dir /path/to/output/
For the AttMask-High or AttMask-Low, just remove the argument
--show_max 0.1
and modify--pred_shape attmask_high
or--pred_shape attmask_low
.
K-NN ImageNet evaluation on full ImageNet-1k or on 20% of the training examples (set --subset 260
). Modify --nproc_per_node
based on you available GPUs.
python3 -m torch.distributed.launch --nproc_per_node=1 evaluation/eval_knn.py \
--arch vit_small --checkpoint_key teacher --avgpool_patchtokens 0 \
--pretrained_weights /path/to/checkpoint.pth --data_path /path/to/ImageNet --subset -1
Linear probing ImageNet evaluation on full ImageNet-1k or on 20% of the training examples (set --subset 260
). Modify --nproc_per_node
and --batch_size_per_gpu
based on you available GPUs.
python3 -m torch.distributed.launch --nproc_per_node=2 evaluation/eval_linear.py \
--batch_size_per_gpu 1024 --n_last_blocks 4 --avgpool_patchtokens 0 --arch vit_small --lr 0.005 \
--pretrained_weights /path/to/checkpoint.pth --data_path /path/to/ImageNet --output_dir /path/to/output/ --subset -1
Low shot ImageNet evaluation task, where only ν = 1, 5, 10 or 20
examples per class of the training set are used for the k-NN classifier.
For ν = 1
use --nb_knn 1 --subset 1
and average the results of 5 different runs using different subset --subset_starts 0
, 5
, 10
, 15
and 20
For ν = 5
use --nb_knn 20 --subset 5
and average the results of 5 different runs using different subset --subset_starts 0
, 5
, 10
, 15
and 20
For ν = 10
use --nb_knn 20 --subset 10
and average the results of 5 different runs using different subset --subset_starts 0
, 10
, 20
, 30
and 40
For ν = 20
use --nb_knn 20 --subset 20
and average the results of 5 different runs using different subset --subset_starts 0
, 20
, 40
, 60
and 80
Example for the first run using ν = 1
:
python3 -m torch.distributed.launch --nproc_per_node=1 evaluation/eval_knn_few.py \
--checkpoint_key teacher --avgpool_patchtokens 0 --arch vit_small \
--pretrained_weights /path/to/checkpoint --data_path /path/to/ImageNet \
--nb_knn 1 --subset 1 --subset_starts 0
Linear probing top-1 accuracy on a more challenging masked version of ImageNet1k validation set. Modify --nproc_per_node
based on you available GPUs.
Example for 0.7
masking ratio:
python -m torch.distributed.launch --nproc_per_node=1 evaluation/eval_linear_acc_drop.py \
--num_workers 4 --output_dir /path/to/output --data_path /path/to/ImageNet \
--pretrained_weights /path/to/checkpoint --n_last_blocks 4 --avgpool_patchtokens 0 \
--arch vit_small --subset -1 --batch_size_per_gpu 30 --lr 0.001 --load_from checkpoint_teacher_linear.pth \
--masking_ratio 0.7
Combine the pre-trained models and the linear head from the linear probing:
python models/combine_ckpt.py \
--checkpoint_pretraining /path/to/pretrained/checkpoint \
--checkpoint_linear /path/to/linear_probing_head \
--output_file /path/to/combined_checkpoint
Perform the background robustness evaluation:
Download and unzip the datasets and run the following:
python3 evaluation/backgrounds_challenge/in9_eval.py \
--arch vit_small \
--checkpoint /path/to/combined_checkpoint \
--data-path /path/to/dataset/bg_challenge
For finetuning on CIFAR10 run:
python3 -m torch.distributed.launch --nproc_per_node=6 evaluation/eval_transfer_finetuning.py \
--avgpool_patchtokens 0 \
--arch vit_small \
--checkpoint_key teacher \
--batch-size 150 \
--lr 7.5e-6 \
--epochs 500 \
--pretrained_weights /path/to/pretrained/checkpoint \
--output_dir /path/to/output \
--data_set CIFAR10 \
--data_path data/cifar
For finetuning on CIFAR100 modify the
--data_set CIFAR100
and on Oxford Flowers modify the--data_set Flowers
and the--epochs 1000
.
If you face NCCL errors during the evaluations, try to use --backend gloo
.
This repository is built using the iBOT, DINO, BEiT and ImageNet-9 repositories.
This repository is released under the Apache 2.0 license as found in the LICENSE file.
If you find this repository useful, please consider giving a star ⭐ and citation:
@article{kakogeorgiou2022attmask,
title={What to Hide from Your Students: Attention-Guided Masked Image Modeling},
url={https://arxiv.org/abs/2203.12719},
DOI={10.48550/arXiv.2203.12719},
journal={arXiv.org},
author={Kakogeorgiou, Ioannis and Gidaris, Spyros and Psomas, Bill and Avrithis, Yannis and Bursuc, Andrei and Karantzalos, Konstantinos and Komodakis, Nikos},
year={2022}
}