Skip to content

This repo contains the official implementation of ECCV 2022 paper "What to Hide from Your Students: Attention-Guided Masked Image Modeling"

License

Notifications You must be signed in to change notification settings

rslab-ntua/attmask

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

What to Hide from Your Students: Attention-Guided Masked Image Modeling

PyTorch implementation and pretrained models for AttMask. [arXiv]

AttMask illustration

Pretrained models

You can download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide arguments and training logs. All backbones are ViT-S/16 models.

Percentage of ImageNet Train Set Epochs AttMask Mode k-NN Linear Probing download
100% 100 Hints 72.8 76.1 backbone only full ckpt logs
100% 100 High 72.5 75.7 backbone only full ckpt logs
100% 300 High 75.0 77.5 backbone only full ckpt logs
20% 100 Hints 49.5 57.5 backbone only full ckpt logs
20% 100 High 49.7 57.9 backbone only full ckpt logs

Training

Installation

Please install PyTorch and download the ImageNet dataset. The experiments have been performed with python version 3.7.6, PyTorch version 1.7.0, CUDA 11.0 and torchvision 0.8.1.

The requirements are easily installed via Anaconda:

conda create -n attmask python=3.7.6
conda activate attmask
conda install pytorch==1.7.0 torchvision==0.8.1 cudatoolkit=11.0 pillow==8.0.1 -c pytorch
pip install timm==0.4.12 tensorboardx==2.5.1 scikit-learn==0.23.2 munkres==1.1.4 tqdm

AttMask training

Pre-train AttMask on ImageNet-1k. Modify --nproc_per_node and --batch_size_per_gpu based on you available GPUs.

Full ImageNet-1k pre-training

Train AttMask-Hint with a ViT-small network for 100 epochs on the full ImageNet-1k:

python -m torch.distributed.launch --nproc_per_node=4 main_attmask.py --batch_size_per_gpu 60 \
--norm_last_layer False --momentum_teacher 0.996 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 100 \
--shared_head True --out_dim 8192 --local_crops_number 10 --global_crops_scale 0.25 1 \
--local_crops_scale 0.05 0.25 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_hint --show_max 0.1 \
--subset -1 --data_path /path/to/ImageNet --output_dir /path/to/output/

Train AttMask-High with a ViT-small network for 100 epochs on the full ImageNet-1k:

python -m torch.distributed.launch --nproc_per_node=4 main_attmask.py --batch_size_per_gpu 60 \
--norm_last_layer False --momentum_teacher 0.996 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 100 \
--shared_head True --out_dim 8192 --local_crops_number 10 --global_crops_scale 0.25 1 \
--local_crops_scale 0.05 0.25 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_high \
--subset -1 --data_path /path/to/ImageNet --output_dir /path/to/output/

Train AttMask-High with a ViT-small network for 300 epochs on the full ImageNet-1k:

python -m torch.distributed.launch --nproc_per_node=8 main_attmask.py --batch_size_per_gpu 100 \
--norm_last_layer False --momentum_teacher 0.996 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 300 \
--shared_head True --out_dim 8192 --local_crops_number 10 --global_crops_scale 0.32 1 \
--local_crops_scale 0.05 0.32 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_high \
--subset -1 --data_path /path/to/ImageNet --output_dir /path/to/output/

20% of ImageNet-1k pre-training

Train AttMask-Hint with a ViT-small network for 100 epochs on the 20% of the ImageNet-1k:

python3 -m torch.distributed.launch --nproc_per_node=4 main_attmask.py --batch_size_per_gpu 60 \
--norm_last_layer False --momentum_teacher 0.99 --num_workers 4 --eval_every 20 \
--arch vit_small --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --epochs 100 \
--shared_head True --out_dim 8192 --local_crops_number 6 --global_crops_scale 0.25 1 \
--local_crops_scale 0.05 0.25 --pred_ratio 0.3 --pred_ratio_var 0.2 --masking_prob 0.5 \
--pred_shape attmask_hint --show_max 0.1 \
--subset 260 --data_path /path/to/ImageNet --output_dir /path/to/output/

For the AttMask-High or AttMask-Low, just remove the argument --show_max 0.1 and modify --pred_shape attmask_high or --pred_shape attmask_low.


Evaluation

k-NN evaluation

K-NN ImageNet evaluation on full ImageNet-1k or on 20% of the training examples (set --subset 260). Modify --nproc_per_node based on you available GPUs.

python3 -m torch.distributed.launch --nproc_per_node=1 evaluation/eval_knn.py \
--arch vit_small --checkpoint_key teacher --avgpool_patchtokens 0 \
--pretrained_weights /path/to/checkpoint.pth --data_path /path/to/ImageNet --subset -1

Linear probing evaluation

Linear probing ImageNet evaluation on full ImageNet-1k or on 20% of the training examples (set --subset 260). Modify --nproc_per_node and --batch_size_per_gpu based on you available GPUs.

python3 -m torch.distributed.launch --nproc_per_node=2 evaluation/eval_linear.py \
--batch_size_per_gpu 1024 --n_last_blocks 4 --avgpool_patchtokens 0 --arch vit_small --lr 0.005 \
--pretrained_weights /path/to/checkpoint.pth --data_path /path/to/ImageNet --output_dir /path/to/output/ --subset -1

Low shot evaluation

Low shot ImageNet evaluation task, where only ν = 1, 5, 10 or 20 examples per class of the training set are used for the k-NN classifier.

For ν = 1 use --nb_knn 1 --subset 1 and average the results of 5 different runs using different subset --subset_starts 0 , 5, 10, 15 and 20

For ν = 5 use --nb_knn 20 --subset 5 and average the results of 5 different runs using different subset --subset_starts 0 , 5, 10, 15 and 20

For ν = 10 use --nb_knn 20 --subset 10 and average the results of 5 different runs using different subset --subset_starts 0 , 10, 20, 30 and 40

For ν = 20 use --nb_knn 20 --subset 20 and average the results of 5 different runs using different subset --subset_starts 0 , 20, 40, 60 and 80

Example for the first run using ν = 1:

python3 -m torch.distributed.launch --nproc_per_node=1 evaluation/eval_knn_few.py \
--checkpoint_key teacher --avgpool_patchtokens 0 --arch vit_small \
--pretrained_weights /path/to/checkpoint --data_path /path/to/ImageNet \
--nb_knn 1 --subset 1 --subset_starts 0

Masked ImageNet1k evaluation

Linear probing top-1 accuracy on a more challenging masked version of ImageNet1k validation set. Modify --nproc_per_node based on you available GPUs.

Example for 0.7 masking ratio:

python -m torch.distributed.launch --nproc_per_node=1 evaluation/eval_linear_acc_drop.py \
--num_workers 4 --output_dir /path/to/output --data_path /path/to/ImageNet \
--pretrained_weights /path/to/checkpoint --n_last_blocks 4 --avgpool_patchtokens 0 \
--arch vit_small --subset -1 --batch_size_per_gpu 30 --lr 0.001 --load_from checkpoint_teacher_linear.pth \
--masking_ratio 0.7

Background robustness on ImageNet-9

Combine the pre-trained models and the linear head from the linear probing:

python models/combine_ckpt.py \
  --checkpoint_pretraining /path/to/pretrained/checkpoint \
  --checkpoint_linear /path/to/linear_probing_head \
  --output_file /path/to/combined_checkpoint

Perform the background robustness evaluation:

Download and unzip the datasets and run the following:

python3 evaluation/backgrounds_challenge/in9_eval.py \
--arch vit_small \
--checkpoint /path/to/combined_checkpoint \
--data-path /path/to/dataset/bg_challenge 

Finetuning evaluation on other image classification datasets

For finetuning on CIFAR10 run:

python3 -m torch.distributed.launch --nproc_per_node=6 evaluation/eval_transfer_finetuning.py \
--avgpool_patchtokens 0 \
--arch vit_small \
--checkpoint_key teacher \
--batch-size 150 \
--lr 7.5e-6 \
--epochs 500 \
--pretrained_weights /path/to/pretrained/checkpoint \
--output_dir /path/to/output \
--data_set CIFAR10 \
--data_path data/cifar

For finetuning on CIFAR100 modify the --data_set CIFAR100 and on Oxford Flowers modify the --data_set Flowers and the --epochs 1000.

Common Errors

If you face NCCL errors during the evaluations, try to use --backend gloo.

Acknowledgement

This repository is built using the iBOT, DINO, BEiT and ImageNet-9 repositories.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find this repository useful, please consider giving a star ⭐ and citation:

@article{kakogeorgiou2022attmask, 
title={What to Hide from Your Students: Attention-Guided Masked Image Modeling}, 
url={https://arxiv.org/abs/2203.12719}, 
DOI={10.48550/arXiv.2203.12719}, 
journal={arXiv.org}, 
author={Kakogeorgiou, Ioannis and Gidaris, Spyros and Psomas, Bill and Avrithis, Yannis and Bursuc, Andrei and Karantzalos, Konstantinos and Komodakis, Nikos}, 
year={2022}
}

About

This repo contains the official implementation of ECCV 2022 paper "What to Hide from Your Students: Attention-Guided Masked Image Modeling"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%