Skip to content

Commit

Permalink
Merge pull request #5 from myuito3/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
myuito3 authored Dec 12, 2023
2 parents 3a7eb6b + d98f150 commit 97b6a0e
Show file tree
Hide file tree
Showing 23 changed files with 551 additions and 73 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

# AdvGrads

<p>
<!-- release badge -->
<a href="https://github.com/myuito3/AdvGrads/releases">
<img alt="Latest Release" src="https://img.shields.io/github/release/myuito3/AdvGrads.svg?&color=blue" /></a>
<!-- license badge -->
<a href="https://github.com/myuito3/AdvGrads/blob/master/LICENSE">
<img alt="License" src="https://img.shields.io/badge/License-Apache_2.0-brightgreen.svg" /></a>
</p>

</div>

## 🌐 About
Expand All @@ -18,11 +27,13 @@ Currently supported attack methods are as follows:
| Method | Type | References |
| :------------------ | :------------------ | :------------------ |
| DeepFool | White-box | 📃[DeepFool: a simple and accurate method to fool deep neural networks](https://arxiv.org/abs/1511.04599) |
| DI-MI-FGSM | White-box | 📃[Improving Transferability of Adversarial Examples with Input Diversity](https://arxiv.org/abs/1803.06978) |
| FGSM | White-box | 📃[Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572) |
| I-FGSM (BIM) | White-box | 📃[Adversarial examples in the physical world](https://arxiv.org/abs/1607.02533) |
| MI-FGSM (MIM) | White-box | 📃[Boosting Adversarial Attacks with Momentum](https://arxiv.org/abs/1710.06081) |
| NI-FGSM | White-box | 📃[Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks](https://arxiv.org/abs/1908.06281) |
| PGD | White-box | 📃[Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/abs/1706.06083) |
| PI-FGSM | White-box | 📃[Patch-wise Attack for Fooling Deep Neural Network](https://arxiv.org/abs/2007.06765) |
| SI-NI-FGSM | White-box | 📃[Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks](https://arxiv.org/abs/1908.06281) |
| SignHunter | Black-box | 📃[Sign Bits Are All You Need for Black-Box Attacks](https://openreview.net/forum?id=SygW0TEFwH) |
| SimBA | Black-box | 📃[Simple Black-box Adversarial Attacks](https://arxiv.org/abs/1905.07121) |
Expand Down Expand Up @@ -57,14 +68,15 @@ py -3.9 -m venv [ENV_NAME]
After creating and activating your virtual environment, you can install necessary libraries via the requirements.txt.

```bash
git clone https://github.com/myuito3/AdvGrads.git
cd AdvGrads/
pip install -r requirements.txt
```

### Installing AdvGrads
Install AdvGrads in editable mode from source code:

```bash
git clone https://github.com/myuito3/AdvGrads.git
python -m pip install -e .
```

Expand Down
87 changes: 46 additions & 41 deletions advgrads/adversarial/attacks/base_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, config: AttackConfig, **kwargs) -> None:
if self.norm not in self.norm_allow_list:
raise ValueError(f"Method does not support {self.norm} perturbation norm.")

def __call__(self, *args: Any, **kwargs: Any) -> Dict[ResultHeadNames, Any]:
def __call__(self, *args: Any, **kwargs: Any) -> Dict[ResultHeadNames, Tensor]:
return self.get_outputs(*args, **kwargs)

@property
Expand Down Expand Up @@ -114,69 +114,74 @@ def run_attack(
"""
raise NotImplementedError

def sanity_check(self, x: Tensor, x_adv: Tensor) -> None:
"""Ensure that the amount of perturbation is properly controlled. This method
is specifically used to check the amount of perturbation of norm-constrained
type attack methods.
Args:
x: Original images.
x_adv: Perturbed images.
"""
if self.eps > 0.0:
deltas = x_adv - x
if self.norm == "l_inf":
real = (
deltas.abs().max().half()
) # ignore slight differences within the decimal point
msg = f"Perturbations beyond the l_inf sphere ({real})."
elif self.norm == "l_2":
real = torch.norm(deltas.view(x.shape[0], -1), p=2, dim=-1).max()
msg = f"Perturbations beyond the l_2 sphere ({real})."
elif self.norm == "l_0":
raise NotImplementedError

assert real <= self.eps, msg

def get_outputs(
self,
x: Tensor,
y: Tensor,
batch: Dict[str, Tensor],
model: Model,
thirdparty_defense: Optional[Defense] = None,
**kwargs,
) -> Dict[ResultHeadNames, Any]:
) -> Dict[ResultHeadNames, Tensor]:
"""Returns raw attack results processed.
Args:
x: Images to be searched for adversarial examples.
y: Ground truth labels of images.
batch: A batch including original images and labels.
model: A model to be attacked.
thirdparty_defense: Thirdparty defense method instance.
"""
x, y = batch["images"], batch["labels"]
attack_outputs = self.run_attack(x, y, model, **kwargs)
self.sanity_check(x, attack_outputs[ResultHeadNames.X_ADV])
x_adv = attack_outputs[ResultHeadNames.X_ADV]
self.sanity_check(x, x_adv)

# If a defensive method is defined, the process is performed here. This
# corresponds to Section 5.2 (GRAY BOX: IMAGE TRANSFORMATIONS AT TEST TIME) in
# the paper of Guo et al [https://arxiv.org/pdf/1711.00117.pdf].
if thirdparty_defense is not None:
attack_outputs[ResultHeadNames.X_ADV] = thirdparty_defense(
attack_outputs[ResultHeadNames.X_ADV]
)

with torch.no_grad():
logits = model(attack_outputs[ResultHeadNames.X_ADV])
if thirdparty_defense is not None:
logits = model(thirdparty_defense(x_adv))
else:
logits = model(x_adv)
preds = torch.argmax(logits, dim=-1)
cond = (preds == y) if self.targeted else (preds != y)
attack_outputs[ResultHeadNames.NUM_SUCCEED] = cond.sum()
succeed = (preds == y) if self.targeted else (preds != y)

if ResultHeadNames.QUERIES in attack_outputs.keys():
attack_outputs[ResultHeadNames.QUERIES_SUCCEED] = attack_outputs[
ResultHeadNames.QUERIES
][cond]
attack_outputs[ResultHeadNames.PREDS] = preds
attack_outputs[ResultHeadNames.SUCCEED] = succeed
attack_outputs[ResultHeadNames.NUM_SUCCEED] = succeed.sum()

for key, value in attack_outputs.items():
if isinstance(value, Tensor):
attack_outputs[key] = value.cpu()
return attack_outputs

def sanity_check(self, x: Tensor, x_adv: Tensor) -> None:
"""Ensure that the amount of perturbation is properly controlled. This method
is specifically used to check the amount of perturbation of norm-constrained
type attack methods.
def get_metrics_dict(
self, outputs: Dict[ResultHeadNames, Tensor], batch: Dict[str, Tensor]
) -> Dict[str, Tensor]:
"""Compute and returns metrics.
Args:
x: Original images.
x_adv: Perturbed images.
outputs: The output to compute metrics dict to.
batch: Ground truth batch corresponding to outputs.
"""
if self.eps > 0.0:
deltas = x_adv - x
if self.norm == "l_inf":
real = (
deltas.abs().max().half()
) # ignore slight differences within the decimal point
msg = f"Perturbations beyond the l_inf sphere ({real})."
elif self.norm == "l_2":
real = torch.norm(deltas.view(x.shape[0], -1), p=2, dim=-1).max()
msg = f"Perturbations beyond the l_2 sphere ({real})."
elif self.norm == "l_0":
raise NotImplementedError

assert real <= self.eps, msg
return {}
14 changes: 14 additions & 0 deletions advgrads/adversarial/attacks/deepfool.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,17 @@ def run_attack(
)

return {ResultHeadNames.X_ADV: x_adv}

def get_metrics_dict(
self, outputs: Dict[ResultHeadNames, Tensor], batch: Dict[str, Tensor]
) -> Dict[str, Tensor]:
metrics_dict = {}
succeed = outputs[ResultHeadNames.SUCCEED]

# perturbation norm
l2_norm_succeed = torch.norm(
outputs[ResultHeadNames.X_ADV] - batch["images"], p=2, dim=[1, 2, 3]
)[succeed]
metrics_dict["l2_norm"] = l2_norm_succeed

return metrics_dict
2 changes: 1 addition & 1 deletion advgrads/adversarial/attacks/pi_fgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_attack(

c = x.shape[1]
stack_kern, padding_size = project_kern(kern_size=3, channels=c)
stack_kern.to(x.device)
stack_kern = stack_kern.to(x.device)

amplification = 0.0
for _ in range(self.max_iters):
Expand Down
12 changes: 12 additions & 0 deletions advgrads/adversarial/attacks/signhunter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,15 @@ def run_attack(
h = 0

return {ResultHeadNames.X_ADV: x_adv, ResultHeadNames.QUERIES: n_queries}

def get_metrics_dict(
self, outputs: Dict[ResultHeadNames, Tensor], batch: Dict[str, Tensor]
) -> Dict[str, Tensor]:
metrics_dict = {}
succeed = outputs[ResultHeadNames.SUCCEED]

# query
queries_succeed = outputs[ResultHeadNames.QUERIES][succeed]
metrics_dict[ResultHeadNames.QUERIES_SUCCEED] = queries_succeed

return metrics_dict
23 changes: 22 additions & 1 deletion advgrads/adversarial/attacks/simba.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def __init__(self, config: SimBAAttackConfig) -> None:
"SimBA is a minimum-norm attack, not a norm-constrained attack."
)
if self.max_iters > 0:
raise ValueError()
raise ValueError(
"The maximum number of queries for SimBA is controlled by the "
"freq_dims parameter in the config."
)

self.loss = (
nn.CrossEntropyLoss(reduction="none")
Expand Down Expand Up @@ -184,3 +187,21 @@ def run_attack(

x_best, _, _, _ = self.get_data(torch.arange(x.shape[0]))
return {ResultHeadNames.X_ADV: x_best, ResultHeadNames.QUERIES: n_queries}

def get_metrics_dict(
self, outputs: Dict[ResultHeadNames, Tensor], batch: Dict[str, Tensor]
) -> Dict[str, Tensor]:
metrics_dict = {}
succeed = outputs[ResultHeadNames.SUCCEED]

# query
queries_succeed = outputs[ResultHeadNames.QUERIES][succeed]
metrics_dict[ResultHeadNames.QUERIES_SUCCEED] = queries_succeed

# perturbation norm
l2_norm_succeed = torch.norm(
outputs[ResultHeadNames.X_ADV] - batch["images"], p=2, dim=[1, 2, 3]
)[succeed]
metrics_dict["l2_norm"] = l2_norm_succeed

return metrics_dict
12 changes: 12 additions & 0 deletions advgrads/adversarial/attacks/square.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,15 @@ def run_attack(
n_queries[idx_to_fool] += 1

return {ResultHeadNames.X_ADV: x_best, ResultHeadNames.QUERIES: n_queries}

def get_metrics_dict(
self, outputs: Dict[ResultHeadNames, Tensor], batch: Dict[str, Tensor]
) -> Dict[str, Tensor]:
metrics_dict = {}
succeed = outputs[ResultHeadNames.SUCCEED]

# query
queries_succeed = outputs[ResultHeadNames.QUERIES][succeed]
metrics_dict[ResultHeadNames.QUERIES_SUCCEED] = queries_succeed

return metrics_dict
5 changes: 4 additions & 1 deletion advgrads/adversarial/attacks/utils/result_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ class ResultHeadNames(Enum):

X_ADV = "x_adv"
SHAPE = "shape"
PREDS = "preds"
SUCCEED = "succeed"
NUM_SUCCEED = "num_succeed"
SUCCESS_RATE = "success_rate"

QUERIES = "queries"
QUERIES_SUCCEED = "queries_succeed"
MEAN_QUERY = "mean_query"
MEDIAN_QUERY = "median_query"
NUM_SUCCEED = "num_succeed"
4 changes: 3 additions & 1 deletion advgrads/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from torch.utils.data import Dataset

from advgrads.data.datasets.imagenet_dataset import ImagenetDataset
from advgrads.data.datasets.vision_dataset import (
MnistDataset,
Cifar10Dataset,
Expand All @@ -28,7 +29,8 @@ def get_dataset_class(name: str) -> Dataset:


dataset_class_dict = {
"mnist": MnistDataset,
"cifar10": Cifar10Dataset,
"imagenet": ImagenetDataset,
"mnist": MnistDataset,
}
all_dataset_names = list(dataset_class_dict.keys())
48 changes: 48 additions & 0 deletions advgrads/data/datasets/imagenet_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2023 Makoto Yuito. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ImageNet dataset."""

from typing import List, Optional

from torchvision import transforms
from torchvision.datasets import ImageNet


DATA_PATH = "./data/imagenet"


class ImagenetDataset(ImageNet):
"""The ImageNet Dataset.
Args:
transform: Transform objects for image preprocessing.
indices_to_use: List of image indices to be used.
"""

def __init__(
self,
transform: transforms.Compose,
indices_to_use: Optional[List[int]] = None,
) -> None:
super().__init__(root=DATA_PATH, split="val", transform=transform)

all_samples = self.samples
self.samples = []
for i in indices_to_use:
self.samples.append(all_samples[i])

@property
def num_classes(self) -> int:
return 1000
10 changes: 10 additions & 0 deletions advgrads/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
TradesMnistModelConfig,
)
from advgrads.models.base_model import Model
from advgrads.models.imagenet.inception import InceptionV3ImagenetModelConfig
from advgrads.models.imagenet.resnet import Resnet50ImagenetModelConfig
from advgrads.models.imagenet.vgg import (
Vgg16ImagenetModelConfig,
Vgg16bnImagenetModelConfig,
)
from advgrads.models.pytorch_playground.cifar10_model import PtPgCifar10ModelConfig
from advgrads.models.pytorch_playground.mnist_model import PtPgMnistModelConfig

Expand All @@ -31,5 +37,9 @@ def get_model_config_class(name: str) -> Model:
"ptpg-mnist": PtPgMnistModelConfig,
"ptpg-cifar10": PtPgCifar10ModelConfig,
"trades-mnist": TradesMnistModelConfig,
"inceptionv3-imagenet": InceptionV3ImagenetModelConfig,
"resnet50-imagenet": Resnet50ImagenetModelConfig,
"vgg16-imagenet": Vgg16ImagenetModelConfig,
"vgg16bn-imagenet": Vgg16bnImagenetModelConfig,
}
all_model_names = list(model_config_class_dict.keys())
13 changes: 13 additions & 0 deletions advgrads/models/imagenet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 Makoto Yuito. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 97b6a0e

Please sign in to comment.