This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathkeras_to_tensorflow.patch
64 lines (59 loc) · 2.69 KB
/
keras_to_tensorflow.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
diff --git a/tools/model_converter/keras_to_tensorflow.py b/tools/model_converter/keras_to_tensorflow.py
index 07865a8..b3817af 100644
--- a/tools/model_converter/keras_to_tensorflow.py
+++ b/tools/model_converter/keras_to_tensorflow.py
@@ -23,6 +23,10 @@ from tensorflow.keras.models import model_from_json, model_from_yaml, load_model
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..'))
from common.utils import get_custom_objects
+# tf.enable_eager_execution()
+tf.compat.v1.disable_eager_execution()
+from tensorflow.python.keras.backend import get_session
+
K.set_learning_phase(0)
FLAGS = flags.FLAGS
@@ -53,6 +57,8 @@ flags.DEFINE_boolean('output_meta_ckpt', False,
'If set to True, exports the model as .meta, .index, and '
'.data files, with a checkpoint file. These can be later '
'loaded in TensorFlow to continue training.')
+flags.DEFINE_boolean('saved_model', False,
+ 'If set, model is saved in saved model format')
flags.mark_flag_as_required('input_model')
flags.mark_flag_as_required('output_model')
@@ -150,7 +156,7 @@ def main(args):
logging.info('Converted output node names are: %s',
str(converted_output_node_names))
- sess = K.get_session()
+ sess = get_session()
if FLAGS.output_meta_ckpt:
saver = tf.train.Saver()
saver.save(sess, str(output_fld / output_model_stem))
@@ -172,16 +178,20 @@ def main(args):
transformed_graph_def,
converted_output_node_names)
else:
- constant_graph = graph_util.convert_variables_to_constants(
- sess,
- sess.graph.as_graph_def(),
- converted_output_node_names)
-
- graph_io.write_graph(constant_graph, str(output_fld), output_model_name,
- as_text=False)
- logging.info('Saved the freezed graph at %s',
- str(Path(output_fld) / output_model_name))
-
+ if FLAGS.saved_model:
+ tf.saved_model.save(model, output_model)
+ logging.info("Saved model in TF2 saved model format at %s",
+ str(Path(output_fld) / output_model_name))
+ else:
+ constant_graph = graph_util.convert_variables_to_constants(
+ sess,
+ sess.graph.as_graph_def(),
+ converted_output_node_names)
+
+ graph_io.write_graph(constant_graph, str(output_fld), output_model_name,
+ as_text=False)
+ logging.info('Saved the freezed graph at %s',
+ str(Path(output_fld) / output_model_name))
if __name__ == "__main__":
app.run(main)