-
Notifications
You must be signed in to change notification settings - Fork 548
/
trt_modnet.py
170 lines (138 loc) · 5.4 KB
/
trt_modnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""trt_modnet.py
This script demonstrates how to do real-time "image matting" with
TensorRT optimized MODNet engine.
"""
import argparse
import numpy as np
import cv2
import pycuda.autoinit # This is needed for initializing CUDA driver
from utils.camera import add_camera_args, Camera
from utils.writer import get_video_writer
from utils.background import Background
from utils.display import open_window, show_fps
from utils.display import FpsCalculator, ScreenToggler
from utils.modnet import TrtMODNet
WINDOW_NAME = 'TrtMODNetDemo'
def parse_args():
"""Parse input arguments."""
desc = ('Capture and display live camera video, while doing '
'real-time image matting with TensorRT optimized MODNet')
parser = argparse.ArgumentParser(description=desc)
parser = add_camera_args(parser)
parser.add_argument(
'--background', type=str, default='',
help='background image or video file name [None]')
parser.add_argument(
'--create_video', type=str, default='',
help='create output video (either .ts or .mp4) [None]')
parser.add_argument(
'--demo_mode', action='store_true',
help='run the program in a special "demo mode" [False]')
args = parser.parse_args()
return args
class BackgroundBlender():
"""BackgroundBlender
# Arguments
demo_mode: if True, do foreground/background blending in a
special "demo mode" which alternates among the
original, replaced and black backgrounds.
"""
def __init__(self, demo_mode=False):
self.demo_mode = demo_mode
self.count = 0
def blend(self, img, bg, matte):
"""Blend foreground and background using the 'matte'.
# Arguments
img: uint8 np.array of shape (H, W, 3), the foreground image
bg: uint8 np.array of shape (H, W, 3), the background image
matte: float32 np.array of shape (H, W), values between 0.0 and 1.0
"""
if self.demo_mode:
img, bg, matte = self._mod_for_demo(img, bg, matte)
return (img * matte[..., np.newaxis] +
bg * (1 - matte[..., np.newaxis])).astype(np.uint8)
def _mod_for_demo(self, img, bg, matte):
"""Modify img, bg and matte for "demo mode"
# Demo script (based on "count")
0~ 59: black background left to right
60~119: black background only
120~179: replaced background left to right
180~239: replaced background
240~299: original background left to right
300~359: original background
"""
img_h, img_w, _ = img.shape
if self.count < 120:
bg = np.zeros(bg.shape, dtype=np.uint8)
if self.count < 60:
offset = int(img_w * self.count / 59)
matte[:, offset:img_w] = 1.0
elif self.count < 240:
if self.count < 180:
offset = int(img_w * (self.count - 120) / 59)
bg[:, offset:img_w, :] = 0
else:
if self.count < 300:
offset = int(img_w * (self.count - 240) / 59)
matte[:, 0:offset] = 1.0
else:
matte[:, :] = 1.0
self.count = (self.count + 1) % 360
return img, bg, matte
class TrtMODNetRunner():
"""TrtMODNetRunner
# Arguments
modnet: TrtMODNet instance
cam: Camera object (for reading foreground images)
bggen: background generator (for reading background images)
blender: BackgroundBlender object
writer: VideoWriter object (for saving output video)
"""
def __init__(self, modnet, cam, bggen, blender, writer=None):
self.modnet = modnet
self.cam = cam
self.bggen = bggen
self.blender = blender
self.writer = writer
open_window(
WINDOW_NAME, 'TensorRT MODNet Demo', cam.img_width, cam.img_height)
def run(self):
"""Get img and bg, infer matte, blend and show img, then repeat."""
scrn_tog = ScreenToggler()
fps_calc = FpsCalculator()
while True:
if cv2.getWindowProperty(WINDOW_NAME, 0) < 0: break
img, bg = self.cam.read(), self.bggen.read()
if img is None: break
matte = self.modnet.infer(img)
matted_img = self.blender.blend(img, bg, matte)
fps = fps_calc.update()
matted_img = show_fps(matted_img, fps)
if self.writer: self.writer.write(matted_img)
cv2.imshow(WINDOW_NAME, matted_img)
key = cv2.waitKey(1)
if key == ord('F') or key == ord('f'): # Toggle fullscreen
scrn_tog.toggle()
elif key == 27: # ESC key: quit
break
def __del__(self):
cv2.destroyAllWindows()
def main():
args = parse_args()
cam = Camera(args)
if not cam.isOpened():
raise SystemExit('ERROR: failed to open camera!')
writer = None
if args.create_video:
writer = get_video_writer(
args.create_video, cam.img_width, cam.img_height)
modnet = TrtMODNet()
bggen = Background(args.background, cam.img_width, cam.img_height)
blender = BackgroundBlender(args.demo_mode)
runner = TrtMODNetRunner(modnet, cam, bggen, blender, writer)
runner.run()
if writer:
writer.release()
cam.release()
if __name__ == '__main__':
main()