-
Notifications
You must be signed in to change notification settings - Fork 101
/
export_inference_graph.py
70 lines (51 loc) · 2.32 KB
/
export_inference_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Export inference graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
import deeplab_model
from utils import preprocessing
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default='./model',
help="Base directory for the model. "
"Make sure 'model_checkpoint_path' given in 'checkpoint' file matches "
"with checkpoint name.")
parser.add_argument('--export_dir', type=str, default='dataset/export_output',
help='The directory where the exported SavedModel will be stored.')
parser.add_argument('--base_architecture', type=str, default='resnet_v2_101',
choices=['resnet_v2_50', 'resnet_v2_101'],
help='The architecture of base Resnet building block.')
parser.add_argument('--output_stride', type=int, default=16,
choices=[8, 16],
help='Output stride for DeepLab v3. Currently 8 or 16 is supported.')
_NUM_CLASSES = 21
def main(unused_argv):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
model = tf.estimator.Estimator(
model_fn=deeplab_model.deeplabv3_model_fn,
model_dir=FLAGS.model_dir,
params={
'output_stride': FLAGS.output_stride,
'batch_size': 1, # Batch size must be 1 because the images' size may differ
'base_architecture': FLAGS.base_architecture,
'pre_trained_model': None,
'batch_norm_decay': None,
'num_classes': _NUM_CLASSES,
})
# Export the model
def serving_input_receiver_fn():
image = tf.placeholder(tf.float32, [None, None, None, 3], name='image_tensor')
receiver_tensors = {'image': image}
features = tf.map_fn(preprocessing.mean_image_subtraction, image)
return tf.estimator.export.ServingInputReceiver(
features=features,
receiver_tensors=receiver_tensors)
model.export_savedmodel(FLAGS.export_dir, serving_input_receiver_fn)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)