Skip to content

Harry24k/adversarial-attacks-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Adversarial-Attacks-PyTorch

MIT License Pypi Latest Release Documentation Status Code style: black

Torchattacks is a PyTorch library that provides adversarial attacks to generate adversarial examples.

It contains PyTorch-like interface and functions that make it easier for PyTorch users to implement adversarial attacks.

import torchattacks
atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
# If inputs were normalized, then
# atk.set_normalization_used(mean=[...], std=[...])
adv_images = atk(images, labels)

Additional Recommended Packages.

Citation. If you use this package, please cite the following BibTex (GoogleScholar):

@article{kim2020torchattacks,
title={Torchattacks: A pytorch repository for adversarial attacks},
author={Kim, Hoki},
journal={arXiv preprint arXiv:2010.01950},
year={2020}
}

πŸ”¨ Requirements and Installation

Requirements

  • PyTorch version >=1.4.0
  • Python version >=3.6

Installation

#  pip
pip install torchattacks

#  source
pip install git+https://github.com/Harry24k/adversarial-attacks-pytorch.git

#  git clone
git clone https://github.com/Harry24k/adversarial-attacks-pytorch.git
cd adversarial-attacks-pytorch/
pip install -e .

πŸš€ Getting Started

Precautions

  • All models should return ONLY ONE vector of (N, C) where C = number of classes. Considering most models in torchvision.models return one vector of (N,C), where N is the number of inputs and C is thenumber of classes, torchattacks also only supports limited forms of output. Please check the shape of the model’s output carefully.
  • The domain of inputs should be in the range of [0, 1]. Since the clipping operation is always applied after the perturbation, the original inputs should have the range of [0, 1], which is the general settings in the vision domain.
  • torch.backends.cudnn.deterministic = True to get same adversarial examples with fixed random seed. Some operations are non-deterministic with float tensors on GPU [discuss]. If you want to get same results with same inputs, please run torch.backends.cudnn.deterministic = True[ref].

Demos

  • Targeted mode

    • Random target label
      # random labels as target labels.
      atk.set_mode_targeted_random()
    • Least likely label
      # labels with the k-th smallest probability as target labels.
      atk.set_mode_targeted_least_likely(kth_min)
    • By custom function
      # labels obtained by mapping function as target labels.
      # shift all class loops one to the right, 1=>2, 2=>3, .., 9=>0
      atk.set_mode_targeted_by_function(target_map_function=lambda images, labels:(labels+1)%10)
    • By label
      atk.set_mode_targeted_by_label(quiet=True)
      # shift all class loops one to the right, 1=>2, 2=>3, .., 9=>0
      target_labels = (labels + 1) % 10
      adv_images = atk(images, target_labels)
    • Return to default
      atk.set_mode_default()
  • Save adversarial images

    # Save
    atk.save(data_loader, save_path="./data.pt", verbose=True)
    
    # Load
    adv_loader = atk.load(load_path="./data.pt")
  • Training/Eval during attack

    # For RNN-based models, we cannot calculate gradients with eval mode.
    # Thus, it should be changed to the training mode during the attack.
    atk.set_model_training_mode(model_training=False, batchnorm_training=False, dropout_training=False)
  • Make a set of attacks

    • Strong attacks
      atk1 = torchattacks.FGSM(model, eps=8/255)
      atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
      atk = torchattacks.MultiAttack([atk1, atk2])
    • Binary search for CW
      atk1 = torchattacks.CW(model, c=0.1, steps=1000, lr=0.01)
      atk2 = torchattacks.CW(model, c=1, steps=1000, lr=0.01)
      atk = torchattacks.MultiAttack([atk1, atk2])
    • Random restarts
      atk1 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
      atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
      atk = torchattacks.MultiAttack([atk1, atk2])

πŸ“ƒ Supported Attacks

The distance measure in parentheses.

Name Paper Remark
FGSM
(Linf)
Explaining and harnessing adversarial examples (Goodfellow et al., 2014)
BIM
(Linf)
Adversarial Examples in the Physical World (Kurakin et al., 2016) Basic iterative method or Iterative-FSGM
CW
(L2)
Towards Evaluating the Robustness of Neural Networks (Carlini et al., 2016)
RFGSM
(Linf)
Ensemble Adversarial Traning: Attacks and Defences (Tramèr et al., 2017) Random initialization + FGSM
PGD
(Linf)
Towards Deep Learning Models Resistant to Adversarial Attacks (Mardry et al., 2017) Projected Gradient Method
PGDL2
(L2)
Towards Deep Learning Models Resistant to Adversarial Attacks (Mardry et al., 2017) Projected Gradient Method
MIFGSM
(Linf)
Boosting Adversarial Attacks with Momentum (Dong et al., 2017) 😍 Contributor zhuangzi926, huitailangyz
TPGD
(Linf)
Theoretically Principled Trade-off between Robustness and Accuracy (Zhang et al., 2019)
EOTPGD
(Linf)
Comment on "Adv-BNN: Improved Adversarial Defense through Robust Bayesian Neural Network" (Zimmermann, 2019) EOT+PGD
APGD
(Linf, L2)
Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks (Croce et al., 2020)
APGDT
(Linf, L2)
Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks (Croce et al., 2020) Targeted APGD
FAB
(Linf, L2, L1)
Minimally distorted Adversarial Examples with a Fast Adaptive Boundary Attack (Croce et al., 2019)
Square
(Linf, L2)
Square Attack: a query-efficient black-box adversarial attack via random search (Andriushchenko et al., 2019)
AutoAttack
(Linf, L2)
Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks (Croce et al., 2020) APGD+APGDT+FAB+Square
DeepFool
(L2)
DeepFool: A Simple and Accurate Method to Fool Deep Neural Networks (Moosavi-Dezfooli et al., 2016)
OnePixel
(L0)
One pixel attack for fooling deep neural networks (Su et al., 2019)
SparseFool
(L0)
SparseFool: a few pixels make a big difference (Modas et al., 2019)
DIFGSM
(Linf)
Improving Transferability of Adversarial Examples with Input Diversity (Xie et al., 2019) 😍 Contributor taobai
TIFGSM
(Linf)
Evading Defenses to Transferable Adversarial Examples by Translation-Invariant Attacks (Dong et al., 2019) 😍 Contributor taobai
NIFGSM
(Linf)
Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (Lin, et al., 2022) 😍 Contributor Zhijin-Ge
SINIFGSM
(Linf)
Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (Lin, et al., 2022) 😍 Contributor Zhijin-Ge
VMIFGSM
(Linf)
Enhancing the Transferability of Adversarial Attacks through Variance Tuning (Wang, et al., 2022) 😍 Contributor Zhijin-Ge
VNIFGSM
(Linf)
Enhancing the Transferability of Adversarial Attacks through Variance Tuning (Wang, et al., 2022) 😍 Contributor Zhijin-Ge
Jitter
(Linf)
Exploring Misclassifications of Robust Neural Networks to Enhance Adversarial Attacks (Schwinn, Leo, et al., 2021)
Pixle
(L0)
Pixle: a fast and effective black-box attack based on rearranging pixels (Pomponi, Jary, et al., 2022)
LGV
(Linf, L2, L1, L0)
LGV: Boosting Adversarial Example Transferability from Large Geometric Vicinity (Gubri, et al., 2022) 😍 Contributor Martin Gubri
SPSA
(Linf)
Adversarial Risk and the Dangers of Evaluating Against Weak Attacks (Uesato, Jonathan, et al., 2018) 😍 Contributor Riko Naka
JSMA
(L0)
The Limitations of Deep Learning in Adversarial Settings (Papernot, Nicolas, et al., 2016) 😍 Contributor Riko Naka
EADL1
(L1)
EAD: Elastic-Net Attacks to Deep Neural Networks (Chen, Pin-Yu, et al., 2018) 😍 Contributor Riko Naka
EADEN
(L1, L2)
EAD: Elastic-Net Attacks to Deep Neural Networks (Chen, Pin-Yu, et al., 2018) 😍 Contributor Riko Naka
PIFGSM (PIM)
(Linf)
Patch-wise Attack for Fooling Deep Neural Network (Gao, Lianli, et al., 2020) 😍 Contributor Riko Naka
PIFGSM++ (PIM++)
(Linf)
Patch-wise++ Perturbation for Adversarial Targeted Attacks (Gao, Lianli, et al., 2021) 😍 Contributor Riko Naka

πŸ“Š Performance Comparison

As for the comparison packages, currently updated and the most cited methods were selected:

  • Foolbox: 611 citations and last update 2023.10.
  • ART: 467 citations and last update 2023.10.

Robust accuracy against each attack and elapsed time on the first 50 images of CIFAR10. For L2 attacks, the average L2 distances between adversarial images and the original images are recorded. All experiments were done on GeForce RTX 2080. For the latest version, please refer to here (code, nbviewer).

Attack Package Standard Wong2020Fast Rice2020Overfitting Remark
FGSM (Linf) Torchattacks 34% (54ms) 48% (5ms) 62% (82ms)
Foolbox* 34% (15ms) 48% (8ms) 62% (30ms)
ART 34% (214ms) 48% (59ms) 62% (768ms)
PGD (Linf) Torchattacks 0% (174ms) 44% (52ms) 58% (1348ms) πŸ‘‘ ​Fastest
Foolbox* 0% (354ms) 44% (56ms) 58% (1856ms)
ART 0% (1384 ms) 44% (437ms) 58% (4704ms)
CW† (L2) Torchattacks 0% / 0.40
(2596ms)
14% / 0.61
(3795ms)
22% / 0.56
(43484ms)
πŸ‘‘ ​Highest Success Rate
πŸ‘‘ Fastest
Foolbox* 0% / 0.40
(2668ms)
32% / 0.41
(3928ms)
34% / 0.43
(44418ms)
ART 0% / 0.59
(196738ms)
24% / 0.70
(66067ms)
26% / 0.65
(694972ms)
PGD (L2) Torchattacks 0% / 0.41 (184ms) 68% / 0.5
(52ms)
70% / 0.5
(1377ms)
πŸ‘‘ Fastest
Foolbox* 0% / 0.41 (396ms) 68% / 0.5
(57ms)
70% / 0.5
(1968ms)
ART 0% / 0.40 (1364ms) 68% / 0.5
(429ms)
70% / 0.5
(4777ms)

* Note that Foolbox returns accuracy and adversarial images simultaneously, thus the actual time for generating adversarial images might be shorter than the records.

†Considering that the binary search algorithm for const c can be time-consuming, torchattacks supports MutliAttack for grid searching c.

Additionally, I also recommend to use a recently proposed package, Rai-toolbox.

Attack Package Time/step (accuracy)
FGSM (Linf) rai-toolbox 58 ms (0%)
Torchattacks 81 ms (0%)
Foolbox 105 ms (0%)
ART 83 ms (0%)
PGD (Linf) rai-toolbox 58 ms (44%)
Torchattacks 79 ms (44%)
Foolbox 82 ms (44%)
ART 90 ms (44%)
PGD (L2) rai-toolbox 58 ms (70%)
Torchattacks 81 ms (70%)
Foolbox 82 ms (70%)
ART 89 ms (70%)

The rai-toolbox takes a unique approach to gradient-based perturbations: they are implemented in terms of parameter-transforming optimizers and perturbation models. This enables users to implement diverse algorithms (like universal perturbations and concept probing with sparse gradients) using the same paradigm as a standard PGD attack.