-
Notifications
You must be signed in to change notification settings - Fork 5
/
detector_single.py
143 lines (100 loc) · 3.69 KB
/
detector_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import torch
from PIL import Image
import datasets
from models import build_model, load_model
from nn import DetectPostProcess
from utils.box_util import draw_object_box
from transforms import detector_transforms as transforms
def read_image(source):
return Image.open(source)
def prepare_input(img, size, enable_letterbox):
t = []
if enable_letterbox:
t.append(transforms.LetterBox())
t.append(transforms.Resize(size))
t = transforms.Compose(t)
img, _ = t(img, [])
t = transforms.Compose((transforms.ToTensor(),
transforms.Normalize()))
tensor, _ = t(img, [])
return img, tensor.unsqueeze(0)
def prepare_model(args, weight):
# build model
model = build_model(args, pretrained=False)
post_process = DetectPostProcess(model.get_anchor_box(),
args.th_conf,
args.th_iou)
# load weight
load_model(model, weight)
# transfer to GPU if possible
if torch.cuda.is_available():
model = model.cuda()
return model, post_process
def inference(model, post_process, img, th_iou, th_conf, enable_letterbox):
size = model.get_input_size()
# prepare input
img, x = prepare_input(img, size, enable_letterbox)
# inference -> postprocess(softmax->nms->...)
if torch.cuda.is_available():
x = x.cuda()
conf, loc = model(x)
return img, post_process(conf, loc)
def single_run(model, post_process, x, dataset,
th_iou, th_conf, enable_letterbox):
# change to evaluation mode
model.eval()
# inference image
img = read_image(x)
with torch.no_grad():
img, results = inference(model,
post_process,
img,
th_iou,
th_conf,
enable_letterbox)
# print results
objs = []
for _cls, _objs in enumerate(results[0]):
if not _objs:
continue
label = dataset.classes[_cls]
for _obj in _objs:
_obj.append(label)
objs.append(_obj)
img = draw_object_box(img, objs)
img.show()
def main():
parser = argparse.ArgumentParser(description='Detector Single Test')
parser.add_argument('inputs', type=str, nargs='*',
help='Input image path')
parser.add_argument('--model', default='ssd300',
help='Detector model name')
parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
type=str, help='VOC or COCO')
parser.add_argument('--weight', default=None,
help='Weight file path')
parser.add_argument('--th_conf', default=0.5, type=float,
help='Confidence Threshold')
parser.add_argument('--th_iou', default=0.5, type=float,
help='IOU Threshold')
parser.add_argument('--enable_letterbox', default=False,
action='store_true',
help='Enable letterboxing image')
args = parser.parse_args()
# dataset
if args.dataset == 'VOC':
dataset = datasets.VOCDetection
else:
raise Exception("unknown dataset")
# load weight
if args.weight:
weight = args.weight
else:
weight = 'checkpoints/' + args.model + '_latest.pth'
model, post_process = prepare_model(args, weight)
for x in args.inputs:
single_run(model, post_process, x, dataset,
args.th_iou, args.th_conf, args.enable_letterbox)
if __name__ == "__main__":
main()