Skip to content

Commit 1448e84

Browse files
committed
fix heatmap bugs
1 parent fa32e9f commit 1448e84

File tree

1 file changed

+50
-50
lines changed

1 file changed

+50
-50
lines changed

yolo.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,56 @@ def detect_image(self, image, crop = False, count = False):
208208
del draw
209209

210210
return image
211+
212+
def get_FPS(self, image, test_interval):
213+
image_shape = np.array(np.shape(image)[0:2])
214+
#---------------------------------------------------------#
215+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
216+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
217+
#---------------------------------------------------------#
218+
image = cvtColor(image)
219+
#---------------------------------------------------------#
220+
# 给图像增加灰条,实现不失真的resize
221+
# 也可以直接resize进行识别
222+
#---------------------------------------------------------#
223+
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
224+
#---------------------------------------------------------#
225+
# 添加上batch_size维度
226+
#---------------------------------------------------------#
227+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
228+
229+
with torch.no_grad():
230+
images = torch.from_numpy(image_data)
231+
if self.cuda:
232+
images = images.cuda()
233+
#---------------------------------------------------------#
234+
# 将图像输入网络当中进行预测!
235+
#---------------------------------------------------------#
236+
outputs = self.net(images)
237+
outputs = decode_outputs(outputs, self.input_shape)
238+
#---------------------------------------------------------#
239+
# 将预测框进行堆叠,然后进行非极大抑制
240+
#---------------------------------------------------------#
241+
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
242+
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
243+
244+
t1 = time.time()
245+
for _ in range(test_interval):
246+
with torch.no_grad():
247+
#---------------------------------------------------------#
248+
# 将图像输入网络当中进行预测!
249+
#---------------------------------------------------------#
250+
outputs = self.net(images)
251+
outputs = decode_outputs(outputs, self.input_shape)
252+
#---------------------------------------------------------#
253+
# 将预测框进行堆叠,然后进行非极大抑制
254+
#---------------------------------------------------------#
255+
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
256+
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
257+
258+
t2 = time.time()
259+
tact_time = (t2 - t1) / test_interval
260+
return tact_time
211261

212262
def detect_heatmap(self, image, heatmap_save_path):
213263
import cv2
@@ -265,56 +315,6 @@ def sigmoid(x):
265315
plt.savefig(heatmap_save_path, dpi=200)
266316
print("Save to the " + heatmap_save_path)
267317
plt.cla()
268-
269-
def get_FPS(self, image, test_interval):
270-
image_shape = np.array(np.shape(image)[0:2])
271-
#---------------------------------------------------------#
272-
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
273-
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
274-
#---------------------------------------------------------#
275-
image = cvtColor(image)
276-
#---------------------------------------------------------#
277-
# 给图像增加灰条,实现不失真的resize
278-
# 也可以直接resize进行识别
279-
#---------------------------------------------------------#
280-
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
281-
#---------------------------------------------------------#
282-
# 添加上batch_size维度
283-
#---------------------------------------------------------#
284-
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
285-
286-
with torch.no_grad():
287-
images = torch.from_numpy(image_data)
288-
if self.cuda:
289-
images = images.cuda()
290-
#---------------------------------------------------------#
291-
# 将图像输入网络当中进行预测!
292-
#---------------------------------------------------------#
293-
outputs = self.net(images)
294-
outputs = decode_outputs(outputs, self.input_shape)
295-
#---------------------------------------------------------#
296-
# 将预测框进行堆叠,然后进行非极大抑制
297-
#---------------------------------------------------------#
298-
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
299-
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
300-
301-
t1 = time.time()
302-
for _ in range(test_interval):
303-
with torch.no_grad():
304-
#---------------------------------------------------------#
305-
# 将图像输入网络当中进行预测!
306-
#---------------------------------------------------------#
307-
outputs = self.net(images)
308-
outputs = decode_outputs(outputs, self.input_shape)
309-
#---------------------------------------------------------#
310-
# 将预测框进行堆叠,然后进行非极大抑制
311-
#---------------------------------------------------------#
312-
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
313-
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
314-
315-
t2 = time.time()
316-
tact_time = (t2 - t1) / test_interval
317-
return tact_time
318318

319319
def convert_to_onnx(self, simplify, model_path):
320320
import onnx

0 commit comments

Comments
 (0)