-
Notifications
You must be signed in to change notification settings - Fork 7
/
stage_detection.py
195 lines (157 loc) · 8.21 KB
/
stage_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import time
import numpy as np
import cv2
# SmashScan libraries
import util
import timeline
LABELS_LIST = ["battlefield", "dreamland", "finaldest",
"fountain", "pokemon", "yoshis"]
# An object that takes a capture and a number of input parameters and performs
# a number of object detection operations. Parameters include a cv2 capture,
# darkflow object, save_flag for saving results, and show_flag for display.
class StageDetector:
def __init__(self, capture, tfnet, show_flag=False, save_flag=False):
self.capture = capture
self.tfnet = tfnet
self.save_flag = save_flag
self.show_flag = show_flag
# Predetermined parameters that have been tested to work best.
self.end_fnum = int(self.capture.get(cv2.CAP_PROP_FRAME_COUNT))
self.max_num_match_frames = 30
self.min_match_length_s = 30
self.num_match_frames = 5
self.step_size = 60
self.timeline_empty_thresh = 4
#### STAGE DETECTOR TESTS ##################################################
# Run the standard stage detector test over the entire video.
def standard_test(self):
# Create a timeline of the label history where the labels are stored as
# integers while no result is (-1). Also create a bounding box list.
dirty_timeline, bbox_hist = list(), list()
# Iterate through video and use tfnet to perform object detection.
start_time = time.time()
for fnum in range(0, self.end_fnum, self.step_size):
self.capture.set(cv2.CAP_PROP_POS_FRAMES, fnum)
_, frame = self.capture.read()
# Get the tfnet result with the largest confidence and extract info.
bbox, label, confidence = self.get_tfnet_result(frame)
# Store label if result found, or (-1) if no result was found.
if label:
dirty_timeline.append(LABELS_LIST.index(label))
bbox_hist.append(bbox)
else:
dirty_timeline.append(-1)
bbox_hist.append(-1)
# Display the frame if show_flag is enabled. Exit if q pressed.
if self.show_flag:
if confidence:
text = '{}: {:.0f}%'.format(label, confidence * 100)
util.show_frame(frame, bbox_list=[bbox], text=text,
save_flag=self.save_flag,
save_name="output/{:07d}.png".format(fnum))
else:
util.show_frame(frame, save_flag=self.save_flag,
save_name="output/{:07d}.png".format(fnum))
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# End the TfNet session and display time taken to complete.
util.display_fps(start_time, len(dirty_timeline), "Initial Sweep")
# Fill holes in the history timeline list, and filter out timeline
# sections that are smaller than a particular size.
clean_timeline = timeline.fill_filter(dirty_timeline,
self.timeline_empty_thresh)
clean_timeline = timeline.size_filter(clean_timeline,
self.step_size, self.min_match_length_s)
timeline.show_plots(dirty_timeline, clean_timeline, LABELS_LIST)
# Get a list of the matches and avg bboxes according to clean_timeline.
match_ranges = timeline.get_ranges(clean_timeline)
match_bboxes = self.get_match_bboxes(match_ranges, bbox_hist)
# Show the beginning and end of each match according to the filters.
display_frames, display_bboxes = list(), list()
for i, match_range in enumerate(match_ranges):
display_frames += [match_range[0]*self.step_size,
match_range[1]*self.step_size]
display_bboxes += [match_bboxes[i], match_bboxes[i]]
util.show_frames(self.capture, display_frames, display_bboxes)
#### STAGE DETECTOR INTERNAL METHODS #######################################
# Return the tfnet prediction with the highest confidence.
def get_tfnet_result(self, frame):
results = self.tfnet.return_predict(frame)
result = dict()
bbox, label, confidence = None, None, None
max_confidence = 0
for result_iter in results:
if result_iter["confidence"] > max_confidence:
result = result_iter
max_confidence = result_iter["confidence"]
if result:
tl = (result['topleft']['x'], result['topleft']['y'])
br = (result['bottomright']['x'], result['bottomright']['y'])
bbox = (tl, br)
label = result['label']
confidence = result['confidence']
return bbox, label, confidence
# Given match ranges and bounding box history, return a list of the average
# bounding box (top left and bottom right coordinate pair) of each match.
def get_match_bboxes(self, match_ranges, bbox_hist):
match_bboxes = list()
for mr in match_ranges:
avg_bbox = util.get_avg_bbox(bbox_hist[mr[0]: mr[1]])
match_bboxes.append(avg_bbox)
return match_bboxes
# Given a list of predicted labels, determine the label that accurately
# represents the match. If there is too much variance or no label was found
# assume that a highlight reel or poor quality stream was given.
def get_match_label(self, label_list, match_range):
# If no stages were found, return and declare a failure.
if not label_list:
print("\tUnidentifiable Match Range: {}".format(match_range))
return None
# If there is too much variance (multiple labels found), return.
if len(set(label_list)) > 1:
print("\tRemoved Match Range: {}".format(match_range))
return "multiple_stages_found"
# Find the label that occured the most during the tested frames.
match_label = max(set(label_list), key=label_list.count)
return match_label
#### STAGE DETECTOR EXTERNAL METHODS #######################################
# Given a list of match ranges, randomly select frames from each range and
# return the average bounding box and expected label for each match. Also
# return an updated list of the match ranges, with matches removed where
# multiple stages were found, aka highlight reels.
def get_match_info(self, match_ranges):
# For each match range, generated random frame numbers to search.
new_match_ranges, match_bboxes, match_labels = list(), list(), list()
for match_range in match_ranges:
random_fnum_list = np.random.randint(low=match_range[0],
high=match_range[1], size=self.num_match_frames)
bbox_list, label_list = list(), list()
# Find the labels for the random frame numbers selected.
for random_fnum in random_fnum_list:
self.capture.set(cv2.CAP_PROP_POS_FRAMES, random_fnum)
_, frame = self.capture.read()
bbox, label, _ = self.get_tfnet_result(frame)
if label:
bbox_list.append(bbox)
label_list.append(label)
# Attempt to find a stage if none was found in the initial search.
if not label_list:
for _ in range(self.max_num_match_frames):
fnum = np.random.randint(match_range[0], match_range[1])
self.capture.set(cv2.CAP_PROP_POS_FRAMES, fnum)
_, frame = self.capture.read()
bbox, label, _ = self.get_tfnet_result(frame)
if label:
bbox_list, label_list = [bbox], [label]
break
# Find the label that occured the most during the tested frames. If
# zero or multiple stages were found, declare failure.
match_label = self.get_match_label(label_list, match_range)
if match_label is None:
return None, None
if match_label == "multiple_stages_found":
continue
new_match_ranges.append(match_range)
match_bboxes.append(util.get_avg_bbox(bbox_list))
match_labels.append(match_label)
return new_match_ranges, match_bboxes, match_labels