-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate_explainability.py
155 lines (129 loc) · 6.68 KB
/
evaluate_explainability.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import random
import torch
from captum.attr import IntegratedGradients, InputXGradient
from models.resnet import resnet50
from models.vgg import vgg16
from models.ViT.ViT_new import vit_base_patch16_224
from models.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from models.model_wrapper import StandardModel, ViTModel
from evaluation_protocols import accuracy_protocol, controlled_synthetic_data_check_protocol, single_deletion_protocol, preservation_check_protocol, deletion_check_protocol, target_sensitivity_protocol, distractibility_protocol, background_independence_protocol
from explainers.explainer_wrapper import CaptumAttributionExplainer, ViTGradCamExplainer, ViTRolloutExplainer, ViTCheferLRPExplainer, CustomExplainer
parser = argparse.ArgumentParser(description='FunnyBirds - Explanation Evaluation')
parser.add_argument('--data', metavar='DIR', required=True,
help='path to dataset (default: imagenet)')
parser.add_argument('--model', required=True,
choices=['resnet50', 'vgg16', 'vit_b_16'],
help='model architecture')
parser.add_argument('--explainer', required=True,
choices=['IntegratedGradients', 'InputXGradient', 'Rollout', 'CheferLRP', 'CustomExplainer'],
help='explainer')
parser.add_argument('--checkpoint_name', type=str, required=False, default=None,
help='checkpoint name (including dir)')
parser.add_argument('--gpu', default=0, type=int,
help='GPU id to use.')
parser.add_argument('--seed', default=0, type=int,
help='seed')
parser.add_argument('--batch_size', default=32, type=int,
help='batch size for protocols that do not require custom BS such as accuracy')
parser.add_argument('--nr_itrs', default=2501, type=int,
help='batch size for protocols that do not require custom BS such as accuracy')
parser.add_argument('--accuracy', default=False, action='store_true',
help='compute accuracy')
parser.add_argument('--controlled_synthetic_data_check', default=False, action='store_true',
help='compute controlled synthetic data check')
parser.add_argument('--single_deletion', default=False, action='store_true',
help='compute single deletion')
parser.add_argument('--preservation_check', default=False, action='store_true',
help='compute preservation check')
parser.add_argument('--deletion_check', default=False, action='store_true',
help='compute deletion check')
parser.add_argument('--target_sensitivity', default=False, action='store_true',
help='compute target sensitivity')
parser.add_argument('--distractibility', default=False, action='store_true',
help='compute distractibility')
parser.add_argument('--background_independence', default=False, action='store_true',
help='compute background dependence')
def main():
args = parser.parse_args()
device = 'cuda:' + str(args.gpu)
random.seed(args.seed)
torch.manual_seed(args.seed)
# create model
if args.model == 'resnet50':
model = resnet50(num_classes = 50)
model = StandardModel(model)
elif args.model == 'vgg16':
model = vgg16(num_classes = 50)
model = StandardModel(model)
elif args.model == 'vit_b_16':
if args.explainer == 'CheferLRP':
model = vit_LRP(num_classes=50)
else:
model = vit_base_patch16_224(num_classes = 50)
model = ViTModel(model)
else:
print('Model not implemented')
if args.checkpoint_name:
model.load_state_dict(torch.load(args.checkpoint_name, map_location=torch.device('cpu'))['state_dict'])
model = model.to(device)
model.eval()
# create explainer
if args.explainer == 'InputXGradient':
explainer = InputXGradient(model)
explainer = CaptumAttributionExplainer(explainer)
elif args.explainer == 'IntegratedGradients':
explainer = IntegratedGradients(model)
baseline = torch.zeros((1,3,256,256)).to(device)
explainer = CaptumAttributionExplainer(explainer, baseline=baseline)
elif args.explainer == 'Rollout':
explainer = ViTRolloutExplainer(model)
elif args.explainer == 'CheferLRP':
explainer = ViTCheferLRPExplainer(model)
elif args.explainer == 'CustomExplainer':
...
else:
print('Explainer not implemented')
accuracy, csdc, pc, dc, distractibility, sd, ts = -1, -1, -1, -1, -1, -1, -1
if args.accuracy:
print('Computing accuracy...')
accuracy = accuracy_protocol(model, args)
accuracy = round(accuracy, 5)
if args.controlled_synthetic_data_check:
print('Computing controlled synthetic data check...')
csdc = controlled_synthetic_data_check_protocol(model, explainer, args)
if args.target_sensitivity:
print('Computing target sensitivity...')
ts = target_sensitivity_protocol(model, explainer, args)
ts = round(ts, 5)
if args.single_deletion:
print('Computing single deletion...')
sd = single_deletion_protocol(model, explainer, args)
sd = round(sd, 5)
if args.preservation_check:
print('Computing preservation check...')
pc = preservation_check_protocol(model, explainer, args)
if args.deletion_check:
print('Computing deletion check...')
dc = deletion_check_protocol(model, explainer, args)
if args.distractibility:
print('Computing distractibility...')
distractibility = distractibility_protocol(model, explainer, args)
if args.background_independence:
print('Computing background independence...')
background_independence = background_independence_protocol(model, args)
background_independence = round(background_independence, 5)
# select completeness and distractability thresholds such that they maximize the sum of both
max_score = 0
best_threshold = -1
for threshold in csdc.keys():
max_score_tmp = csdc[threshold]/3. + pc[threshold]/3. + dc[threshold]/3. + distractibility[threshold]
if max_score_tmp > max_score:
max_score = max_score_tmp
best_threshold = threshold
print('FINAL RESULTS:')
print('Accuracy, CSDC, PC, DC, Distractability, Background independence, SD, TS')
print('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(accuracy, round(csdc[best_threshold],5), round(pc[best_threshold],5), round(dc[best_threshold],5), round(distractibility[best_threshold],5), background_independence, sd, ts))
print('Best threshold:', best_threshold)
if __name__ == '__main__':
main()