-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathutil_keras-h5-model_to-tensorflow-pb_to-nvinfer-uff.py
310 lines (234 loc) · 13 KB
/
util_keras-h5-model_to-tensorflow-pb_to-nvinfer-uff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
#-------------------------------------------------------------------------------#
# Utility to convert Keras model to Tensorflow's .PB (proto-binary) and then to
# Nvidia libnvinfer's uff format. With UFF one can execute models on
# TensorRT compatible devices like TX2.
#
# Author : Manohar Kuse <mpkuse@connect.ust.hk>
# Created: 29th May, 2019
# Site : https://kusemanohar.wordpress.com/2019/05/25/hands-on-tensorrt-on-nvidiatx2/
#-------------------------------------------------------------------------------#
import keras
import numpy as np
import os
import tensorflow as tf
from CustomNets import NetVLADLayer, GhostVLADLayer
from predict_utils import change_model_inputshape
from keras import backend as K
import TerminalColors
tcol = TerminalColors.bcolors()
import argparse
def load_keras_hdf5_model( kerasmodel_h5file, verbose=True ):
""" Loads keras model from a HDF5 file """
assert os.path.isfile( kerasmodel_h5file ), 'The model weights file doesnot exists or there is a permission issue.'+"kerasmodel_file="+kerasmodel_h5file
K.set_learning_phase(0)
model = keras.models.load_model(kerasmodel_h5file, custom_objects={'NetVLADLayer': NetVLADLayer, 'GhostVLADLayer': GhostVLADLayer} )
if verbose:
model.summary();
print tcol.OKGREEN, 'Successfully Loaded kerasmodel_h5file: ', tcol.ENDC, kerasmodel_h5file
return model
def load_basic_model( ):
K.set_learning_phase(0)
from CustomNets import make_from_mobilenet, make_from_vgg16
from CustomNets import NetVLADLayer, GhostVLADLayer
# Please choose only one of these.
if False: # VGG
input_img = keras.layers.Input( shape=(240, 320, 3 ) )
cnn = make_from_vgg16( input_img, weights=None, layer_name='block5_pool', kernel_regularizer=keras.regularizers.l2(0.01) )
model = keras.models.Model( inputs=input_img, outputs=cnn )
if True: #mobilenet
input_img = keras.layers.Input( shape=(240, 320, 3 ) )
cnn = make_from_mobilenet( input_img, layer_name='conv_pw_5_relu', weights=None, kernel_regularizer=keras.regularizers.l2(0.01) )
model = keras.models.Model( inputs=input_img, outputs=cnn )
if False: #mobilenet+netvlad
input_img = keras.layers.Input( shape=(240, 320, 3 ) )
cnn = make_from_mobilenet( input_img, layer_name='conv_pw_5_relu', weights=None, kernel_regularizer=keras.regularizers.l2(0.01) )
# cnn = make_from_vgg16( input_img, weights=None, layer_name='block5_pool', kernel_regularizer=keras.regularizers.l2(0.01) )
out = NetVLADLayer(num_clusters = 16)( cnn )
model = keras.models.Model( inputs=input_img, outputs=out )
if False: #netvlad only
input_img = keras.layers.Input( shape=(60, 80, 256 ) )
out = NetVLADLayer(num_clusters = 16)( input_img )
model = keras.models.Model( inputs=input_img, outputs=out )
model.summary()
return model
def write_kerasmodel_as_tensorflow_pb( model, LOG_DIR, output_model_name='output_model.pb' ):
""" Takes as input a keras.models.Model() and writes out
Tensorflow proto-binary.
"""
print tcol.HEADER,'[write_kerasmodel_as_tensorflow_pb] Start', tcol.ENDC
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
K.set_learning_phase(0)
sess = K.get_session()
# Make const
print 'Make Computation Graph as Constant and Prune unnecessary stuff from it'
constant_graph = graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(),
[node.op.name for node in model.outputs])
constant_graph = tf.graph_util.remove_training_nodes(constant_graph)
#--- convert Switch --> Identity
# I am doing this because TensorRT cannot process Switch operations.
# # https://github.com/tensorflow/tensorflow/issues/8404#issuecomment-297469468
# for node in constant_graph.node:
# if node.op == "Switch":
# node.op = "Identity"
# del node.input[1]
# # END
# Write .pb
# output_model_name = 'output_model.pb'
print tcol.OKGREEN, 'Write ', output_model_name, tcol.ENDC
print 'model.outputs=', [node.op.name for node in model.outputs]
graph_io.write_graph(constant_graph, LOG_DIR, output_model_name,
as_text=False)
print tcol.HEADER, '[write_kerasmodel_as_tensorflow_pb] Done', tcol.ENDC
# Write .pbtxt (for viz only)
output_model_pbtxt_name = output_model_name+'.pbtxt' #'output_model.pbtxt'
print tcol.OKGREEN, 'Write ', output_model_pbtxt_name, tcol.ENDC
tf.train.write_graph(constant_graph, LOG_DIR,
output_model_pbtxt_name, as_text=True)
# Write model.summary to file (to get info on input and output shapes)
output_modelsummary_fname = LOG_DIR+'/'+output_model_name + '.modelsummary.log'
print tcol.OKGREEN, 'Write ', output_modelsummary_fname, tcol.ENDC
with open(output_modelsummary_fname,'w') as fh:
# Pass the file handle in as a lambda function to make it callable
model.summary(print_fn=lambda x: fh.write(x + '\n'))
def convert_to_uff( pb_input_fname, uff_output_fname ):
""" Uses Nvidia's `convert-to-uff` through os.system.
This will convert the .pb file (generated from call to `write_kerasmodel_as_tensorflow_pb` )
and write out .uff file.
usage: convert-to-uff [-h] [-l] [-t] [--write_preprocessed] [-q] [-d]
[-o OUTPUT] [-O OUTPUT_NODE] [-I INPUT_NODE]
[-p PREPROCESSOR]
input_file
Converts TensorFlow models to Unified Framework Format (UFF).
positional arguments:
input_file path to input model (protobuf file of frozen GraphDef)
optional arguments:
-h, --help show this help message and exit
-l, --list-nodes show list of nodes contained in input file
-t, --text write a text version of the output in addition to the
binary
--write_preprocessed write the preprocessed protobuf in addition to the
binary
-q, --quiet disable log messages
-d, --debug Enables debug mode to provide helpful debugging output
-o OUTPUT, --output OUTPUT
name of output uff file
-O OUTPUT_NODE, --output-node OUTPUT_NODE
name of output nodes of the model
-I INPUT_NODE, --input-node INPUT_NODE
name of a node to replace with an input to the model.
Must be specified as:
"name,new_name,dtype,dim1,dim2,..."
-p PREPROCESSOR, --preprocessor PREPROCESSOR
the preprocessing file to run before handling the
graph. This file must define a `preprocess` function
that accepts a GraphSurgeon DynamicGraph as it's
input. All transformations should happen in place on
the graph, as return values are discarded
"""
assert os.path.isfile( pb_input_fname ), "The .pb file="+str(pb_input_fname)+" does not exist"
cmd = 'convert-to-uff -t -o %s %s | tee %s' %(uff_output_fname, pb_input_fname, uff_output_fname+'.log')
print tcol.HEADER, '[bash run] ', cmd, tcol.ENDC
os.system( cmd )
print tcol.WARNING, 'If there are warning above like `No conversion function...`, this means that Nvidias UFF doesnt yet have certain function. Most like in this case your model cannot be run with tensorrt.', tcol.ENDC
def graphsurgeon_cleanup( LOG_DIR, input_model_name='output_model.pb', cleaned_model_name='output_model_aftersurgery.pb' ):
""" Loads the tensorflow frozen_graph and cleans up with nvidia's graphsurgeon
"""
assert os.path.isfile( LOG_DIR+'/'+input_model_name ), "[graphsurgeon_cleanup]The .pb file="+str(input_model_name)+" does not exist"
import graphsurgeon as gs
print tcol.HEADER, '[graphsurgeon_cleanup] graphsurgeon.__version__', gs.__version__, tcol.ENDC
DG = gs.DynamicGraph()
print tcol.OKGREEN, '[graphsurgeon_cleanup] READ tensorflow Graph using graphsurgeon.DynamicGraph: ', LOG_DIR+'/'+input_model_name, tcol.ENDC
DG.read( LOG_DIR+'/'+input_model_name )
# Remove control variable first
all_switch = DG.find_nodes_by_op( 'Switch' )
DG.forward_inputs( all_switch )
print 'Write (after graphsurgery) : ', LOG_DIR+'/'+cleaned_model_name
DG.write( LOG_DIR+'/'+cleaned_model_name )
if os.path.isdir( LOG_DIR+'/graphsurgeon_cleanup' ):
pass
else:
os.mkdir( LOG_DIR+'/graphsurgeon_cleanup')
DG.write_tensorboard( LOG_DIR+'/graphsurgeon_cleanup' )
# import code
# code.interact( local=locals() )
print tcol.HEADER, '[graphsurgeon_cleanup] END', tcol.ENDC
# def verify_generated_uff_with_tensorrt_uffparser( ufffilename, uffinput, uffinput_dims, uff_output ):
def verify_generated_uff_with_tensorrt_uffparser( ufffilename ):
""" Loads the UFF file with TensorRT (py). """
assert os.path.isfile( ufffilename ), "ufffilename="+ ufffilename+ ' doesnt exist'
import tensorrt as trt
print tcol.HEADER, '[verify_generated_uff_with_tensorrt_uffparser] TensorRT version=', trt.__version__, tcol.ENDC
try:
uffinput = "input_1"
uffinput_dims = (3,240,320)
# uffinput_dims = (256, 80,60)
uffoutput = "conv_pw_5_relu/Relu6"
# uffoutput = "net_vlad_layer_1/l2_normalize_1"
# uffoutput = "net_vlad_layer_1/add_1"
# uffoutput = "net_vlad_layer_1/Reshape_1"
TRT_LOGGER = trt.Logger( trt.Logger.WARNING)
with trt.Builder( TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
print 'ufffilename=', str( ufffilename)
print 'uffinput=', str( uffinput), '\t', 'uffinput_dims=', str( uffinput_dims)
print 'uffoutput=', str( uffoutput)
parser.register_input( uffinput, uffinput_dims )
parser.register_output( uffoutput )
parser.parse( ufffilename, network )
pass
print tcol.OKGREEN, '[verify_generated_uff_with_tensorrt_uffparser] Verified.....!', tcol.ENDC
except:
print tcol.FAIL, '[verify_generated_uff_with_tensorrt_uffparser] UFF file=', ufffilename, ' with uffinput=', uffinput , ' uffoutput=', uffoutput , ' cannot be parsed.'
if __name__ == '__main__':
#---
# Parse Command line
parser = argparse.ArgumentParser(description='Convert Keras hdf5 models to .uff models for TensorRT.')
parser.add_argument('--kerasmodel_h5file', '-h5', required=True, type=str, help='The input keras modelarch_and_weights full filename')
args = parser.parse_args()
#---
# Paths, File Init and other initialize
# kerasmodel_h5file = 'models.keras/June2019/centeredinput-m1to1-240x320x3__mobilenet-conv_pw_6_relu__K16__allpairloss/modelarch_and_weights.700.h5'
kerasmodel_h5file = args.kerasmodel_h5file
LOG_DIR = '/'.join( kerasmodel_h5file.split('/')[0:-1] )
print tcol.HEADER
print '##------------------------------------------------------------##'
print '## kerasmodel_h5file = ', kerasmodel_h5file
print '## LOG_DIR = ', LOG_DIR
print '##------------------------------------------------------------##'
print tcol.ENDC
#---
# Load HDF5 Keras model
model = load_keras_hdf5_model( kerasmodel_h5file, verbose=True ) #this
# model = load_basic_model()
# quit()
#-----
# Replace Input Layer's Dimensions
im_rows = None#480
im_cols = 752
im_chnls = 3
if im_rows == None or im_cols == None or im_chnls == None:
print tcol.WARNING, 'NOT doing `change_model_inputshape`', tcol.ENDC
new_model = model
else:
# change_model_inputshape uses model_from_json internally, I feel a bit uncomfortable about this.
new_model = change_model_inputshape( model, new_input_shape=(1,im_rows,im_cols,im_chnls), verbose=True )
print 'OLD MODEL: ', 'input_shape=', str(model.inputs)
print 'NEW MODEL: input_shape=', str(new_model.inputs)
#-----
# Write Tensorflow (atleast 1.12) proto-binary (.pb)
write_kerasmodel_as_tensorflow_pb( new_model, LOG_DIR=LOG_DIR, output_model_name='output_model.pb' )
# write_kerasmodel_as_tensorflow_pb( new_model, LOG_DIR=LOG_DIR, output_model_name=kerasmodel_h5file.split('/')[-1]+'.pb' )
#-----
# Clean up graph with Nvidia's graphsurgeon
# currently not in use but might come in handly later...maybe
# graphsurgeon_cleanup( LOG_DIR=LOG_DIR, input_model_name='output_model.pb', cleaned_model_name='output_model_aftersurgery.pb')
#-----
# Write UFF
convert_to_uff( pb_input_fname=LOG_DIR+'/output_model.pb', uff_output_fname=LOG_DIR+'/output_nvinfer.uff' )
# convert_to_uff( pb_input_fname=LOG_DIR+'/output_model_aftersurgery.pb', uff_output_fname=LOG_DIR+'/output_nvinfer.uff' )
#-----
# Try to load UFF with tensorrt
verify_generated_uff_with_tensorrt_uffparser( ufffilename=LOG_DIR+'/output_nvinfer.uff' )