@@ -208,6 +208,56 @@ def detect_image(self, image, crop = False, count = False):
208
208
del draw
209
209
210
210
return image
211
+
212
+ def get_FPS (self , image , test_interval ):
213
+ image_shape = np .array (np .shape (image )[0 :2 ])
214
+ #---------------------------------------------------------#
215
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
216
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
217
+ #---------------------------------------------------------#
218
+ image = cvtColor (image )
219
+ #---------------------------------------------------------#
220
+ # 给图像增加灰条,实现不失真的resize
221
+ # 也可以直接resize进行识别
222
+ #---------------------------------------------------------#
223
+ image_data = resize_image (image , (self .input_shape [1 ],self .input_shape [0 ]), self .letterbox_image )
224
+ #---------------------------------------------------------#
225
+ # 添加上batch_size维度
226
+ #---------------------------------------------------------#
227
+ image_data = np .expand_dims (np .transpose (preprocess_input (np .array (image_data , dtype = 'float32' )), (2 , 0 , 1 )), 0 )
228
+
229
+ with torch .no_grad ():
230
+ images = torch .from_numpy (image_data )
231
+ if self .cuda :
232
+ images = images .cuda ()
233
+ #---------------------------------------------------------#
234
+ # 将图像输入网络当中进行预测!
235
+ #---------------------------------------------------------#
236
+ outputs = self .net (images )
237
+ outputs = decode_outputs (outputs , self .input_shape )
238
+ #---------------------------------------------------------#
239
+ # 将预测框进行堆叠,然后进行非极大抑制
240
+ #---------------------------------------------------------#
241
+ results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
242
+ image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
243
+
244
+ t1 = time .time ()
245
+ for _ in range (test_interval ):
246
+ with torch .no_grad ():
247
+ #---------------------------------------------------------#
248
+ # 将图像输入网络当中进行预测!
249
+ #---------------------------------------------------------#
250
+ outputs = self .net (images )
251
+ outputs = decode_outputs (outputs , self .input_shape )
252
+ #---------------------------------------------------------#
253
+ # 将预测框进行堆叠,然后进行非极大抑制
254
+ #---------------------------------------------------------#
255
+ results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
256
+ image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
257
+
258
+ t2 = time .time ()
259
+ tact_time = (t2 - t1 ) / test_interval
260
+ return tact_time
211
261
212
262
def detect_heatmap (self , image , heatmap_save_path ):
213
263
import cv2
@@ -265,56 +315,6 @@ def sigmoid(x):
265
315
plt .savefig (heatmap_save_path , dpi = 200 )
266
316
print ("Save to the " + heatmap_save_path )
267
317
plt .cla ()
268
-
269
- def get_FPS (self , image , test_interval ):
270
- image_shape = np .array (np .shape (image )[0 :2 ])
271
- #---------------------------------------------------------#
272
- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
273
- # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
274
- #---------------------------------------------------------#
275
- image = cvtColor (image )
276
- #---------------------------------------------------------#
277
- # 给图像增加灰条,实现不失真的resize
278
- # 也可以直接resize进行识别
279
- #---------------------------------------------------------#
280
- image_data = resize_image (image , (self .input_shape [1 ],self .input_shape [0 ]), self .letterbox_image )
281
- #---------------------------------------------------------#
282
- # 添加上batch_size维度
283
- #---------------------------------------------------------#
284
- image_data = np .expand_dims (np .transpose (preprocess_input (np .array (image_data , dtype = 'float32' )), (2 , 0 , 1 )), 0 )
285
-
286
- with torch .no_grad ():
287
- images = torch .from_numpy (image_data )
288
- if self .cuda :
289
- images = images .cuda ()
290
- #---------------------------------------------------------#
291
- # 将图像输入网络当中进行预测!
292
- #---------------------------------------------------------#
293
- outputs = self .net (images )
294
- outputs = decode_outputs (outputs , self .input_shape )
295
- #---------------------------------------------------------#
296
- # 将预测框进行堆叠,然后进行非极大抑制
297
- #---------------------------------------------------------#
298
- results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
299
- image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
300
-
301
- t1 = time .time ()
302
- for _ in range (test_interval ):
303
- with torch .no_grad ():
304
- #---------------------------------------------------------#
305
- # 将图像输入网络当中进行预测!
306
- #---------------------------------------------------------#
307
- outputs = self .net (images )
308
- outputs = decode_outputs (outputs , self .input_shape )
309
- #---------------------------------------------------------#
310
- # 将预测框进行堆叠,然后进行非极大抑制
311
- #---------------------------------------------------------#
312
- results = non_max_suppression (outputs , self .num_classes , self .input_shape ,
313
- image_shape , self .letterbox_image , conf_thres = self .confidence , nms_thres = self .nms_iou )
314
-
315
- t2 = time .time ()
316
- tact_time = (t2 - t1 ) / test_interval
317
- return tact_time
318
318
319
319
def convert_to_onnx (self , simplify , model_path ):
320
320
import onnx
0 commit comments