-
Notifications
You must be signed in to change notification settings - Fork 0
/
ssdoil.py
102 lines (87 loc) · 3.77 KB
/
ssdoil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastai.vision.data import ObjectCategoryList, ObjectItemList, imagenet_stats
from fastai.vision.image import ImageBBox
import torch
def nms(boxes, scores, overlap=0.5, top_k=100):
keep = scores.new(scores.size(0)).zero_().long()
if boxes.numel() == 0: return keep
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1)
v, idx = scores.sort(0) # sort in ascending order
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()
count = 0
while idx.numel() > 0:
i = idx[-1] # index of current largest val
keep[count] = i
count += 1
if idx.size(0) == 1: break
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)
# store element-wise max with next highest score
xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)
inter = w*h
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter/union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
return keep, count
class SSDObjectCategoryList(ObjectCategoryList):
"`ItemList` for labelled bounding boxes detected using SSD."
def analyze_pred(pred, thresh=0.5, nms_overlap=0.1, ssd=None):
# def analyze_pred(pred, anchors, grid_sizes, thresh=0.5, nms_overlap=0.1, ssd=None):
b_clas, b_bb = pred
a_ic = ssd._actn_to_bb(b_bb, ssd._anchors.cpu(), ssd._grid_sizes.cpu())
conf_scores, clas_ids = b_clas[:, 1:].max(1)
conf_scores = b_clas.t().sigmoid()
out1, bbox_list, class_list = [], [], []
for cl in range(1, len(conf_scores)):
c_mask = conf_scores[cl] > thresh
if c_mask.sum() == 0:
continue
scores = conf_scores[cl][c_mask]
l_mask = c_mask.unsqueeze(1)
l_mask = l_mask.expand_as(a_ic)
boxes = a_ic[l_mask].view(-1, 4) # boxes are now in range[ 0, 1]
boxes = (boxes-0.5) * 2.0 # putting boxes in range[-1, 1]
ids, count = nms(boxes.data, scores, nms_overlap, 50) # FIX- NMS overlap hardcoded
ids = ids[:count]
out1.append(scores[ids])
bbox_list.append(boxes.data[ids])
class_list.append(torch.tensor([cl]*count))
if len(bbox_list) == 0:
return None #torch.Tensor(size=(0,4)), torch.Tensor()
return torch.cat(bbox_list, dim=0), torch.cat(class_list, dim=0) # torch.cat(out1, dim=0),
def reconstruct(self, t, x):
if t is None: return None
bboxes, labels = t
if len((labels - self.pad_idx).nonzero()) == 0: return
i = (labels - self.pad_idx).nonzero().min()
bboxes,labels = bboxes[i:],labels[i:]
return ImageBBox.create(*x.size, bboxes, labels=labels, classes=self.classes, scale=False)
class SSDObjectItemList(ObjectItemList):
"`ItemList` suitable for object detection."
_label_cls,_square_show_res = SSDObjectCategoryList,False