-
Notifications
You must be signed in to change notification settings - Fork 13
/
demo.py
executable file
·121 lines (103 loc) · 4.79 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
import mxnet as mx
import argparse
import cv2
import time
from core.symbol import P_Net, R_Net, O_Net
from core.imdb import IMDB
from config import config
from core.loader import TestLoader
from core.detector import Detector
from core.fcn_detector import FcnDetector
from tools.load_model import load_param
from core.MtcnnDetector import MtcnnDetector
def test_net(prefix, epoch, batch_size, ctx,
thresh=[0.6, 0.6, 0.7], min_face_size=24,
stride=2, slide_window=False):
detectors = [None, None, None]
# load pnet model
args, auxs = load_param(prefix[0], epoch[0], convert=True, ctx=ctx)
if slide_window:
PNet = Detector(P_Net("test"), 12, batch_size[0], ctx, args, auxs)
else:
PNet = FcnDetector(P_Net("test"), ctx, args, auxs)
detectors[0] = PNet
# load rnet model
args, auxs = load_param(prefix[1], epoch[0], convert=True, ctx=ctx)
RNet = Detector(R_Net("test"), 24, batch_size[1], ctx, args, auxs)
detectors[1] = RNet
# load onet model
args, auxs = load_param(prefix[2], epoch[2], convert=True, ctx=ctx)
ONet = Detector(O_Net("test"), 48, batch_size[2], ctx, args, auxs)
detectors[2] = ONet
mtcnn_detector = MtcnnDetector(detectors=detectors, ctx=ctx, min_face_size=min_face_size,
stride=stride, threshold=thresh, slide_window=slide_window)
video_capture = cv2.VideoCapture(0)
boxes = []
boxes_c = []
while True :
#img = cv2.imread('/home/zzg/Opensource/mtcnn-master/data/custom/02.jpg')
_, img = video_capture.read()
t1 = time.time()
boxes, boxes_c = mtcnn_detector.detect_pnet(img)
if boxes_c is None :
continue
boxes, boxes_c = mtcnn_detector.detect_rnet(img, boxes_c)
if boxes_c is None:
continue
boxes, boxes_c = mtcnn_detector.detect_onet(img, boxes_c)
print 'time: ',time.time() - t1
if boxes_c is not None:
draw = img.copy()
font = cv2.FONT_HERSHEY_SIMPLEX
for b in boxes_c:
cv2.rectangle(draw, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (0, 255, 255), 1)
#cv2.putText(draw, '%.3f'%b[4], (int(b[0]), int(b[1])), font, 0.4, (255, 255, 255), 1)
cv2.imshow("detection result", draw)
#cv2.imwrite("o12.jpg",draw)
cv2.waitKey(10)
# img = cv2.imread('/home/zzg/Opensource/mtcnn-master/data/custom/03.jpg')
# t1 = time.time()
# boxes, boxes_c = mtcnn_detector.detect_pnet(img)
# boxes, boxes_c = mtcnn_detector.detect_rnet(img, boxes_c)
# boxes, boxes_c = mtcnn_detector.detect_onet(img, boxes_c)
# print 'time: ',time.time() - t1
# if boxes_c is not None:
# draw = img.copy()
# font = cv2.FONT_HERSHEY_SIMPLEX
# for b in boxes_c:
# cv2.rectangle(draw, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (0, 255, 255), 1)
# #cv2.putText(draw, '%.3f'%b[4], (int(b[0]), int(b[1])), font, 0.4, (255, 255, 255), 1)
# cv2.imshow("detection result", draw)
# cv2.imwrite("o12.jpg",draw)
# cv2.waitKey(10)
def parse_args():
parser = argparse.ArgumentParser(description='Test mtcnn',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--prefix', dest='prefix', help='prefix of model name', nargs="+",
default=['model/pnet', 'model/rnet', 'model/onet'], type=str)
parser.add_argument('--epoch', dest='epoch', help='epoch number of model to load', nargs="+",
default=[16, 16, 16], type=int)
parser.add_argument('--batch_size', dest='batch_size', help='list of batch size used in prediction', nargs="+",
default=[2048, 256, 16], type=int)
parser.add_argument('--thresh', dest='thresh', help='list of thresh for pnet, rnet, onet', nargs="+",
default=[0.8, 0.9, 0.9], type=float)
parser.add_argument('--min_face', dest='min_face', help='minimum face size for detection',
default=40, type=int)
parser.add_argument('--stride', dest='stride', help='stride of sliding window',
default=2, type=int)
parser.add_argument('--sw', dest='slide_window', help='use sliding window in pnet', action='store_true')
parser.add_argument('--gpu', dest='gpu_id', help='GPU device to train with',
default=-1, type=int)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print 'Called with argument:'
print args
ctx = mx.gpu(args.gpu_id)
if args.gpu_id == -1:
ctx = mx.cpu(0)
test_net(args.prefix, args.epoch, args.batch_size,
ctx, args.thresh, args.min_face,
args.stride, args.slide_window)