forked from ZBox1005/CoVer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_ood_detection.py
167 lines (142 loc) · 6.81 KB
/
eval_ood_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import argparse
import numpy as np
import torch
from scipy import stats
from utils.common import setup_seed, get_num_cls, get_test_labels
from utils.detection_util import print_measures, get_and_print_results, get_ood_scores_clip, get_ood_scores_resnet
from utils.file_ops import save_as_dataframe, setup_log
from utils.plot_util import plot_distribution
from utils.train_eval_util import set_model_clip, set_model_resnet50, set_id_loader, set_ood_loader_ImageNet
def process_args():
parser = argparse.ArgumentParser(description='Evaluates CoVer Score',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# setting for each run
parser.add_argument('--in_dataset', default='ImageNet', type=str,
choices=['ImageNet', 'ImageNet10', 'ImageNet20', 'ImageNet100'], help='in-distribution dataset')
parser.add_argument('--root-dir', default="datasets", type=str,
help='root dir of datasets')
parser.add_argument('--name', default="eval_ood",
type=str, help="unique ID for the run")
parser.add_argument('--seed', default=5, type=int, help="random seed")
parser.add_argument('--gpu', default=0, type=int,
help='the GPU indice to use')
parser.add_argument('-b', '--batch-size', default=512, type=int,
help='mini-batch size')
parser.add_argument('--T', type=int, default=1,
help='temperature parameter')
parser.add_argument('--model', default='CLIP', choices=['ResNet50', 'CLIP'],
type=str, help='model architecture')
parser.add_argument('--CLIP_ckpt', type=str, default='ViT-B-16',
choices=['RN50', 'ViT-B-32', 'ViT-B-16', 'ViT-L-14'],
help='which pretrained img encoder to use')
parser.add_argument('--score', default='MCM', type=str, choices=['CoVer', 'energy'], help='score options')
# for pretrained resnet and ASH
parser.add_argument('--train_restore_file', default="resnet50-19c8e357.pth", type=str, help="which pth to use")
parser.add_argument('--ash_method', default="ash_s@90", type=str, help="which pth to use")
# for DICE and DICE + ReAct
parser.add_argument('--p', default=None, type=int, help="p in dice")
parser.add_argument('--clip_threshold', default=1.0, type=float, help="clip threshold in react")
args = parser.parse_args()
args.n_cls = get_num_cls(args)
args.log_directory = f"results/{args.in_dataset}/{args.score}/{args.model}_{args.CLIP_ckpt}_T_{args.T}_ID_{args.name}/"
for item in imagenet_c.items():
method = item[0]
args.log_directory += f"_{method}"
severities = item[1]
for severity in severities:
args.log_directory += f"_{str(severity)}"
os.makedirs(args.log_directory, exist_ok=True)
return args
# --- Select the corruption types for input expansion ---
# --- Recommend corruption types from validation set ---
# imagenet_c = {
# "origin": tuple([1]),
# "brightness": tuple([1, 2]),
# "fog": tuple([1, 2]),
# "saturate": tuple([1, 2]),
# "motion_blur": tuple([1, 2]),
# "defocus_blur": tuple([1, 2]),
# "gaussian_blur": tuple([1, 2]),
# }
# ------
imagenet_c = {
# Fixed
"origin": tuple([1]),
# Selective
"brightness": tuple([1, 2, 3, 4, 5]),
"fog": tuple([1, 2, 3, 4, 5]),
"contrast": tuple([1, 2, 3, 4, 5]),
"motion_blur": tuple([1, 2, 3, 4, 5]),
"defocus_blur": tuple([1, 2, 3, 4, 5]),
"gaussian_blur": tuple([1, 2, 3, 4, 5]),
"spatter": tuple([1, 2, 3, 4, 5]),
"saturate": tuple([1, 2, 3, 4, 5]),
"elastic_transform": tuple([1, 2, 3, 4, 5]),
"jpeg_compression": tuple([1, 2, 3, 4, 5]),
"pixelate": tuple([1, 2, 3, 4, 5]),
"speckle_noise": tuple([1, 2, 3, 4, 5]),
"glass_blur": tuple([1, 2, 3, 4, 5]),
"gaussian_noise": tuple([1, 2, 3, 4, 5]),
"shot_noise": tuple([1, 2, 3, 4, 5]),
"zoom_blur": tuple([1, 2, 3, 4, 5]),
"snow": tuple([1, 2, 3, 4, 5]),
"impulse_noise": tuple([1, 2, 3, 4, 5]),
}
def main():
args = process_args()
setup_seed(args.seed)
log = setup_log(args)
assert torch.cuda.is_available()
torch.cuda.set_device(args.gpu)
# Load OOD detector
if args.model == 'CLIP':
net, preprocess = set_model_clip(args)
elif 'ResNet' in args.model:
net, preprocess = set_model_resnet50(args)
else:
raise ValueError(f"Unsupported model type: {args.model}")
net.eval()
# Following MCM
if args.in_dataset in ['ImageNet10']:
out_datasets = ['ImageNet20', 'ImageNet100']
elif args.in_dataset in ['ImageNet20', 'ImageNet100']:
out_datasets = ['ImageNet10']
# ImageNet OOD detection benchmark
elif args.in_dataset in ['ImageNet']:
# Test
out_datasets = ['iNaturalist', 'SUN', 'places365', 'dtd']
# Validation
# out_datasets = ['SVHN']
# Get original loader and aug loader dict for ID dataset
test_loader, in_aug_loader_dict = set_id_loader(args, preprocess)
test_labels = get_test_labels(args, test_loader)
# ID set
if args.model == 'CLIP':
in_score = get_ood_scores_clip(args, net, in_aug_loader_dict, test_labels, imagenet_c,
dataset_name='ImageNet')
elif 'ResNet' in args.model:
in_score = get_ood_scores_resnet(args, net, in_aug_loader_dict, imagenet_c, dataset_name='ImageNet')
# OOD set
auroc_list, aupr_list, fpr_list = [], [], []
for out_dataset in out_datasets:
log.debug(f"Evaluating OOD dataset {out_dataset}")
ood_loader, out_aug_loader_dict = set_ood_loader_ImageNet(args, out_dataset, preprocess,
root=os.path.join(args.root_dir,
'ImageNet_OOD_dataset'))
if args.model == 'CLIP':
out_score = get_ood_scores_clip(args, net, out_aug_loader_dict, test_labels, imagenet_c,
dataset_name=out_dataset)
elif 'ResNet' in args.model:
out_score = get_ood_scores_resnet(args, net, out_aug_loader_dict, imagenet_c, dataset_name=out_dataset)
log.debug(f"in scores: {stats.describe(in_score)}")
log.debug(f"out scores: {stats.describe(out_score)}")
plot_distribution(args, in_score, out_score, out_dataset)
get_and_print_results(args, log, in_score, out_score,
auroc_list, aupr_list, fpr_list)
log.debug('\n\nMean Test Results')
print_measures(log, np.mean(auroc_list), np.mean(aupr_list),
np.mean(fpr_list), method_name=args.score)
save_as_dataframe(args, out_datasets, fpr_list, auroc_list, aupr_list)
if __name__ == '__main__':
main()