-
Notifications
You must be signed in to change notification settings - Fork 33
/
eval_sde_adv_bpda.py
279 lines (214 loc) · 11 KB
/
eval_sde_adv_bpda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for DiffPure. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
import argparse
import logging
import yaml
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from bpda_eot.bpda_eot_attack import BPDA_EOT_Attack
import utils
from utils import str2bool, get_accuracy, get_image_classifier, load_data
from runners.diffpure_ddpm import Diffusion
from runners.diffpure_guided import GuidedDiffusion
from runners.diffpure_sde import RevGuidedDiffusion
class ResNet_Adv_Model(nn.Module):
def __init__(self, args, config):
super().__init__()
# image classifier
self.resnet = get_image_classifier(args.classifier_name).to(config.device)
def purify(self, x):
return x
def forward(self, x, mode='purify_and_classify'):
if mode == 'purify':
out = self.purify(x)
elif mode == 'classify':
out = self.resnet(x) # x in [0, 1]
elif mode == 'purify_and_classify':
x = self.purify(x)
out = self.resnet(x) # x in [0, 1]
else:
raise NotImplementedError(f'unknown mode: {mode}')
return out
class SDE_Adv_Model(nn.Module):
def __init__(self, args, config):
super().__init__()
self.args = args
# image classifier
self.resnet = get_image_classifier(args.classifier_name).to(config.device)
# diffusion model
print(f'diffusion_type: {args.diffusion_type}')
if args.diffusion_type == 'ddpm':
self.runner = GuidedDiffusion(args, config, device=config.device)
elif args.diffusion_type == 'sde':
self.runner = RevGuidedDiffusion(args, config, device=config.device)
elif args.diffusion_type == 'celebahq-ddpm':
self.runner = Diffusion(args, config, device=config.device)
else:
raise NotImplementedError('unknown diffusion type')
self.register_buffer('counter', torch.zeros(1, device=config.device))
self.tag = None
# use `counter` to record the the sampling time every 5 NFEs (note we hardcoded print freq to 5,
# and you may want to change the freq)
def reset_counter(self):
self.counter = torch.zeros(1, dtype=torch.int, device=config.device)
def set_tag(self, tag=None):
self.tag = tag
def purify(self, x):
counter = self.counter.item()
if counter % 5 == 0:
print(f'diffusion times: {counter}')
# imagenet [3, 224, 224] -> [3, 256, 256] -> [3, 224, 224]
if 'imagenet' in self.args.domain:
x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
start_time = time.time()
x_re = self.runner.image_editing_sample((x - 0.5) * 2, bs_id=counter, tag=self.tag)
minutes, seconds = divmod(time.time() - start_time, 60)
if 'imagenet' in self.args.domain:
x_re = F.interpolate(x_re, size=(224, 224), mode='bilinear', align_corners=False)
if counter % 5 == 0:
print(f'x shape (before diffusion models): {x.shape}')
print(f'x shape (before resnet): {x_re.shape}')
print("Sampling time per batch: {:0>2}:{:05.2f}".format(int(minutes), seconds))
self.counter += 1
return (x_re + 1) * 0.5
def forward(self, x, mode='purify_and_classify'):
if mode == 'purify':
out = self.purify(x)
elif mode == 'classify':
out = self.resnet(x) # x in [0, 1]
elif mode == 'purify_and_classify':
x = self.purify(x)
out = self.resnet(x) # x in [0, 1]
else:
raise NotImplementedError(f'unknown mode: {mode}')
return out
def eval_bpda(args, config, model, x_val, y_val, adv_batch_size, log_dir):
ngpus = torch.cuda.device_count()
model_ = model
if ngpus > 1:
model_ = model.module
x_val, y_val = x_val.to(config.device), y_val.to(config.device)
# ------------------ apply the attack to resnet ------------------
print(f'apply the bpda attack to resnet...')
resnet_bpda = ResNet_Adv_Model(args, config)
if ngpus > 1:
resnet_bpda = torch.nn.DataParallel(resnet_bpda)
start_time = time.time()
init_acc = get_accuracy(resnet_bpda, x_val, y_val, bs=adv_batch_size)
print('initial accuracy: {:.2%}, time elapsed: {:.2f}s'.format(init_acc, time.time() - start_time))
adversary_resnet = BPDA_EOT_Attack(resnet_bpda, adv_eps=args.adv_eps, eot_defense_reps=args.eot_defense_reps,
eot_attack_reps=args.eot_attack_reps)
start_time = time.time()
class_batch, ims_adv_batch = adversary_resnet.attack_all(x_val, y_val, batch_size=adv_batch_size)
init_acc = float(class_batch[0, :].sum()) / class_batch.shape[1]
robust_acc = float(class_batch[-1, :].sum()) / class_batch.shape[1]
print('init acc: {:.2%}, robust acc: {:.2%}, time elapsed: {:.2f}s'.format(init_acc, robust_acc, time.time() - start_time))
print(f'x_adv_resnet shape: {ims_adv_batch.shape}')
torch.save([ims_adv_batch, y_val], f'{log_dir}/x_adv_resnet_sd{args.seed}.pt')
# ------------------ apply the attack to sde_adv ------------------
print(f'apply the bpda attack to sde_adv...')
start_time = time.time()
model_.reset_counter()
model_.set_tag('no_adv')
init_acc = get_accuracy(model, x_val, y_val, bs=adv_batch_size)
print('initial accuracy: {:.2%}, time elapsed: {:.2f}s'.format(init_acc, time.time() - start_time))
adversary_sde = BPDA_EOT_Attack(model, adv_eps=args.adv_eps, eot_defense_reps=args.eot_defense_reps,
eot_attack_reps=args.eot_attack_reps)
start_time = time.time()
model_.reset_counter()
model_.set_tag()
class_batch, ims_adv_batch = adversary_sde.attack_all(x_val, y_val, batch_size=adv_batch_size)
init_acc = float(class_batch[0, :].sum()) / class_batch.shape[1]
robust_acc = float(class_batch[-1, :].sum()) / class_batch.shape[1]
print('init acc: {:.2%}, robust acc: {:.2%}, time elapsed: {:.2f}s'.format(init_acc, robust_acc, time.time() - start_time))
print(f'x_adv_sde shape: {ims_adv_batch.shape}')
torch.save([ims_adv_batch, y_val], f'{log_dir}/x_adv_sde_sd{args.seed}.pt')
def robustness_eval(args, config):
middle_name = '_'.join([args.diffusion_type, 'bpda'])
log_dir = os.path.join(args.image_folder, args.classifier_name, middle_name,
'seed' + str(args.seed), 'data' + str(args.data_seed))
os.makedirs(log_dir, exist_ok=True)
args.log_dir = log_dir
logger = utils.Logger(file_name=f'{log_dir}/log.txt', file_mode="w+", should_flush=True)
ngpus = torch.cuda.device_count()
adv_batch_size = args.adv_batch_size * ngpus
print(f'ngpus: {ngpus}, adv_batch_size: {adv_batch_size}')
# load model
print('starting the model and loader...')
model = SDE_Adv_Model(args, config)
if ngpus > 1:
model = torch.nn.DataParallel(model)
model = model.eval().to(config.device)
# load data
x_val, y_val = load_data(args, adv_batch_size)
# eval classifier and sde_adv against bpda attack
eval_bpda(args, config, model, x_val, y_val, adv_batch_size, log_dir)
logger.close()
def parse_args_and_config():
parser = argparse.ArgumentParser(description=globals()['__doc__'])
# diffusion models
parser.add_argument('--config', type=str, required=True, help='Path to the config file')
parser.add_argument('--data_seed', type=int, default=0, help='Random seed')
parser.add_argument('--seed', type=int, default=1234, help='Random seed')
parser.add_argument('--exp', type=str, default='exp', help='Path for saving running related data.')
parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
parser.add_argument('-i', '--image_folder', type=str, default='images', help="The folder name of samples")
parser.add_argument('--ni', action='store_true', help="No interaction. Suitable for Slurm Job launcher")
parser.add_argument('--sample_step', type=int, default=1, help='Total sampling steps')
parser.add_argument('--t', type=int, default=400, help='Sampling noise scale')
parser.add_argument('--t_delta', type=int, default=15, help='Perturbation range of sampling noise scale')
parser.add_argument('--rand_t', type=str2bool, default=False, help='Decide if randomize sampling noise scale')
parser.add_argument('--diffusion_type', type=str, default='ddpm', help='[ddpm, sde, celebahq-ddpm]')
parser.add_argument('--score_type', type=str, default='guided_diffusion', help='[guided_diffusion, score_sde]')
parser.add_argument('--eot_iter', type=int, default=20, help='only for rand version of autoattack')
parser.add_argument('--use_bm', action='store_true', help='whether to use brownian motion')
parser.add_argument('--eot_defense_reps', type=int, default=150)
parser.add_argument('--eot_attack_reps', type=int, default=15)
# adv
parser.add_argument('--domain', type=str, default='celebahq', help='which domain: celebahq, cat, car, imagenet')
parser.add_argument('--classifier_name', type=str, default='Eyeglasses', help='which classifier to use')
parser.add_argument('--partition', type=str, default='val')
parser.add_argument('--adv_batch_size', type=int, default=64)
parser.add_argument('--num_sub', type=int, default=1000, help='imagenet subset')
parser.add_argument('--adv_eps', type=float, default=0.07)
args = parser.parse_args()
# parse config file
with open(os.path.join('configs', args.config), 'r') as f:
config = yaml.safe_load(f)
new_config = utils.dict2namespace(config)
level = getattr(logging, args.verbose.upper(), None)
if not isinstance(level, int):
raise ValueError('level {} not supported'.format(args.verbose))
handler1 = logging.StreamHandler()
formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
handler1.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler1)
logger.setLevel(level)
args.image_folder = os.path.join(args.exp, args.image_folder)
os.makedirs(args.image_folder, exist_ok=True)
# add device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logging.info("Using device: {}".format(device))
new_config.device = device
# set random seed
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
return args, new_config
if __name__ == '__main__':
args, config = parse_args_and_config()
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
robustness_eval(args, config)