forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_tflite_lstd_graph_lib.py
327 lines (293 loc) · 13.5 KB
/
export_tflite_lstd_graph_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
"""
import os
import tempfile
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from lstm_object_detection import model_builder
from object_detection import exporter
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import post_processing_builder
from object_detection.core import box_list
_DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4
def get_const_center_size_encoded_anchors(anchors):
"""Exports center-size encoded anchors as a constant tensor.
Args:
anchors: a float32 tensor of shape [num_anchors, 4] containing the anchor
boxes
Returns:
encoded_anchors: a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
"""
anchor_boxlist = box_list.BoxList(anchors)
y, x, h, w = anchor_boxlist.get_center_coordinates_and_sizes()
num_anchors = y.get_shape().as_list()
with tf.Session() as sess:
y_out, x_out, h_out, w_out = sess.run([y, x, h, w])
encoded_anchors = tf.constant(
np.transpose(np.stack((y_out, x_out, h_out, w_out))),
dtype=tf.float32,
shape=[num_anchors[0], _DEFAULT_NUM_COORD_BOX],
name='anchors')
return encoded_anchors
def append_postprocessing_op(frozen_graph_def,
max_detections,
max_classes_per_detection,
nms_score_threshold,
nms_iou_threshold,
num_classes,
scale_values,
detections_per_class=100,
use_regular_nms=False):
"""Appends postprocessing custom op.
Args:
frozen_graph_def: Frozen GraphDef for SSD model after freezing the
checkpoint
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
nms_score_threshold: Score threshold used in Non-maximal suppression in
post-processing
nms_iou_threshold: Intersection-over-union threshold used in Non-maximal
suppression in post-processing
num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
appended
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected
boxes
"""
new_output = frozen_graph_def.node.add()
new_output.op = 'TFLite_Detection_PostProcess'
new_output.name = 'TFLite_Detection_PostProcess'
new_output.attr['_output_quantized'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['_output_types'].list.type.extend([
types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT,
types_pb2.DT_FLOAT
])
new_output.attr['_support_output_type_float_in_quantized_op'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['max_detections'].CopyFrom(
attr_value_pb2.AttrValue(i=max_detections))
new_output.attr['max_classes_per_detection'].CopyFrom(
attr_value_pb2.AttrValue(i=max_classes_per_detection))
new_output.attr['nms_score_threshold'].CopyFrom(
attr_value_pb2.AttrValue(f=nms_score_threshold.pop()))
new_output.attr['nms_iou_threshold'].CopyFrom(
attr_value_pb2.AttrValue(f=nms_iou_threshold.pop()))
new_output.attr['num_classes'].CopyFrom(
attr_value_pb2.AttrValue(i=num_classes))
new_output.attr['y_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['y_scale'].pop()))
new_output.attr['x_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['x_scale'].pop()))
new_output.attr['h_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['h_scale'].pop()))
new_output.attr['w_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['w_scale'].pop()))
new_output.attr['detections_per_class'].CopyFrom(
attr_value_pb2.AttrValue(i=detections_per_class))
new_output.attr['use_regular_nms'].CopyFrom(
attr_value_pb2.AttrValue(b=use_regular_nms))
new_output.input.extend(
['raw_outputs/box_encodings', 'raw_outputs/class_predictions', 'anchors'])
# Transform the graph to append new postprocessing op
input_names = []
output_names = ['TFLite_Detection_PostProcess']
transforms = ['strip_unused_nodes']
transformed_graph_def = TransformGraph(frozen_graph_def, input_names,
output_names, transforms)
return transformed_graph_def
def export_tflite_graph(pipeline_config,
trained_checkpoint_prefix,
output_dir,
add_postprocessing_op,
max_detections,
max_classes_per_detection,
detections_per_class=100,
use_regular_nms=False,
binary_graph_name='tflite_graph.pb',
txt_graph_name='tflite_graph.pbtxt'):
"""Exports a tflite compatible graph and anchors for ssd detection model.
Anchors are written to a tensor and tflite compatible graph
is written to output_dir/tflite_graph.pb.
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`, `train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`. Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
Raises:
ValueError: if the pipeline config contains models other than ssd or uses an
fixed_shape_resizer and provides a shape as well.
"""
model_config = pipeline_config['model']
lstm_config = pipeline_config['lstm_model']
eval_config = pipeline_config['eval_config']
tf.gfile.MakeDirs(output_dir)
if model_config.WhichOneof('model') != 'ssd':
raise ValueError('Only ssd models are supported in tflite. '
'Found {} in config'.format(
model_config.WhichOneof('model')))
num_classes = model_config.ssd.num_classes
nms_score_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.score_threshold
}
nms_iou_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.iou_threshold
}
scale_values = {}
scale_values['y_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.y_scale
}
scale_values['x_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.x_scale
}
scale_values['h_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.height_scale
}
scale_values['w_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.width_scale
}
image_resizer_config = model_config.ssd.image_resizer
image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
num_channels = _DEFAULT_NUM_CHANNELS
if image_resizer == 'fixed_shape_resizer':
height = image_resizer_config.fixed_shape_resizer.height
width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
num_channels = 1
shape = [lstm_config.eval_unroll_length, height, width, num_channels]
else:
raise ValueError(
'Only fixed_shape_resizer'
'is supported with tflite. Found {}'.format(
image_resizer_config.WhichOneof('image_resizer_oneof')))
video_tensor = tf.placeholder(
tf.float32, shape=shape, name='input_video_tensor')
detection_model = model_builder.build(
model_config, lstm_config, is_training=False)
preprocessed_video, true_image_shapes = detection_model.preprocess(
tf.to_float(video_tensor))
predicted_tensors = detection_model.predict(preprocessed_video,
true_image_shapes)
# predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes)
# The score conversion occurs before the post-processing custom op
_, score_conversion_fn = post_processing_builder.build(
model_config.ssd.post_processing)
class_predictions = score_conversion_fn(
predicted_tensors['class_predictions_with_background'])
with tf.name_scope('raw_outputs'):
# 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
# containing the encoded box predictions. Note that these are raw
# predictions and no Non-Max suppression is applied on them and
# no decode center size boxes is applied to them.
tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
# 'raw_outputs/class_predictions': a float32 tensor of shape
# [1, num_anchors, num_classes] containing the class scores for each anchor
# after applying score conversion.
tf.identity(class_predictions, name='class_predictions')
# 'anchors': a float32 tensor of shape
# [4, num_anchors] containing the anchors as a constant node.
tf.identity(
get_const_center_size_encoded_anchors(predicted_tensors['anchors']),
name='anchors')
# Add global step to the graph, so we know the training step number when we
# evaluate the model.
tf.train.get_or_create_global_step()
# graph rewriter
is_quantized = ('graph_rewriter' in pipeline_config)
if is_quantized:
graph_rewriter_config = pipeline_config['graph_rewriter']
graph_rewriter_fn = graph_rewriter_builder.build(
graph_rewriter_config, is_training=False, is_export=True)
graph_rewriter_fn()
if model_config.ssd.feature_extractor.HasField('fpn'):
exporter.rewrite_nn_resize_op(is_quantized)
# freeze the graph
saver_kwargs = {}
if eval_config.use_moving_averages:
saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
moving_average_checkpoint = tempfile.NamedTemporaryFile()
exporter.replace_variable_values_with_moving_averages(
tf.get_default_graph(), trained_checkpoint_prefix,
moving_average_checkpoint.name)
checkpoint_to_use = moving_average_checkpoint.name
else:
checkpoint_to_use = trained_checkpoint_prefix
saver = tf.train.Saver(**saver_kwargs)
input_saver_def = saver.as_saver_def()
frozen_graph_def = exporter.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=checkpoint_to_use,
output_node_names=','.join([
'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
'anchors'
]),
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
output_graph='',
initializer_nodes='')
# Add new operation to do post processing in a custom op (TF Lite only)
if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection,
nms_score_threshold, nms_iou_threshold, num_classes, scale_values,
detections_per_class, use_regular_nms)
else:
# Return frozen without adding post-processing custom op
transformed_graph_def = frozen_graph_def
binary_graph = os.path.join(output_dir, binary_graph_name)
with tf.gfile.GFile(binary_graph, 'wb') as f:
f.write(transformed_graph_def.SerializeToString())
txt_graph = os.path.join(output_dir, txt_graph_name)
with tf.gfile.GFile(txt_graph, 'w') as f:
f.write(str(transformed_graph_def))