@@ -75,31 +75,36 @@ def test_save_correct_files(saver, output_dir):
75
75
assert expected_output_filenames == real_output_filenames
76
76
77
77
78
+ def test_save_several_times (saver , output_dir ):
79
+ n = 5
80
+ for _ in range (n ):
81
+ saver .save ()
82
+
83
+ expected_output_filenames = set (
84
+ [
85
+ join (output_dir , f"{ i } " , filename )
86
+ for i in range (n ) for filename in [
87
+ 'saved_model.pb' ,
88
+ 'variables' ,
89
+ 'variables/variables.data-00000-of-00001' ,
90
+ 'variables/variables.index' ,
91
+ ]
92
+ ],
93
+ )
94
+ real_output_filenames = set (
95
+ glob (output_dir + '/*/*' ) + # noqa: W504
96
+ glob (output_dir + '/*/*/*' ),
97
+ )
98
+ assert expected_output_filenames == real_output_filenames
99
+
100
+
78
101
def test_save_will_not_change_model (x_test , y_test , model , saver , output_dir ):
79
102
old_loss = model .evaluate (x_test , y_test )
80
103
saver .save ()
81
- new_loss = load_n_evaluate (model . sess , output_dir , x_test , y_test )
104
+ new_loss = load_n_evaluate (output_dir , x_test , y_test )
82
105
assert old_loss == new_loss
83
106
84
107
85
- def load_n_evaluate (sess , output_dir , x_test , y_test ):
86
- with tf .Session (graph = tf .Graph ()) as sess :
87
- meta_graph_def = tf .saved_model .loader .load (
88
- sess = sess ,
89
- tags = [tf .saved_model .tag_constants .SERVING ],
90
- export_dir = join (output_dir , '0' ),
91
- )
92
- evaluate_graph = meta_graph_def .signature_def ['evaluate' ]
93
- loss = sess .run (
94
- evaluate_graph .outputs ['loss' ].name ,
95
- feed_dict = {
96
- evaluate_graph .inputs ['x' ].name : x_test ,
97
- evaluate_graph .inputs ['y' ].name : y_test ,
98
- },
99
- )
100
- return loss
101
-
102
-
103
108
def test_frozen_save_correct_files (saver_with_freeze , output_dir ):
104
109
saver_with_freeze .save ()
105
110
expected_output_filenames = set (
@@ -116,19 +121,48 @@ def test_frozen_save_correct_files(saver_with_freeze, output_dir):
116
121
)
117
122
assert expected_output_filenames == real_output_filenames
118
123
119
- # def test_freeze_graph_has_session_update(self):
120
- # old_sess = saver.session
121
- # saver.freeze_graph()
122
- # new_sess = saver.session
123
- # assertNotEqual(old_sess, new_sess)
124
-
125
- # def test_freeze_graph_will_not_change_loss(self):
126
- # old_loss = model.evaluate(x_test, y_test)
127
- # saver.freeze_graph()
128
- # model.sess = saver.session
129
- # new_loss = model.evaluate(x_test, y_test)
130
- # assertEqual(old_loss, new_loss)
131
-
132
- # def test_freeze_n_save(self):
133
- # saver.freeze_graph()
134
- # saver.save()
124
+
125
+ def test_frozen_save_several_times (saver_with_freeze , output_dir ):
126
+ n = 5
127
+ for _ in range (n ):
128
+ saver_with_freeze .save ()
129
+
130
+ expected_output_filenames = set (
131
+ [
132
+ join (output_dir , f"{ i } " , filename )
133
+ for i in range (n ) for filename in [
134
+ 'saved_model.pb' ,
135
+ 'variables' ,
136
+ ]
137
+ ],
138
+ )
139
+ real_output_filenames = set (
140
+ glob (output_dir + '/*/*' ) + # noqa: W504
141
+ glob (output_dir + '/*/*/*' ),
142
+ )
143
+ assert expected_output_filenames == real_output_filenames
144
+
145
+
146
+ def test_frozen_save_will_not_change_model (x_test , y_test , model , saver_with_freeze , output_dir ):
147
+ old_loss = model .evaluate (x_test , y_test )
148
+ saver_with_freeze .save ()
149
+ new_loss = load_n_evaluate (output_dir , x_test , y_test )
150
+ assert old_loss == new_loss
151
+
152
+
153
+ def load_n_evaluate (output_dir , x_test , y_test ):
154
+ with tf .Session (graph = tf .Graph ()) as sess :
155
+ meta_graph_def = tf .saved_model .loader .load (
156
+ sess = sess ,
157
+ tags = [tf .saved_model .tag_constants .SERVING ],
158
+ export_dir = join (output_dir , '0' ),
159
+ )
160
+ evaluate_graph = meta_graph_def .signature_def ['evaluate' ]
161
+ loss = sess .run (
162
+ evaluate_graph .outputs ['loss' ].name ,
163
+ feed_dict = {
164
+ evaluate_graph .inputs ['x' ].name : x_test ,
165
+ evaluate_graph .inputs ['y' ].name : y_test ,
166
+ },
167
+ )
168
+ return loss
0 commit comments