Skip to content

Commit b1bf64f

Browse files
authored
Merge pull request #31 from Yoctol/more_test_cases
case: save more than one time
2 parents 43b209d + c61f1b4 commit b1bf64f

File tree

1 file changed

+69
-35
lines changed

1 file changed

+69
-35
lines changed

serving_utils/tests/test_saver.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -75,31 +75,36 @@ def test_save_correct_files(saver, output_dir):
7575
assert expected_output_filenames == real_output_filenames
7676

7777

78+
def test_save_several_times(saver, output_dir):
79+
n = 5
80+
for _ in range(n):
81+
saver.save()
82+
83+
expected_output_filenames = set(
84+
[
85+
join(output_dir, f"{i}", filename)
86+
for i in range(n) for filename in [
87+
'saved_model.pb',
88+
'variables',
89+
'variables/variables.data-00000-of-00001',
90+
'variables/variables.index',
91+
]
92+
],
93+
)
94+
real_output_filenames = set(
95+
glob(output_dir + '/*/*') + # noqa: W504
96+
glob(output_dir + '/*/*/*'),
97+
)
98+
assert expected_output_filenames == real_output_filenames
99+
100+
78101
def test_save_will_not_change_model(x_test, y_test, model, saver, output_dir):
79102
old_loss = model.evaluate(x_test, y_test)
80103
saver.save()
81-
new_loss = load_n_evaluate(model.sess, output_dir, x_test, y_test)
104+
new_loss = load_n_evaluate(output_dir, x_test, y_test)
82105
assert old_loss == new_loss
83106

84107

85-
def load_n_evaluate(sess, output_dir, x_test, y_test):
86-
with tf.Session(graph=tf.Graph()) as sess:
87-
meta_graph_def = tf.saved_model.loader.load(
88-
sess=sess,
89-
tags=[tf.saved_model.tag_constants.SERVING],
90-
export_dir=join(output_dir, '0'),
91-
)
92-
evaluate_graph = meta_graph_def.signature_def['evaluate']
93-
loss = sess.run(
94-
evaluate_graph.outputs['loss'].name,
95-
feed_dict={
96-
evaluate_graph.inputs['x'].name: x_test,
97-
evaluate_graph.inputs['y'].name: y_test,
98-
},
99-
)
100-
return loss
101-
102-
103108
def test_frozen_save_correct_files(saver_with_freeze, output_dir):
104109
saver_with_freeze.save()
105110
expected_output_filenames = set(
@@ -116,19 +121,48 @@ def test_frozen_save_correct_files(saver_with_freeze, output_dir):
116121
)
117122
assert expected_output_filenames == real_output_filenames
118123

119-
# def test_freeze_graph_has_session_update(self):
120-
# old_sess = saver.session
121-
# saver.freeze_graph()
122-
# new_sess = saver.session
123-
# assertNotEqual(old_sess, new_sess)
124-
125-
# def test_freeze_graph_will_not_change_loss(self):
126-
# old_loss = model.evaluate(x_test, y_test)
127-
# saver.freeze_graph()
128-
# model.sess = saver.session
129-
# new_loss = model.evaluate(x_test, y_test)
130-
# assertEqual(old_loss, new_loss)
131-
132-
# def test_freeze_n_save(self):
133-
# saver.freeze_graph()
134-
# saver.save()
124+
125+
def test_frozen_save_several_times(saver_with_freeze, output_dir):
126+
n = 5
127+
for _ in range(n):
128+
saver_with_freeze.save()
129+
130+
expected_output_filenames = set(
131+
[
132+
join(output_dir, f"{i}", filename)
133+
for i in range(n) for filename in [
134+
'saved_model.pb',
135+
'variables',
136+
]
137+
],
138+
)
139+
real_output_filenames = set(
140+
glob(output_dir + '/*/*') + # noqa: W504
141+
glob(output_dir + '/*/*/*'),
142+
)
143+
assert expected_output_filenames == real_output_filenames
144+
145+
146+
def test_frozen_save_will_not_change_model(x_test, y_test, model, saver_with_freeze, output_dir):
147+
old_loss = model.evaluate(x_test, y_test)
148+
saver_with_freeze.save()
149+
new_loss = load_n_evaluate(output_dir, x_test, y_test)
150+
assert old_loss == new_loss
151+
152+
153+
def load_n_evaluate(output_dir, x_test, y_test):
154+
with tf.Session(graph=tf.Graph()) as sess:
155+
meta_graph_def = tf.saved_model.loader.load(
156+
sess=sess,
157+
tags=[tf.saved_model.tag_constants.SERVING],
158+
export_dir=join(output_dir, '0'),
159+
)
160+
evaluate_graph = meta_graph_def.signature_def['evaluate']
161+
loss = sess.run(
162+
evaluate_graph.outputs['loss'].name,
163+
feed_dict={
164+
evaluate_graph.inputs['x'].name: x_test,
165+
evaluate_graph.inputs['y'].name: y_test,
166+
},
167+
)
168+
return loss

0 commit comments

Comments
 (0)