diff --git a/README.md b/README.md index ff5b261..d4be329 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,11 @@ Currently supported attack methods are as follows: | FGSM | White-box | 📃[Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572) | | I-FGSM (BIM) | White-box | 📃[Adversarial examples in the physical world](https://arxiv.org/abs/1607.02533) | | MI-FGSM (MIM) | White-box | 📃[Boosting Adversarial Attacks with Momentum](https://arxiv.org/abs/1710.06081) | +| NI-FGSM | White-box | 📃[Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks](https://arxiv.org/abs/1908.06281) | +| PGD | White-box | 📃[Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/abs/1706.06083) | +| SI-NI-FGSM | White-box | 📃[Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks](https://arxiv.org/abs/1908.06281) | | SignHunter | Black-box | 📃[Sign Bits Are All You Need for Black-Box Attacks](https://openreview.net/forum?id=SygW0TEFwH) | +| SimBA | Black-box | 📃[Simple Black-box Adversarial Attacks](https://arxiv.org/abs/1905.07121) | | Square attack | Black-box | 📃[Square Attack: a query-efficient black-box adversarial attack via random search](https://arxiv.org/abs/1912.00049) | ### 💠 Defenses @@ -29,7 +33,9 @@ Currently supported defense methods including adversarially trained models are a | Method | Type | References | | :------------------ | :------------------ | :------------------ | -| JPEG Compression | Input transform | 📃[A study of the effect of JPG compression on adversarial images](https://arxiv.org/abs/1608.00853) | +| Bit-Red | Input transform | 📃[Feature Squeezing: Detecting Adversarial Examples in Deep Neural Networks](https://arxiv.org/abs/1704.01155) | +| JPEG | Input transform | 📃[A study of the effect of JPG compression on adversarial images](https://arxiv.org/abs/1608.00853) | +| Randomization | Input transform | 📃[Mitigating Adversarial Effects Through Randomization](https://arxiv.org/abs/1711.01991) | | TRADES | Adv. training | 📃[Theoretically Principled Trade-off between Robustness and Accuracy](https://arxiv.org/abs/1901.08573) | ### 🧩 Others @@ -72,12 +78,12 @@ python advgrads_cli/attack.py --load_config configs/mnist.yaml ### ⚙ Description format of config files The attack configs are managed by a YAML file. The main fields and variables are described below. -- `data`: _(str)_ Specify a dataset for which adversarial examples are to be generated. -- `model`: _(str)_ Select a model to be attacked. See [here](https://github.com/myuito3/AdvGrads/blob/main/advgrads/models/__init__.py) for currently supported models. -- `attacks`: _(list)_ This field allows you to specify attack methods that you wish to execute in a list format. You can set values including hyperparameters defined for each method. The parameters that can be specified for all methods are as follows: +- `data`: _(str, required)_ Specify a dataset for which adversarial examples are to be generated. +- `model`: _(str, required)_ Select a model to be attacked. See [here](https://github.com/myuito3/AdvGrads/blob/main/advgrads/models/__init__.py) for currently supported models. +- `attacks`: _(list, required)_ This field allows you to specify attack methods that you wish to execute in a list format. You can set values including hyperparameters defined for each method. The parameters that can be specified for all methods are as follows: - `method`: _(str)_ Attack method. See [here](https://github.com/myuito3/AdvGrads/blob/main/advgrads/adversarial/__init__.py) for currently supported attack methods. - `norm`: _(str)_ Norm for adversarial perturbations. - `eps`: _(float)_ Maximum norm constraint. - `max_iters`: _(int)_ Maximum number of iterations used in iterative methods. - `targeted`: _(bool)_ Whether or not to perform targeted attacks which aim to misclassify an adversarial example into a particular class. -- `thirdparty_defense`: _(str)_ Thirdparty defensive method. See [here](https://github.com/myuito3/AdvGrads/blob/main/advgrads/adversarial/__init__.py) for currently supported defensive methods. +- `thirdparty_defense`: _(str, optional)_ Thirdparty defensive method. See [here](https://github.com/myuito3/AdvGrads/blob/main/advgrads/adversarial/__init__.py) for currently supported defensive methods. diff --git a/advgrads/adversarial/__init__.py b/advgrads/adversarial/__init__.py index f25252d..5f1ec59 100644 --- a/advgrads/adversarial/__init__.py +++ b/advgrads/adversarial/__init__.py @@ -12,40 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Init adversarial attacks/defenses methods.""" +"""Init attack/defense method configs.""" from advgrads.adversarial.attacks.base_attack import AttackConfig from advgrads.adversarial.attacks.deepfool import DeepFoolAttackConfig +from advgrads.adversarial.attacks.di_mi_fgsm import DiMiFgsmAttackConfig from advgrads.adversarial.attacks.fgsm import FgsmAttackConfig from advgrads.adversarial.attacks.i_fgsm import IFgsmAttackConfig from advgrads.adversarial.attacks.mi_fgsm import MiFgsmAttackConfig +from advgrads.adversarial.attacks.ni_fgsm import NiFgsmAttackConfig +from advgrads.adversarial.attacks.pgd import PGDAttackConfig +from advgrads.adversarial.attacks.pi_fgsm import PiFgsmAttackConfig +from advgrads.adversarial.attacks.si_ni_fgsm import SiNiFgsmAttackConfig from advgrads.adversarial.attacks.signhunter import SignHunterAttackConfig +from advgrads.adversarial.attacks.simba import SimBAAttackConfig from advgrads.adversarial.attacks.square import SquareAttackConfig from advgrads.adversarial.defenses.input_transform.base_defense import DefenseConfig +from advgrads.adversarial.defenses.input_transform.bit_depth_reduction import ( + BitDepthReductionDefenseConfig, +) from advgrads.adversarial.defenses.input_transform.jpeg_compression import ( JpegCompressionDefenseConfig, ) +from advgrads.adversarial.defenses.input_transform.randomization import ( + RandomizationDefenseConfig, +) def get_attack_config_class(name: str) -> AttackConfig: + assert name in all_attack_names, f"Attack method named '{name}' not found." return attack_class_dict[name] def get_defense_config_class(name: str) -> DefenseConfig: + assert name in all_defense_names, f"Defense method named '{name}' not found." return defense_class_dict[name] attack_class_dict = { "deepfool": DeepFoolAttackConfig, + "di-mi-fgsm": DiMiFgsmAttackConfig, "fgsm": FgsmAttackConfig, - "i_fgsm": IFgsmAttackConfig, - "mi_fgsm": MiFgsmAttackConfig, + "i-fgsm": IFgsmAttackConfig, + "mi-fgsm": MiFgsmAttackConfig, + "ni-fgsm": NiFgsmAttackConfig, + "pgd": PGDAttackConfig, + "pi-fgsm": PiFgsmAttackConfig, + "si-ni-fgsm": SiNiFgsmAttackConfig, "signhunter": SignHunterAttackConfig, + "simba": SimBAAttackConfig, "square": SquareAttackConfig, } all_attack_names = list(attack_class_dict.keys()) defense_class_dict = { + "bit-red": BitDepthReductionDefenseConfig, "jpeg": JpegCompressionDefenseConfig, + "randomization": RandomizationDefenseConfig, } all_defense_names = list(defense_class_dict.keys()) diff --git a/advgrads/adversarial/attacks/base_attack.py b/advgrads/adversarial/attacks/base_attack.py index e2a0df4..615c99d 100644 --- a/advgrads/adversarial/attacks/base_attack.py +++ b/advgrads/adversarial/attacks/base_attack.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Base class for adversarial attack methods.""" +"""Base class for attack methods.""" from abc import abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, Literal, Optional, Type +from typing import Any, Dict, List, Literal, Optional, Type import torch from torch import Tensor @@ -27,9 +27,12 @@ from advgrads.models.base_model import Model +NormType = Literal["l_0", "l_2", "l_inf"] + + @dataclass class AttackConfig(InstantiateConfig): - """Configuration for attack methods.""" + """The base configuration class for attack methods.""" _target: Type = field(default_factory=lambda: Attack) """Target class to instantiate.""" @@ -39,7 +42,7 @@ class AttackConfig(InstantiateConfig): """Min value of image used to clip perturbed images.""" max_val: float = 1.0 """Max value of image used to clip perturbed images.""" - norm: Optional[Literal["l_0", "l_2", "l_inf"]] = None + norm: Optional[NormType] = None """Norm bound of adversarial perturbations.""" eps: float = 0.0 """Radius of a l_p ball.""" @@ -48,17 +51,29 @@ class AttackConfig(InstantiateConfig): class Attack: - """Base class for attack methods. + """The base class for attack methods. Args: config: Configuration for attack methods. + norm_allow_list: List of supported perturbation norms. Each method defines this + within its own class. """ config: AttackConfig + norm_allow_list: List[NormType] def __init__(self, config: AttackConfig, **kwargs) -> None: self.config = config + if self.eps < 0: + raise ValueError(f"eps must be greater than or equal to 0, got {self.eps}.") + if self.max_iters < 0: + raise ValueError( + f"max_iters must be greater than or equal to 0, got {self.max_iters}." + ) + if self.norm not in self.norm_allow_list: + raise ValueError(f"Method does not support {self.norm} perturbation norm.") + def __call__(self, *args: Any, **kwargs: Any) -> Dict[ResultHeadNames, Any]: return self.get_outputs(*args, **kwargs) @@ -112,14 +127,15 @@ def get_outputs( Args: x: Images to be searched for adversarial examples. y: Ground truth labels of images. - model: A model under attack. + model: A model to be attacked. + thirdparty_defense: Thirdparty defense method instance. """ attack_outputs = self.run_attack(x, y, model, **kwargs) self.sanity_check(x, attack_outputs[ResultHeadNames.X_ADV]) # If a defensive method is defined, the process is performed here. This - # corresponds to Section 5.2 (GRAY BOX: IMAGE TRANSFORMATIONS AT TEST TIME) of - # the paper of Guo et al. + # corresponds to Section 5.2 (GRAY BOX: IMAGE TRANSFORMATIONS AT TEST TIME) in + # the paper of Guo et al [https://arxiv.org/pdf/1711.00117.pdf]. if thirdparty_defense is not None: attack_outputs[ResultHeadNames.X_ADV] = thirdparty_defense( attack_outputs[ResultHeadNames.X_ADV] @@ -142,22 +158,25 @@ def get_outputs( return attack_outputs def sanity_check(self, x: Tensor, x_adv: Tensor) -> None: - """Ensure that the amount of perturbation is properly controlled. + """Ensure that the amount of perturbation is properly controlled. This method + is specifically used to check the amount of perturbation of norm-constrained + type attack methods. Args: x: Original images. x_adv: Perturbed images. """ if self.eps > 0.0: + deltas = x_adv - x if self.norm == "l_inf": - delta = x_adv - x real = ( - delta.abs().max().half() + deltas.abs().max().half() ) # ignore slight differences within the decimal point - assert ( - real <= self.eps - ), f"Perturbations beyond the l_inf sphere ({real})." + msg = f"Perturbations beyond the l_inf sphere ({real})." elif self.norm == "l_2": - raise NotImplementedError + real = torch.norm(deltas.view(x.shape[0], -1), p=2, dim=-1).max() + msg = f"Perturbations beyond the l_2 sphere ({real})." elif self.norm == "l_0": raise NotImplementedError + + assert real <= self.eps, msg diff --git a/advgrads/adversarial/attacks/deepfool.py b/advgrads/adversarial/attacks/deepfool.py index f0bbf57..d5c1289 100644 --- a/advgrads/adversarial/attacks/deepfool.py +++ b/advgrads/adversarial/attacks/deepfool.py @@ -12,26 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the DeepFool attack. +"""The implementation of the DeepFool attack. Paper: DeepFool: a simple and accurate method to fool deep neural networks Url: https://arxiv.org/abs/1511.04599 """ from dataclasses import dataclass, field -from typing import Dict, Type +from typing import Dict, List, Type import torch from torch import Tensor -from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames from advgrads.models.base_model import Model @dataclass class DeepFoolAttackConfig(AttackConfig): - """The configuration class for DeepFool attack.""" + """The configuration class for the DeepFool attack.""" _target: Type = field(default_factory=lambda: DeepFoolAttack) """Target class to instantiate.""" @@ -44,12 +44,14 @@ class DeepFoolAttack(Attack): Args: config: The DeepFool attack configuration. + norm_allow_list: List of supported perturbation norms. """ config: DeepFoolAttackConfig + norm_allow_list: List[NormType] = ["l_2"] - def __init__(self, config: DeepFoolAttackConfig, **kwargs) -> None: - super().__init__(config, **kwargs) + def __init__(self, config: DeepFoolAttackConfig) -> None: + super().__init__(config) if self.targeted: raise ValueError("DeepFool does not support targeted attack.") diff --git a/advgrads/adversarial/attacks/di_mi_fgsm.py b/advgrads/adversarial/attacks/di_mi_fgsm.py new file mode 100644 index 0000000..ee7c37a --- /dev/null +++ b/advgrads/adversarial/attacks/di_mi_fgsm.py @@ -0,0 +1,123 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Diverse Inputs Momentum Iterative Fast Gradient Sign +Method (DI-MI-FGSM) attack. This method is referred to as Momentum Diverse Inputs +Iterative Fast Gradient Sign Method (M-DI2-FGSM) in the original paper. + +Paper: Improving Transferability of Adversarial Examples with Input Diversity +Url: https://arxiv.org/abs/1803.06978 + +Original code is referenced from https://github.com/cihangxie/DI-2-FGSM +""" + +import random +from dataclasses import dataclass, field +from typing import Dict, List, Type + +import torch +import torch.nn.functional as F +from torch import Tensor + +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType +from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames +from advgrads.models.base_model import Model + + +@dataclass +class DiMiFgsmAttackConfig(AttackConfig): + """The configuration class for the DI-MI-FGSM attack.""" + + _target: Type = field(default_factory=lambda: DiMiFgsmAttack) + """Target class to instantiate.""" + max_resolution_ratio: float = 1.104 + """Ratio of the length of one side of the transformed image to one of the original + image. The default value is calculated w.r.t the ImageNet setting mentioned in the + original paper (330/299 = 1.1036).""" + keep_dims: bool = False + """Whether to keep the original image size.""" + prob: float = 0.5 + """Probability of using diverse inputs.""" + momentum: float = 1.0 + """Momentum about the model.""" + + +class DiMiFgsmAttack(Attack): + """The class of the DI-MI-FGSM attack. + + Args: + config: The DI-MI-FGSM attack configuration. + norm_allow_list: List of supported perturbation norms. + """ + + config: DiMiFgsmAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] + + def input_diversity(self, x: Tensor) -> Tensor: + """Apply diverse input patterns, i.e., random transformations, on the input + image x. + + Args: + x: Images to be transformed. + """ + h, w = x.shape[2:] + h_final = int(h * self.config.max_resolution_ratio) + w_final = int(w * self.config.max_resolution_ratio) + + # 1. random resize + h_resize = random.randint(h, h_final - 1) + w_resize = random.randint(w, w_final - 1) + x_resize = F.interpolate(x, size=[h_resize, w_resize], mode="nearest") + + # 2. random padding + h_remain = h_final - h_resize + w_remain = w_final - w_resize + pad_top = random.randint(0, h_remain) + pad_left = random.randint(0, w_remain) + dim = [pad_left, w_remain - pad_left, pad_top, h_remain - pad_top] + x_pad = F.pad(x_resize, dim, mode="constant", value=0) + + assert x_pad.shape[2:] == (h_final, w_final) + if self.config.keep_dims: + x_pad = F.interpolate(x_pad, size=[h, w], mode="nearest") + + return x_pad if torch.rand(1) < self.config.prob else x + + def run_attack( + self, x: Tensor, y: Tensor, model: Model + ) -> Dict[ResultHeadNames, Tensor]: + x_adv = x + alpha = self.eps / self.max_iters + accumulated_grads = torch.zeros_like(x) + + for _ in range(self.max_iters): + x_adv = x_adv.clone().detach().requires_grad_(True) + model.zero_grad() + + logits = model(self.input_diversity(x_adv)) + loss = F.cross_entropy(logits, y) + if self.targeted: + loss *= -1 + gradients = torch.autograd.grad(loss, [x_adv])[0].detach() + + gradients = gradients / torch.mean( + torch.abs(gradients), dim=(1, 2, 3), keepdims=True + ) + gradients = gradients + self.config.momentum * accumulated_grads + accumulated_grads = gradients.clone().detach() + + x_adv = x_adv + alpha * torch.sign(gradients) + x_adv = torch.clamp(x_adv, min=self.min_val, max=self.max_val) + + return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/fgsm.py b/advgrads/adversarial/attacks/fgsm.py index c9b39ba..77b6100 100644 --- a/advgrads/adversarial/attacks/fgsm.py +++ b/advgrads/adversarial/attacks/fgsm.py @@ -12,27 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the FGSM attack. +"""The implementation of the Fast Gradient Sign Method (FGSM) attack. Paper: Explaining and Harnessing Adversarial Examples Url: https://arxiv.org/abs/1412.6572 """ from dataclasses import dataclass, field -from typing import Dict, Type +from typing import Dict, List, Type import torch import torch.nn.functional as F from torch import Tensor -from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames from advgrads.models.base_model import Model @dataclass class FgsmAttackConfig(AttackConfig): - """The configuration class for FGSM attack.""" + """The configuration class for the FGSM attack.""" _target: Type = field(default_factory=lambda: FgsmAttack) """Target class to instantiate.""" @@ -43,24 +43,24 @@ class FgsmAttack(Attack): Args: config: The FGSM attack configuration. + norm_allow_list: List of supported perturbation norms. """ config: FgsmAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] def run_attack( self, x: Tensor, y: Tensor, model: Model ) -> Dict[ResultHeadNames, Tensor]: x_adv = x.clone().detach().requires_grad_(True) - - logits = model(x_adv) - loss = F.cross_entropy(logits, torch.as_tensor(y, dtype=torch.long)) model.zero_grad() - loss.backward() - gradients_raw = x_adv.grad.data.detach() + logits = model(x_adv) + loss = F.cross_entropy(logits, y) if self.targeted: - gradients_raw *= -1 + loss *= -1 + gradients = torch.autograd.grad(loss, [x_adv])[0].detach() - x_adv = x_adv + self.eps * gradients_raw.sign() - x_adv = x_adv.clamp(min=self.min_val, max=self.max_val) + x_adv = x_adv + self.eps * torch.sign(gradients) + x_adv = torch.clamp(x_adv, min=self.min_val, max=self.max_val) return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/i_fgsm.py b/advgrads/adversarial/attacks/i_fgsm.py index c253ecf..bac0cb3 100644 --- a/advgrads/adversarial/attacks/i_fgsm.py +++ b/advgrads/adversarial/attacks/i_fgsm.py @@ -12,28 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the I-FGSM attack. This method is also called Basic Iterative -Method (BIM). +"""The implementation of the Iterative Fast Gradient Sign Method (I-FGSM) attack. This +method is also called Basic Iterative Method (BIM). Paper: Adversarial examples in the physical world Url: https://arxiv.org/abs/1607.02533 """ from dataclasses import dataclass, field -from typing import Dict, Type +from typing import Dict, List, Type import torch import torch.nn.functional as F from torch import Tensor -from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames from advgrads.models.base_model import Model @dataclass class IFgsmAttackConfig(AttackConfig): - """The configuration class for I-FGSM attack.""" + """The configuration class for the I-FGSM attack.""" _target: Type = field(default_factory=lambda: IFgsmAttack) """Target class to instantiate.""" @@ -44,9 +44,11 @@ class IFgsmAttack(Attack): Args: config: The I-FGSM attack configuration. + norm_allow_list: List of supported perturbation norms. """ config: IFgsmAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] def run_attack( self, x: Tensor, y: Tensor, model: Model @@ -56,17 +58,15 @@ def run_attack( for _ in range(self.max_iters): x_adv = x_adv.clone().detach().requires_grad_(True) - - logits = model(x_adv) - loss = F.cross_entropy(logits, torch.as_tensor(y, dtype=torch.long)) model.zero_grad() - loss.backward() - gradients_raw = x_adv.grad.data.detach() + logits = model(x_adv) + loss = F.cross_entropy(logits, y) if self.targeted: - gradients_raw *= -1 + loss *= -1 + gradients = torch.autograd.grad(loss, [x_adv])[0].detach() - x_adv = x_adv + alpha * gradients_raw.sign() - x_adv = x_adv.clamp(min=self.min_val, max=self.max_val) + x_adv = x_adv + alpha * torch.sign(gradients) + x_adv = torch.clamp(x_adv, min=self.min_val, max=self.max_val) return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/mi_fgsm.py b/advgrads/adversarial/attacks/mi_fgsm.py index e48fd3e..c7f2bda 100644 --- a/advgrads/adversarial/attacks/mi_fgsm.py +++ b/advgrads/adversarial/attacks/mi_fgsm.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the MI-FGSM attack. This method is also called Momentum Iterative -Method (MIM). +"""The implementation of the Momentum Iterative Fast Gradient Sign Method (MI-FGSM) +attack. This method is also called Momentum Iterative Method (MIM). Paper: Boosting Adversarial Attacks with Momentum Url: https://arxiv.org/abs/1710.06081 @@ -23,20 +23,20 @@ """ from dataclasses import dataclass, field -from typing import Dict, Type +from typing import Dict, List, Type import torch import torch.nn.functional as F from torch import Tensor -from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames from advgrads.models.base_model import Model @dataclass class MiFgsmAttackConfig(AttackConfig): - """The configuration class for MI-FGSM attack.""" + """The configuration class for the MI-FGSM attack.""" _target: Type = field(default_factory=lambda: MiFgsmAttack) """Target class to instantiate.""" @@ -49,36 +49,36 @@ class MiFgsmAttack(Attack): Args: config: The MI-FGSM attack configuration. + norm_allow_list: List of supported perturbation norms. """ config: MiFgsmAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] def run_attack( self, x: Tensor, y: Tensor, model: Model ) -> Dict[ResultHeadNames, Tensor]: x_adv = x - grad = torch.zeros_like(x).detach() alpha = self.eps / self.max_iters + accumulated_grads = torch.zeros_like(x) for _ in range(self.max_iters): x_adv = x_adv.clone().detach().requires_grad_(True) - - logits = model(x_adv) - loss = F.cross_entropy(logits, torch.as_tensor(y, dtype=torch.long)) model.zero_grad() - loss.backward() - gradients_raw = x_adv.grad.data.detach() + logits = model(x_adv) + loss = F.cross_entropy(logits, y) if self.targeted: - gradients_raw *= -1 + loss *= -1 + gradients = torch.autograd.grad(loss, [x_adv])[0].detach() - gradients_raw = gradients_raw / torch.mean( - torch.abs(gradients_raw), dim=(1, 2, 3), keepdims=True + gradients = gradients / torch.mean( + torch.abs(gradients), dim=(1, 2, 3), keepdims=True ) - gradients_raw = gradients_raw + self.config.momentum * grad + gradients = gradients + self.config.momentum * accumulated_grads + accumulated_grads = gradients.clone().detach() - x_adv = x_adv + alpha * gradients_raw.sign() - x_adv = x_adv.clamp(min=self.min_val, max=self.max_val) - grad = gradients_raw.clone().detach() + x_adv = x_adv + alpha * torch.sign(gradients) + x_adv = torch.clamp(x_adv, min=self.min_val, max=self.max_val) return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/ni_fgsm.py b/advgrads/adversarial/attacks/ni_fgsm.py new file mode 100644 index 0000000..802e3a0 --- /dev/null +++ b/advgrads/adversarial/attacks/ni_fgsm.py @@ -0,0 +1,85 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Nesterov Iterative Fast Gradient Sign Method (NI-FGSM) +attack. + +Paper: Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks +Url: https://arxiv.org/abs/1908.06281 + +Original code is referenced from https://github.com/JHL-HUST/SI-NI-FGSM +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Type + +import torch +import torch.nn.functional as F +from torch import Tensor + +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType +from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames +from advgrads.models.base_model import Model + + +@dataclass +class NiFgsmAttackConfig(AttackConfig): + """The configuration class for the NI-FGSM attack.""" + + _target: Type = field(default_factory=lambda: NiFgsmAttack) + """Target class to instantiate.""" + momentum: float = 1.0 + """Momentum about the model.""" + + +class NiFgsmAttack(Attack): + """The class of the NI-FGSM attack. + + Args: + config: The NI-FGSM attack configuration. + norm_allow_list: List of supported perturbation norms. + """ + + config: NiFgsmAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] + + def run_attack( + self, x: Tensor, y: Tensor, model: Model + ) -> Dict[ResultHeadNames, Tensor]: + x_adv = x + alpha = self.eps / self.max_iters + accumulated_grads = torch.zeros_like(x) + + for _ in range(self.max_iters): + x_adv = x_adv.clone().detach().requires_grad_(True) + model.zero_grad() + + x_nes = x_adv + self.config.momentum * alpha * accumulated_grads + + logits = model(x_nes) + loss = F.cross_entropy(logits, y) + if self.targeted: + loss *= -1 + gradients = torch.autograd.grad(loss, [x_adv])[0].detach() + + gradients = gradients / torch.mean( + torch.abs(gradients), dim=(1, 2, 3), keepdims=True + ) + gradients = gradients + self.config.momentum * accumulated_grads + accumulated_grads = gradients.clone().detach() + + x_adv = x_adv + alpha * torch.sign(gradients) + x_adv = torch.clamp(x_adv, min=self.min_val, max=self.max_val) + + return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/pgd.py b/advgrads/adversarial/attacks/pgd.py new file mode 100644 index 0000000..b55bee6 --- /dev/null +++ b/advgrads/adversarial/attacks/pgd.py @@ -0,0 +1,75 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Projected Gradient Descent (PGD) attack. + +Paper: Towards Deep Learning Models Resistant to Adversarial Attacks +Url: https://arxiv.org/abs/1706.06083 +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Type + +import torch +import torch.nn.functional as F +from torch import Tensor + +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType +from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames +from advgrads.models.base_model import Model + + +@dataclass +class PGDAttackConfig(AttackConfig): + """The configuration class for the PGD attack.""" + + _target: Type = field(default_factory=lambda: PGDAttack) + """Target class to instantiate.""" + + +class PGDAttack(Attack): + """The class of the PGD attack. + + Args: + config: The PGD attack configuration. + norm_allow_list: List of supported perturbation norms. + """ + + config: PGDAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] + + def run_attack( + self, x: Tensor, y: Tensor, model: Model + ) -> Dict[ResultHeadNames, Tensor]: + alpha = self.eps / self.max_iters + + # Generate initial perturbations from a continuous uniform distribution. + init_deltas = torch.empty_like(x).uniform_(-self.eps, self.eps) + x_adv = torch.clamp(x + init_deltas, min=self.min_val, max=self.max_val) + + for _ in range(self.max_iters): + x_adv = x_adv.clone().detach().requires_grad_(True) + model.zero_grad() + + logits = model(x_adv) + loss = F.cross_entropy(logits, y) + if self.targeted: + loss *= -1 + gradients = torch.autograd.grad(loss, [x_adv])[0].detach() + + x_adv = x_adv + alpha * torch.sign(gradients) + deltas = torch.clamp(x_adv - x, min=-self.eps, max=self.eps) + x_adv = torch.clamp(x + deltas, min=self.min_val, max=self.max_val) + + return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/pi_fgsm.py b/advgrads/adversarial/attacks/pi_fgsm.py new file mode 100644 index 0000000..2ef58bc --- /dev/null +++ b/advgrads/adversarial/attacks/pi_fgsm.py @@ -0,0 +1,106 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Patch-wise Iterative Fast Gradient Sign Method (PI-FGSM) +attack. + +Paper: Patch-wise Attack for Fooling Deep Neural Network +Url: https://arxiv.org/abs/2007.06765 + +Original code is referenced from +https://github.com/qilong-zhang/Patch-wise-iterative-attack/tree/master/Pytorch%20version +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Type + +import torch +import torch.nn.functional as F +from torch import Tensor + +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType +from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames +from advgrads.models.base_model import Model + + +def project_kern(kern_size: int = 3, channels: int = 3): + """Generate a special uniform projection kernel.""" + kern = torch.ones((kern_size, kern_size)) / (kern_size**2 - 1) + kern[kern_size // 2, kern_size // 2] = 0.0 + stack_kern = torch.stack([kern] * channels)[:, None, :, :] + return stack_kern, kern_size // 2 + + +def project_noise(x: Tensor, stack_kern: Tensor, padding_size: int, groups: int = 3): + """Convolution using the project kernel.""" + return F.conv2d(x, stack_kern, padding=(padding_size, padding_size), groups=groups) + + +@dataclass +class PiFgsmAttackConfig(AttackConfig): + """The configuration class for the PI-FGSM attack.""" + + _target: Type = field(default_factory=lambda: PiFgsmAttack) + """Target class to instantiate.""" + amplification: float = 10.0 + """Parameter to amplifythe step size.""" + + +class PiFgsmAttack(Attack): + """The class of the PI-FGSM attack. + + Args: + config: The PI-FGSM attack configuration. + norm_allow_list: List of supported perturbation norms. + """ + + config: PiFgsmAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] + + def run_attack( + self, x: Tensor, y: Tensor, model: Model + ) -> Dict[ResultHeadNames, Tensor]: + x_adv = x + alpha = self.eps / self.max_iters + alpha_beta = alpha * self.config.amplification + + c = x.shape[1] + stack_kern, padding_size = project_kern(kern_size=3, channels=c) + stack_kern.to(x.device) + + amplification = 0.0 + for _ in range(self.max_iters): + x_adv = x_adv.clone().detach().requires_grad_(True) + model.zero_grad() + + logits = model(x_adv) + loss = F.cross_entropy(logits, y) + if self.targeted: + loss *= -1 + gradients = torch.autograd.grad(loss, [x_adv])[0].detach() + + amplification += alpha_beta * torch.sign(gradients) + cut_noise = torch.clamp( + abs(amplification) - self.eps, 0.0, 10000.0 + ) * torch.sign(amplification) + projection = alpha_beta * torch.sign( + project_noise(cut_noise, stack_kern, padding_size, groups=c) + ) + amplification += projection + + x_adv = x_adv + alpha_beta * torch.sign(gradients) + projection + deltas = torch.clamp(x_adv - x, min=-self.eps, max=self.eps) + x_adv = torch.clamp(x + deltas, min=self.min_val, max=self.max_val) + + return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/si_ni_fgsm.py b/advgrads/adversarial/attacks/si_ni_fgsm.py new file mode 100644 index 0000000..37dca37 --- /dev/null +++ b/advgrads/adversarial/attacks/si_ni_fgsm.py @@ -0,0 +1,113 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Scale-Invariant Nesterov Iterative Fast Gradient Sign +Method (SI-NI-FGSM) attack. + +Paper: Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks +Url: https://arxiv.org/abs/1908.06281 + +Original code is referenced from https://github.com/JHL-HUST/SI-NI-FGSM +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Type + +import torch +import torch.nn.functional as F +from torch import Tensor + +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType +from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames +from advgrads.models.base_model import Model + + +@dataclass +class SiNiFgsmAttackConfig(AttackConfig): + """The configuration class for the SI-NI-FGSM attack.""" + + _target: Type = field(default_factory=lambda: SiNiFgsmAttack) + """Target class to instantiate.""" + momentum: float = 1.0 + """Momentum about the model.""" + num_scale_copies: int = 5 + """Number of scale copies.""" + + +class SiNiFgsmAttack(Attack): + """The class of the SI-NI-FGSM attack. + + Args: + config: The SI-NI-FGSM attack configuration. + norm_allow_list: List of supported perturbation norms. + """ + + config: SiNiFgsmAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] + + def scale_invariant( + self, + x: Tensor, + y: Tensor, + model: Model, + x_to_be_scaled: Optional[Tensor] = None, + ): + """Function that returns the mean of the gradients of multiple scale copies of + the input image by the Scale-Invariant Method (SIM). + + Args: + x: Images referenced when calculating the gradient. + y: Labels of images. + model: Model to be attacked. + x_to_be_scaled: Images used for scale copies; if None, this will be the + same as x. + """ + if x_to_be_scaled is None: + x_to_be_scaled = x + gradients = torch.zeros_like(x, device=x.device) + + for i in range(self.config.num_scale_copies): + x_scaled = 1 / (2**i) * x_to_be_scaled + logits = model(x_scaled) + loss = F.cross_entropy(logits, y) + if self.targeted: + loss *= -1 + gradients += torch.autograd.grad(loss, [x])[0].detach() + + return gradients / self.config.num_scale_copies + + def run_attack( + self, x: Tensor, y: Tensor, model: Model + ) -> Dict[ResultHeadNames, Tensor]: + x_adv = x + alpha = self.eps / self.max_iters + accumulated_grads = torch.zeros_like(x) + + for _ in range(self.max_iters): + x_adv = x_adv.clone().detach().requires_grad_(True) + model.zero_grad() + + x_nes = x_adv + self.config.momentum * alpha * accumulated_grads + gradients = self.scale_invariant(x_adv, y, model, x_to_be_scaled=x_nes) + + gradients = gradients / torch.mean( + torch.abs(gradients), dim=(1, 2, 3), keepdims=True + ) + gradients = gradients + self.config.momentum * accumulated_grads + accumulated_grads = gradients.clone().detach() + + x_adv = x_adv + alpha * torch.sign(gradients) + x_adv = torch.clamp(x_adv, min=self.min_val, max=self.max_val) + + return {ResultHeadNames.X_ADV: x_adv} diff --git a/advgrads/adversarial/attacks/signhunter.py b/advgrads/adversarial/attacks/signhunter.py index 7cc581b..e881bd9 100644 --- a/advgrads/adversarial/attacks/signhunter.py +++ b/advgrads/adversarial/attacks/signhunter.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the SignHunter attack. +"""The implementation of the SignHunter attack. Paper: Sign Bits Are All You Need for Black-Box Attacks Url: https://openreview.net/forum?id=SygW0TEFwH """ from dataclasses import dataclass, field -from typing import Dict, Type +from typing import Dict, List, Type import numpy as np import torch import torch.nn as nn from torch import Tensor -from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType from advgrads.adversarial.attacks.utils.losses import MarginLoss from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames from advgrads.models.base_model import Model @@ -34,7 +34,7 @@ @dataclass class SignHunterAttackConfig(AttackConfig): - """The configuration class for SignHunter attack.""" + """The configuration class for the SignHunter attack.""" _target: Type = field(default_factory=lambda: SignHunterAttack) """Target class to instantiate.""" @@ -45,12 +45,14 @@ class SignHunterAttack(Attack): Args: config: The SignHunter attack configuration. + norm_allow_list: List of supported perturbation norms. """ config: SignHunterAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] - def __init__(self, config: SignHunterAttackConfig, **kwargs) -> None: - super().__init__(config, **kwargs) + def __init__(self, config: SignHunterAttackConfig) -> None: + super().__init__(config) self.loss = ( nn.CrossEntropyLoss(reduction="none") @@ -91,7 +93,7 @@ def run_attack( margin_min_curr = margin_min[idx_to_fool] loss_min_curr = loss_min[idx_to_fool] - # Generate candidates for new adversarial examples + # Generate candidates for new adversarial examples. chunk_len = np.ceil(n_dim / (2**h)).astype(int) istart = i * chunk_len iend = min(n_dim, (i + 1) * chunk_len) @@ -107,7 +109,7 @@ def run_attack( loss = self.loss(logits, y_curr) margin = self.margin(logits, y_curr) - # Update current loss values and adversarial examples + # Update current loss values and adversarial examples. idx_improved = loss < loss_min_curr loss_min[idx_to_fool] = idx_improved * loss + ~idx_improved * loss_min_curr margin_min[idx_to_fool] = ( diff --git a/advgrads/adversarial/attacks/simba.py b/advgrads/adversarial/attacks/simba.py new file mode 100644 index 0000000..41f012b --- /dev/null +++ b/advgrads/adversarial/attacks/simba.py @@ -0,0 +1,186 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Simple Black-box Attack (SimBA) attack. + +Paper: Simple Black-box Adversarial Attacks +Url: https://arxiv.org/abs/1905.07121 + +Original code is referenced from https://github.com/cg563/simple-blackbox-attack +Note that this code is simply an extension of 20-line implementation of SimBA to batch +processing. +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Type + +import torch +import torch.nn as nn +from torch import Tensor + +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType +from advgrads.adversarial.attacks.utils.losses import MarginLoss +from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames +from advgrads.models.base_model import Model + + +@dataclass +class SimBAAttackConfig(AttackConfig): + """The configuration class for the SimBA attack.""" + + _target: Type = field(default_factory=lambda: SimBAAttack) + """Target class to instantiate.""" + lr: float = 0.2 + """Step size per iteration.""" + freq_dims: int = 14 + """Dimensionality of 2D frequency space.""" + + +class SimBAAttack(Attack): + """The class of the SimBA attack. + + Args: + config: The SimBA attack configuration. + norm_allow_list: List of supported perturbation norms. + """ + + config: SimBAAttackConfig + norm_allow_list: List[NormType] = ["l_2"] + + def __init__(self, config: SimBAAttackConfig) -> None: + super().__init__(config) + + if self.eps > 0.0: + raise ValueError( + "SimBA is a minimum-norm attack, not a norm-constrained attack." + ) + if self.max_iters > 0: + raise ValueError() + + self.loss = ( + nn.CrossEntropyLoss(reduction="none") + if self.targeted + else MarginLoss(targeted=self.targeted) + ) + self.margin = MarginLoss(self.targeted) + + def reset_data(self, *args: Tuple[Tensor, ...]) -> None: + """Register data to be used for attack. + + Args: + args: Tensor data such as images and loss values. + """ + assert all( + args[0].shape[0] == arg.shape[0] for arg in args + ), "Size mismatch between tensors." + self.all_args = args + + def get_data(self, indices: Tensor) -> Tuple[Tensor, ...]: + """Returns only the elements specified by indices from registered data. + + Args: + indices: Indices of data to be extracted. + """ + return (arg[indices] for arg in self.all_args) + + def update_data(self, indices: Tensor, *args: Tuple[Tensor, ...]) -> None: + """Update the data at specified indices in registered data with new data. + + Args: + indices: Indices of data to be replaced. + args: New tensor data. + """ + for arg, new_arg in zip(self.all_args, args): + arg[indices] = new_arg + + @torch.no_grad() + def step_single(self, idx_to_fool: Tensor, diffs: Tensor, model: Model) -> Tensor: + """Perform one step of attack with given additional perturbations. + + Args: + idx_to_fool: Indices of data to be used to attack. + diffs: Additional perturbations. + model: Model to be attacked. + """ + x_best, y, loss_min, margin_min = self.get_data(idx_to_fool) + + # Generate candidates for new adversarial examples. + x_new = torch.clamp(x_best + diffs, min=self.min_val, max=self.max_val) + logits = model(x_new) + loss = self.loss(logits, y) + margin = self.margin(logits, y) + + # Update current loss values and adversarial examples. + idx_improved = loss < loss_min + loss_min = idx_improved * loss + ~idx_improved * loss_min + margin_min = idx_improved * margin + ~idx_improved * margin_min + + _idx_improved = torch.reshape(idx_improved, [-1, *[1] * len(x_best.shape[:-1])]) + x_best = _idx_improved * x_new + ~_idx_improved * x_best + + self.update_data(idx_to_fool, x_best, y, loss_min, margin_min) + return idx_improved + + @torch.no_grad() + def run_attack( + self, x: Tensor, y: Tensor, model: Model + ) -> Dict[ResultHeadNames, Tensor]: + c, h, w = x.shape[1:] + n_queries = torch.zeros((x.shape[0]), dtype=torch.int16).to(x.device) + + x_best = x.clone() + logits = model(x_best) + loss_min = self.loss(logits, y) + margin_min = self.margin(logits, y) + n_queries += 1 + + # Determine index of pixels to be perturbed at random for each image. + n_dims = c * self.config.freq_dims * self.config.freq_dims + idx_pixels = torch.zeros((x.shape[0], n_dims), device=x.device).long() + for i in range(x.shape[0]): + idx_pixels[i, ...] = torch.randperm(c * h * w)[:n_dims] + + self.reset_data(x_best, y, loss_min, margin_min) + + for i_iter in range(n_dims): + idx_to_fool = torch.atleast_1d((margin_min > 0.0).nonzero().squeeze()) + if len(idx_to_fool) == 0: + break + + # Try negative direction. + diffs = torch.zeros((len(idx_to_fool), c * h * w), device=x.device) + u = torch.arange(len(idx_to_fool)) + diffs[u, idx_pixels[idx_to_fool, i_iter]] = -1 * self.config.lr + diffs = diffs.view(-1, *x.shape[1:]) + + idx_improved = self.step_single(idx_to_fool, diffs, model) + n_queries[idx_to_fool] += 1 + + # Try positive direction for samples that failed to update loss by trying + # negative direction. + idx_failed = torch.nonzero(~idx_improved).squeeze() + idx_to_fool = torch.atleast_1d(idx_to_fool[idx_failed]) + if len(idx_to_fool) == 0: + continue + + diffs = torch.zeros((len(idx_to_fool), c * h * w), device=x.device) + u = torch.arange(len(idx_to_fool)) + diffs[u, idx_pixels[idx_to_fool, i_iter]] = self.config.lr + diffs = diffs.view(-1, *x.shape[1:]) + + _ = self.step_single(idx_to_fool, diffs, model) + n_queries[idx_to_fool] += 1 + + x_best, _, _, _ = self.get_data(torch.arange(x.shape[0])) + return {ResultHeadNames.X_ADV: x_best, ResultHeadNames.QUERIES: n_queries} diff --git a/advgrads/adversarial/attacks/square.py b/advgrads/adversarial/attacks/square.py index a03f771..09a405c 100644 --- a/advgrads/adversarial/attacks/square.py +++ b/advgrads/adversarial/attacks/square.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the Square attack. +"""The implementation of the Square attack. Paper: Square Attack: a query-efficient black-box adversarial attack via random search Url: https://arxiv.org/abs/1912.00049 @@ -22,14 +22,14 @@ import math from dataclasses import dataclass, field -from typing import Dict, Type +from typing import Dict, List, Type import numpy as np import torch import torch.nn as nn from torch import Tensor -from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig +from advgrads.adversarial.attacks.base_attack import Attack, AttackConfig, NormType from advgrads.adversarial.attacks.utils.losses import MarginLoss from advgrads.adversarial.attacks.utils.result_heads import ResultHeadNames from advgrads.models.base_model import Model @@ -37,7 +37,7 @@ @dataclass class SquareAttackConfig(AttackConfig): - """The configuration class for Square attack.""" + """The configuration class for the Square attack.""" _target: Type = field(default_factory=lambda: SquareAttack) """Target class to instantiate.""" @@ -46,16 +46,18 @@ class SquareAttackConfig(AttackConfig): class SquareAttack(Attack): - """The class of the Square attack . + """The class of the Square attack. Args: config: The Square attack configuration. + norm_allow_list: List of supported perturbation norms. """ config: SquareAttackConfig + norm_allow_list: List[NormType] = ["l_inf"] - def __init__(self, config: SquareAttackConfig, **kwargs) -> None: - super().__init__(config, **kwargs) + def __init__(self, config: SquareAttackConfig) -> None: + super().__init__(config) self.loss = ( nn.CrossEntropyLoss(reduction="none") @@ -142,22 +144,21 @@ def looking_window(arr: Tensor) -> Tensor: return deltas + @torch.no_grad() def run_attack( self, x: Tensor, y: Tensor, model: Model ) -> Dict[ResultHeadNames, Tensor]: c, h, w = x.shape[1:] n_queries = torch.zeros((x.shape[0]), dtype=torch.int16).to(x.device) - # [1, w, c], i.e. vertical stripes work best for untargeted attacks init_delta = torch.from_numpy( np.random.choice([-self.eps, self.eps], size=[x.shape[0], c, 1, w]) ).to(x) x_best = torch.clamp(x + init_delta, self.min_val, self.max_val) - with torch.no_grad(): - logits = model(x_best) - loss_min = self.loss(logits, y) - margin_min = self.margin(logits, y) - n_queries += 1 # ones because we have already used 1 query + logits = model(x_best) + loss_min = self.loss(logits, y) + margin_min = self.margin(logits, y) + n_queries += 1 for i_iter in range(self.max_iters - 1): idx_to_fool = torch.atleast_1d((margin_min > 0.0).nonzero().squeeze()) @@ -171,15 +172,14 @@ def run_attack( margin_min_curr = margin_min[idx_to_fool] loss_min_curr = loss_min[idx_to_fool] - # Generate candidates for new adversarial examples + # Generate candidates for new adversarial examples. deltas = self.get_new_deltas(x_curr, x_best_curr, i_iter) x_new = torch.clamp(x_curr + deltas, self.min_val, self.max_val) - with torch.no_grad(): - logits = model(x_new) - loss = self.loss(logits, y_curr) - margin = self.margin(logits, y_curr) + logits = model(x_new) + loss = self.loss(logits, y_curr) + margin = self.margin(logits, y_curr) - # Update current loss values and adversarial examples + # Update current loss values and adversarial examples. idx_improved = loss < loss_min_curr loss_min[idx_to_fool] = idx_improved * loss + ~idx_improved * loss_min_curr margin_min[idx_to_fool] = ( diff --git a/advgrads/adversarial/attacks/utils/losses.py b/advgrads/adversarial/attacks/utils/losses.py index 45aa263..2e3dedd 100644 --- a/advgrads/adversarial/attacks/utils/losses.py +++ b/advgrads/adversarial/attacks/utils/losses.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Collection of losses.""" +"""Collection of loss functions.""" from abc import abstractmethod @@ -23,14 +23,12 @@ class Loss(nn.Module): - """Base class for losses. + """Base class for a loss function. Args: - targeted: Objective of attacker. + targeted: Whether or not to perform targeted attacks. """ - targeted: bool = False - def __init__(self, targeted: bool = False) -> None: super().__init__() self.targeted = targeted @@ -40,13 +38,17 @@ def forward(self, logits: Tensor, y: Tensor) -> Tensor: """Calculates loss values. Args: - logits: Logits value outputed by victim model. + logits: Logits values outputed by victim model. y: Ground truth or target labels. """ class CrossEntropyLoss(Loss): - """Implementation of the cross-entropy loss.""" + """Implementation of the cross-entropy loss. + + Args: + targeted: Whether or not to perform targeted attacks. + """ def forward(self, logits: Tensor, y: Tensor) -> Tensor: loss = F.cross_entropy(logits, y, reduction="none") @@ -58,6 +60,9 @@ def forward(self, logits: Tensor, y: Tensor) -> Tensor: class MarginLoss(Loss): """Implementation of the margin loss (difference between the correct and 2nd best class). + + Args: + targeted: Whether or not to perform targeted attacks. """ def forward(self, logits: Tensor, y: Tensor) -> Tensor: diff --git a/advgrads/adversarial/attacks/utils/math.py b/advgrads/adversarial/attacks/utils/math.py new file mode 100644 index 0000000..2b6c690 --- /dev/null +++ b/advgrads/adversarial/attacks/utils/math.py @@ -0,0 +1,29 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Math helper functions.""" + +import torch +from torch import Tensor + + +def normalize(vectors: Tensor, eps: float = 1e-10) -> Tensor: + """Returns normalized vectors. + + Args: + vectors: Vectors to normalize. + eps: Epsilon value to avoid division by zero. + """ + dims = list(range(vectors.ndim)) + return vectors / (torch.norm(vectors, p=2, dim=dims[1:], keepdim=True) + eps) diff --git a/advgrads/adversarial/defenses/adv_train/trades/trades_mnist.py b/advgrads/adversarial/defenses/adv_train/trades/trades_mnist.py index 8431234..c629430 100644 --- a/advgrads/adversarial/defenses/adv_train/trades/trades_mnist.py +++ b/advgrads/adversarial/defenses/adv_train/trades/trades_mnist.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The TRADES (TRadeoff-inspired Adversarial DEfense via Surrogate-loss minimization) -model. +"""The implementation of the TRadeoff-inspired Adversarial DEfense via Surrogate-loss +minimization (TRADES) model. Paper: Theoretically Principled Trade-off between Robustness and Accuracy Url: https://arxiv.org/abs/1901.08573 @@ -35,7 +35,7 @@ @dataclass class TradesMnistModelConfig(ModelConfig): - """Configuration for the TRADES model instantiation.""" + """The configuration class for the TRADES model.""" _target: Type = field(default_factory=lambda: TradesMnistModel) """Target class to instantiate.""" @@ -57,7 +57,7 @@ class TradesMnistModel(Model): config: TradesMnistModelConfig - def __init__(self, config: TradesMnistModelConfig, **kwargs) -> None: + def __init__(self, config: TradesMnistModelConfig) -> None: super().__init__(config) activ = nn.ReLU(True) diff --git a/advgrads/adversarial/defenses/input_transform/base_defense.py b/advgrads/adversarial/defenses/input_transform/base_defense.py index ce7db4b..55d94c9 100644 --- a/advgrads/adversarial/defenses/input_transform/base_defense.py +++ b/advgrads/adversarial/defenses/input_transform/base_defense.py @@ -25,14 +25,14 @@ @dataclass class DefenseConfig(InstantiateConfig): - """Configuration for defense methods.""" + """The base configuration class for defense methods.""" _target: Type = field(default_factory=lambda: Defense) """Target class to instantiate.""" class Defense: - """Base class for defense methods. + """The base class for defense methods. Args: config: Configuration for defense methods. diff --git a/advgrads/adversarial/defenses/input_transform/bit_depth_reduction.py b/advgrads/adversarial/defenses/input_transform/bit_depth_reduction.py new file mode 100644 index 0000000..bb82626 --- /dev/null +++ b/advgrads/adversarial/defenses/input_transform/bit_depth_reduction.py @@ -0,0 +1,55 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Bit-Depth Reduction defense. + +Paper: Feature Squeezing: Detecting Adversarial Examples in Deep Neural Networks +Url: https://arxiv.org/abs/1704.01155 +""" + +from dataclasses import dataclass, field +from typing import Type + +from torch import Tensor + +from advgrads.adversarial.defenses.input_transform.base_defense import ( + Defense, + DefenseConfig, +) + + +@dataclass +class BitDepthReductionDefenseConfig(DefenseConfig): + """The configuration class for the Bit-Depth Reduction defense.""" + + _target: Type = field(default_factory=lambda: BitDepthReductionDefense) + """Target class to instantiate.""" + num_bits: int = 4 + """Number of bits after squeezing.""" + + +class BitDepthReductionDefense(Defense): + """The class of the Bit-Depth Reduction defense. + + Args: + config: The Bit-Depth Reduction defense configuration. + """ + + config: BitDepthReductionDefenseConfig + + def run_defense(self, x: Tensor) -> Tensor: + max_val_squeezed = 2**self.config.num_bits + x_defended = (x.detach().clone() * max_val_squeezed).int() + x_defended = x_defended / max_val_squeezed + return x_defended diff --git a/advgrads/adversarial/defenses/input_transform/jpeg_compression.py b/advgrads/adversarial/defenses/input_transform/jpeg_compression.py index 4f3cd93..9e79ce4 100644 --- a/advgrads/adversarial/defenses/input_transform/jpeg_compression.py +++ b/advgrads/adversarial/defenses/input_transform/jpeg_compression.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the JPEG Compression defense. +"""The implementation of the JPEG Compression defense. Paper: A study of the effect of JPG compression on adversarial images Url: https://arxiv.org/abs/1608.00853 @@ -23,9 +23,9 @@ from typing import Type import torch +import torchvision.transforms.functional as F from PIL import Image from torch import Tensor -from torchvision import transforms from advgrads.adversarial.defenses.input_transform.base_defense import ( Defense, @@ -35,7 +35,7 @@ @dataclass class JpegCompressionDefenseConfig(DefenseConfig): - """Configuration for the JPEG Compression defense.""" + """The configuration class for the JPEG Compression defense.""" _target: Type = field(default_factory=lambda: JpegCompressionDefense) """Target class to instantiate.""" @@ -52,19 +52,13 @@ class JpegCompressionDefense(Defense): config: JpegCompressionDefenseConfig - def __init__(self, config: JpegCompressionDefenseConfig) -> None: - super().__init__(config) - - self.to_pil = transforms.ToPILImage() - self.to_tensor = transforms.ToTensor() - def run_defense(self, x: Tensor) -> Tensor: x_defended = torch.zeros_like(x, device=x.device) for i_img in range(x.shape[0]): - x_i_pil = self.to_pil(x[i_img].detach().clone().cpu()) + x_i_pil = F.to_pil_image(x[i_img].detach().clone().cpu()) buffer = BytesIO() x_i_pil.save(buffer, format="JPEG", quality=self.config.quality) - x_defended[i_img] = self.to_tensor(Image.open(buffer)).to(x.device) + x_defended[i_img] = F.to_tensor(Image.open(buffer)).to(x.device) return x_defended diff --git a/advgrads/adversarial/defenses/input_transform/randomization.py b/advgrads/adversarial/defenses/input_transform/randomization.py new file mode 100644 index 0000000..bc38d08 --- /dev/null +++ b/advgrads/adversarial/defenses/input_transform/randomization.py @@ -0,0 +1,90 @@ +# Copyright 2023 Makoto Yuito. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The implementation of the Randomization defense. + +Paper: Mitigating Adversarial Effects Through Randomization +Url: https://arxiv.org/abs/1711.01991 +""" + +import random +from dataclasses import dataclass, field +from typing import Type + +import torch +import torch.nn.functional as F +from torch import Tensor + +from advgrads.adversarial.defenses.input_transform.base_defense import ( + Defense, + DefenseConfig, +) + + +@dataclass +class RandomizationDefenseConfig(DefenseConfig): + """The configuration class for the Randomization defense.""" + + _target: Type = field(default_factory=lambda: RandomizationDefense) + """Target class to instantiate.""" + max_resolution_ratio: float = 1.11 + """Ratio of the length of one side of the transformed image to one of the original + image. The default value is calculated w.r.t the ImageNet setting mentioned in the + paper (331/299 = 1.107).""" + keep_dims: bool = True + """Whether to keep the original image size.""" + + +class RandomizationDefense(Defense): + """The class of the Randomization defense. + + Args: + config: The Randomization defense configuration. + """ + + config: RandomizationDefenseConfig + + def run_defense(self, x: Tensor) -> Tensor: + h, w = x.shape[2:] + h_final = int(h * self.config.max_resolution_ratio) + w_final = int(w * self.config.max_resolution_ratio) + + if self.config.keep_dims: + x_defended = torch.zeros_like(x, device=x.device) + else: + x_defended = torch.zeros((*x.shape[:2], h_final, w_final), device=x.device) + + for i_img in range(x.shape[0]): + x_i = x[i_img : i_img + 1].detach().clone() + + # 1. random resize + h_resize = random.randint(h, h_final - 1) + w_resize = random.randint(w, w_final - 1) + x_i_resize = F.interpolate(x_i, size=[h_resize, w_resize], mode="nearest") + + # 2. random padding + h_remain = h_final - h_resize + w_remain = w_final - w_resize + pad_top = random.randint(0, h_remain) + pad_left = random.randint(0, w_remain) + dim = [pad_left, w_remain - pad_left, pad_top, h_remain - pad_top] + x_i_pad = F.pad(x_i_resize, dim, mode="constant", value=0) + + assert x_i_pad.shape[2:] == (h_final, w_final) + if self.config.keep_dims: + x_i_pad = F.interpolate(x_i_pad, size=[h, w], mode="nearest") + + x_defended[i_img] = x_i_pad.squeeze(0) + + return x_defended diff --git a/advgrads/configs/experiment_config.py b/advgrads/configs/experiment_config.py index 5117b56..0697b66 100644 --- a/advgrads/configs/experiment_config.py +++ b/advgrads/configs/experiment_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Config used for running an experiment.""" +"""Collection of config classes used for running an experiment.""" from dataclasses import dataclass from pathlib import Path @@ -29,8 +29,6 @@ class ExperimentConfig: """Name of the dataset.""" model: Optional[str] = None """Name of the model.""" - checkpoint_path: Optional[str] = None - """Path to the checkpoint to be loaded into the model.""" attacks: Optional[List[dict]] = None """List of attack parameters.""" seed: Optional[int] = None @@ -45,7 +43,7 @@ class ExperimentConfig: @dataclass class ResultConfig(ExperimentConfig): - """Result config.""" + """The config class for output results of the experiment.""" output_dir: Path = Path("outputs") """Output directory to save the result of each attack.""" diff --git a/advgrads/data/__init__.py b/advgrads/data/__init__.py index d74cb97..5d5f2a5 100644 --- a/advgrads/data/__init__.py +++ b/advgrads/data/__init__.py @@ -23,6 +23,7 @@ def get_dataset_class(name: str) -> Dataset: + assert name in all_dataset_names, f"Dataset named '{name}' not found." return dataset_class_dict[name] diff --git a/advgrads/data/datasets/vision_dataset.py b/advgrads/data/datasets/vision_dataset.py index 005fca5..071dac6 100644 --- a/advgrads/data/datasets/vision_dataset.py +++ b/advgrads/data/datasets/vision_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Vision Dataset.""" +"""Datasets provided by torchvision.""" from typing import List, Optional, Tuple @@ -27,7 +27,11 @@ class MnistDataset(Dataset): - """The MNIST Dataset.""" + """The MNIST Dataset. + + Args: + indices_to_use: List of image indices to be used. + """ def __init__(self, indices_to_use: Optional[List[int]] = None) -> None: super().__init__() @@ -51,9 +55,13 @@ def num_classes(self) -> int: class Cifar10Dataset(Dataset): - """The CIFAR-10 Dataset.""" + """The CIFAR-10 Dataset. + + Args: + indices_to_use: List of image indices to be used. + """ - def __init__(self, indices_to_use: list = None) -> None: + def __init__(self, indices_to_use: Optional[List[int]] = None) -> None: super().__init__() data = CIFAR10(root=DATA_PATH, train=False, download=True) arrays = (data.data, np.array(data.targets, dtype=np.longlong)) diff --git a/advgrads/models/__init__.py b/advgrads/models/__init__.py index ea3b6bd..8b953c3 100644 --- a/advgrads/models/__init__.py +++ b/advgrads/models/__init__.py @@ -23,12 +23,13 @@ def get_model_config_class(name: str) -> Model: + assert name in all_model_names, f"Model named '{name}' not found." return model_config_class_dict[name] model_config_class_dict = { - "ptpg_mnist": PtPgMnistModelConfig, - "ptpg_cifar10": PtPgCifar10ModelConfig, - "trades_mnist": TradesMnistModelConfig, + "ptpg-mnist": PtPgMnistModelConfig, + "ptpg-cifar10": PtPgCifar10ModelConfig, + "trades-mnist": TradesMnistModelConfig, } all_model_names = list(model_config_class_dict.keys()) diff --git a/advgrads/models/base_model.py b/advgrads/models/base_model.py index 9948205..ac39984 100644 --- a/advgrads/models/base_model.py +++ b/advgrads/models/base_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Base model class.""" +"""Base class for models.""" import os import requests @@ -30,7 +30,7 @@ @dataclass class ModelConfig(InstantiateConfig): - """Configuration for the base model instantiation.""" + """The base configuration class for models.""" _target: Type = field(default_factory=lambda: Model) """Target class to instantiate.""" @@ -41,10 +41,10 @@ class ModelConfig(InstantiateConfig): class Model(nn.Module): - """Base model class for PyTorch. + """The base class for models. Args: - config: The base model configuration. + config: Configuration for models. """ config: ModelConfig @@ -54,7 +54,7 @@ def __init__(self, config: ModelConfig, **kwargs) -> None: self.config = config @abstractmethod - def forward(self, x_input: Tensor) -> Tensor: + def forward(self, x_input: Tensor, **kwargs) -> Tensor: """Query the model and obtain logits output. Args: diff --git a/advgrads/models/pytorch_playground/cifar10_model.py b/advgrads/models/pytorch_playground/cifar10_model.py index 427adbe..99a7b9f 100644 --- a/advgrads/models/pytorch_playground/cifar10_model.py +++ b/advgrads/models/pytorch_playground/cifar10_model.py @@ -30,7 +30,7 @@ @dataclass class PtPgCifar10ModelConfig(ModelConfig): - """Configuration for the pytorch-playground model instantiation.""" + """The configuration class for the pytorch-playground model.""" _target: Type = field(default_factory=lambda: PtPgCifar10Model) """Target class to instantiate.""" @@ -51,7 +51,7 @@ class PtPgCifar10Model(Model): config: PtPgCifar10ModelConfig - def __init__(self, config: PtPgCifar10ModelConfig, **kwargs) -> None: + def __init__(self, config: PtPgCifar10ModelConfig) -> None: super().__init__(config) n_channel = 128 diff --git a/advgrads/models/pytorch_playground/mnist_model.py b/advgrads/models/pytorch_playground/mnist_model.py index bdf26b2..ab88d8b 100644 --- a/advgrads/models/pytorch_playground/mnist_model.py +++ b/advgrads/models/pytorch_playground/mnist_model.py @@ -31,7 +31,7 @@ @dataclass class PtPgMnistModelConfig(ModelConfig): - """Configuration for the pytorch-playground model instantiation.""" + """The configuration class for the pytorch-playground model.""" _target: Type = field(default_factory=lambda: PtPgMnistModel) """Target class to instantiate.""" @@ -52,7 +52,7 @@ class PtPgMnistModel(Model): config: PtPgMnistModelConfig - def __init__(self, config: PtPgMnistModelConfig, **kwargs) -> None: + def __init__(self, config: PtPgMnistModelConfig) -> None: super().__init__(config) self.input_dims = 784 diff --git a/advgrads/utils/rich_utils.py b/advgrads/utils/rich_utils.py index d893788..c0b1dfd 100644 --- a/advgrads/utils/rich_utils.py +++ b/advgrads/utils/rich_utils.py @@ -24,10 +24,18 @@ def console_print(msg: Any) -> None: - """Print message via rich console.""" + """Print message via rich console. + + Args: + msg: Message outputed in terminal. + """ CONSOLE.print(msg) def console_log(msg: Any) -> None: - """Log message via rich console.""" + """Log message via rich console. + + Args: + msg: Message outputed in terminal. + """ CONSOLE.log(msg) diff --git a/configs/cifar10.yaml b/configs/cifar10.yaml index ec5eb4e..8dbc220 100644 --- a/configs/cifar10.yaml +++ b/configs/cifar10.yaml @@ -1,6 +1,6 @@ seed: 42 data: cifar10 -model: ptpg_cifar10 +model: ptpg-cifar10 attacks: - @@ -10,13 +10,13 @@ attacks: num_iters: 0 targeted: false - - method: i_fgsm + method: i-fgsm norm: l_inf eps: 0.05 max_iters: 10 targeted: false - - method: mi_fgsm + method: mi-fgsm norm: l_inf eps: 0.05 max_iters: 10 diff --git a/configs/mnist.yaml b/configs/mnist.yaml index 64a91e4..bf8d9f4 100644 --- a/configs/mnist.yaml +++ b/configs/mnist.yaml @@ -1,6 +1,6 @@ seed: 42 data: mnist -model: ptpg_mnist +model: ptpg-mnist attacks: - @@ -10,13 +10,13 @@ attacks: max_iters: 0 targeted: false - - method: i_fgsm + method: i-fgsm norm: l_inf eps: 0.3 max_iters: 10 targeted: false - - method: mi_fgsm + method: mi-fgsm norm: l_inf eps: 0.3 max_iters: 10