Skip to content

Commit e1aa4e0

Browse files
committed
Adjust tests.
1 parent 795abab commit e1aa4e0

File tree

2 files changed

+40
-38
lines changed

2 files changed

+40
-38
lines changed

tests/test_aot.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,23 @@ def create_graph_def(self, create="saved_model", **kwargs):
8383

8484
with tmp_file(suffix=".pb") as pb_path:
8585
cmsml_tools.save_graph(pb_path, concrete_func, variables_to_constants=False)
86-
graph_graph_def = cmsml.tensorflow.load_graph_def(pb_path,
87-
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
86+
graph_graph_def = cmsml.tensorflow.load_graph_def(
87+
pb_path,
88+
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
89+
)
8890
return graph_graph_def
8991

9092
@skip_if_no_tf2xla_supported_ops
9193
def test_get_graph_ops_saved_model(self):
9294
tf_graph_def, keras_graph_def = self.create_graph_def(create="saved_model")
9395

9496
graph_ops = set(get_graph_ops(tf_graph_def, node_def_number=0))
95-
expected_ops = {"AddV2",
96-
"BiasAdd",
97-
"Const",
98-
"Identity",
99-
"MatMul",
100-
"Mul",
101-
"NoOp",
102-
"Rsqrt",
103-
"Softmax",
104-
"Sub",
105-
"Tanh"
106-
}
107-
97+
expected_ops = {
98+
"AddV2", "BiasAdd", "Const", "Identity", "MatMul", "Mul", "NoOp", "Rsqrt", "Softmax",
99+
"Sub", "Tanh",
100+
}
108101
io_ops = {"ReadVariableOp", "Placeholder"}
102+
109103
ops_without_io = graph_ops - io_ops
110104
self.assertSetEqual(ops_without_io, expected_ops)
111105

@@ -114,19 +108,10 @@ def test_get_graph_ops_graph(self):
114108
concrete_function_graph_def = self.create_graph_def(create="graph")
115109
graph_ops = set(get_graph_ops(concrete_function_graph_def, node_def_number=0))
116110

117-
expected_ops = {"AddV2",
118-
"BiasAdd",
119-
"Const",
120-
"Identity",
121-
"MatMul",
122-
"Mul",
123-
"NoOp",
124-
"Rsqrt",
125-
"Softmax",
126-
"Sub",
127-
"Tanh"
128-
}
129-
111+
expected_ops = {
112+
"AddV2", "BiasAdd", "Const", "Identity", "MatMul", "Mul", "NoOp", "Rsqrt", "Softmax",
113+
"Sub", "Tanh",
114+
}
130115
io_ops = {"ReadVariableOp", "Placeholder"}
131116

132117
ops_without_io = graph_ops - io_ops
@@ -160,6 +145,7 @@ def tf_version(self):
160145
self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf()
161146
return self._tf_version
162147

148+
@skip_if_no_tf2xla_supported_ops
163149
def test_parse_ops_table(self):
164150
ops_dict = OpsData.parse_ops_table(device="cpu")
165151
expected_ops = ("Abs", "Acosh", "Add", "Atan", "BatchMatMul", "Conv2D")

tests/test_compile_tf_graph.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
# coding: utf-8
2+
13
import os
4+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
25

36
import cmsml
47
from cmsml.util import tmp_dir
@@ -10,7 +13,6 @@
1013
class TfCompileTestCase(CMSMLTestCase):
1114
def __init__(self, *args, **kwargs):
1215
super(TfCompileTestCase, self).__init__(*args, **kwargs)
13-
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
1416

1517
self._tf = None
1618
self._tf1 = None
@@ -70,11 +72,25 @@ def test_compile_tf_graph_static_preparation(self):
7072
loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path)
7173
for batch_size in batch_sizes:
7274
# first entry is empty, second contains inputs tuple(tensorspecs)
73-
model_static_inputs = loaded_static_model.signatures[f'serving_default__{batch_size}'].structured_input_signature[1]
74-
75-
expected_model_static_inputs = {f"first__bs{batch_size}": tf.TensorSpec(shape=(batch_size, 2), dtype=tf.float32, name=f"first__bs{batch_size}"),
76-
f"second__bs{batch_size}": tf.TensorSpec(shape=(batch_size, 3), dtype=tf.float32, name=f"second__bs{batch_size}"),
77-
f"third__bs{batch_size}": tf.TensorSpec(shape=(batch_size, 10), dtype=tf.float32, name=f"third__bs{batch_size}")}
75+
model_static_inputs = loaded_static_model.signatures[f"serving_default__{batch_size}"].structured_input_signature[1] # noqa
76+
77+
expected_model_static_inputs = {
78+
f"first__bs{batch_size}": tf.TensorSpec(
79+
shape=(batch_size, 2),
80+
dtype=tf.float32,
81+
name=f"first__bs{batch_size}",
82+
),
83+
f"second__bs{batch_size}": tf.TensorSpec(
84+
shape=(batch_size, 3),
85+
dtype=tf.float32,
86+
name=f"second__bs{batch_size}",
87+
),
88+
f"third__bs{batch_size}": tf.TensorSpec(
89+
shape=(batch_size, 10),
90+
dtype=tf.float32,
91+
name=f"third__bs{batch_size}",
92+
),
93+
}
7894

7995
self.assertDictEqual(model_static_inputs, expected_model_static_inputs)
8096

@@ -103,13 +119,13 @@ def test_compile_tf_graph_static_aot_compilation(self):
103119
batch_sizes=batch_sizes,
104120
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
105121
output_serving_key=None,
106-
compile_prefix='aot_model_bs_{}',
107-
compile_class='bs_{}')
122+
compile_prefix="aot_model_bs_{}",
123+
compile_class="bs_{}")
108124

109125
aot_dir = os.path.join(static_saved_model_path, "aot")
110126
for batch_size in batch_sizes:
111-
aot_model_header = os.path.join(aot_dir, 'aot_model_bs_{}.h'.format(batch_size))
112-
aot_model_object = os.path.join(aot_dir, 'aot_model_bs_{}.o'.format(batch_size))
127+
aot_model_header = os.path.join(aot_dir, "aot_model_bs_{}.h".format(batch_size))
128+
aot_model_object = os.path.join(aot_dir, "aot_model_bs_{}.o".format(batch_size))
113129

114130
self.assertTrue(os.path.exists(aot_model_object))
115131
self.assertTrue(os.path.exists(aot_model_header))

0 commit comments

Comments
 (0)