-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathinfer_cls.py
113 lines (87 loc) · 4.47 KB
/
infer_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import torch
from torch.backends import cudnn
cudnn.enabled = True
import voc12.data
import scipy.misc
import importlib
from torch.utils.data import DataLoader
import torchvision
from tool import imutils, pyutils
import argparse
from PIL import Image
import torch.nn.functional as F
import os.path
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--weights", required=True, type=str)
parser.add_argument("--network", default="network.vgg16_cls", type=str)
parser.add_argument("--infer_list", default="voc12/val.txt", type=str)
parser.add_argument("--num_workers", default=8, type=int)
parser.add_argument("--voc12_root", required=True, type=str)
parser.add_argument("--low_alpha", default=4, type=int)
parser.add_argument("--high_alpha", default=32, type=int)
parser.add_argument("--out_cam", default=None, type=str)
parser.add_argument("--out_la_crf", default=None, type=str)
parser.add_argument("--out_ha_crf", default=None, type=str)
parser.add_argument("--out_cam_pred", default=None, type=str)
args = parser.parse_args()
model = getattr(importlib.import_module(args.network), 'Net')()
model.load_state_dict(torch.load(args.weights))
model.eval()
model.cuda()
infer_dataset = voc12.data.VOC12ClsDatasetMSF(args.infer_list, voc12_root=args.voc12_root,
scales=(1, 0.5, 1.5, 2.0),
inter_transform=torchvision.transforms.Compose(
[np.asarray,
model.normalize,
imutils.HWC_to_CHW]))
infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=args.num_workers, pin_memory=True)
n_gpus = torch.cuda.device_count()
model_replicas = torch.nn.parallel.replicate(model, list(range(n_gpus)))
for iter, (img_name, img_list, label) in enumerate(infer_data_loader):
img_name = img_name[0]; label = label[0]
img_path = voc12.data.get_img_path(img_name, args.voc12_root)
orig_img = np.asarray(Image.open(img_path))
orig_img_size = orig_img.shape[:2]
def _work(i, img):
with torch.no_grad():
with torch.cuda.device(i%n_gpus):
cam = model_replicas[i%n_gpus].forward_cam(img.cuda())
cam = F.upsample(cam, orig_img_size, mode='bilinear', align_corners=False)[0]
cam = cam.cpu().numpy() * label.clone().view(20, 1, 1).numpy()
if i % 2 == 1:
cam = np.flip(cam, axis=-1)
return cam
thread_pool = pyutils.BatchThreader(_work, list(enumerate(img_list)),
batch_size=12, prefetch_size=0, processes=args.num_workers)
cam_list = thread_pool.pop_results()
sum_cam = np.sum(cam_list, axis=0)
norm_cam = sum_cam / (np.max(sum_cam, (1, 2), keepdims=True) + 1e-5)
cam_dict = {}
for i in range(20):
if label[i] > 1e-5:
cam_dict[i] = norm_cam[i]
if args.out_cam is not None:
np.save(os.path.join(args.out_cam, img_name + '.npy'), cam_dict)
if args.out_cam_pred is not None:
bg_score = [np.ones_like(norm_cam[0])*0.2]
pred = np.argmax(np.concatenate((bg_score, norm_cam)), 0)
scipy.misc.imsave(os.path.join(args.out_cam_pred, img_name + '.png'), pred.astype(np.uint8))
def _crf_with_alpha(cam_dict, alpha):
v = np.array(list(cam_dict.values()))
bg_score = np.power(1 - np.max(v, axis=0, keepdims=True), alpha)
bgcam_score = np.concatenate((bg_score, v), axis=0)
crf_score = imutils.crf_inference(orig_img, bgcam_score, labels=bgcam_score.shape[0])
n_crf_al = dict()
n_crf_al[0] = crf_score[0]
for i, key in enumerate(cam_dict.keys()):
n_crf_al[key+1] = crf_score[i+1]
return n_crf_al
if args.out_la_crf is not None:
crf_la = _crf_with_alpha(cam_dict, args.low_alpha)
np.save(os.path.join(args.out_la_crf, img_name + '.npy'), crf_la)
if args.out_ha_crf is not None:
crf_ha = _crf_with_alpha(cam_dict, args.high_alpha)
np.save(os.path.join(args.out_ha_crf, img_name + '.npy'), crf_ha)
print(iter)