Skip to content

Commit 63ebe84

Browse files
committed
Polish tests.
1 parent cc3cdf8 commit 63ebe84

File tree

4 files changed

+82
-72
lines changed

4 files changed

+82
-72
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[flake8]
22

3-
max-line-length = 101
3+
max-line-length = 120
44

55
# codes of errors to ignore
66
ignore = E128, E306, E402, E722, E731, E741, W504, Q003

cmsml/scripts/compile_tf_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def compile_tf_graph(
3434
The resulting static 'ConcreteFunction' is saved as subgraph under a new *output_serving_key*
3535
signature in a SavedModel stored at *output_path*.
3636
If no *output_serving_key* is given the 'ConcreteFunction' are saved with the
37-
signature "{*input_serving_key*}__bs{*batch_size*}"
37+
signature "{*input_serving_key*}_bs{*batch_size*}".
3838
3939
An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given.
4040
In this case *compile_prefix* is the file prefix, while *compile_class* is the name of the
@@ -44,7 +44,7 @@ def compile_tf_graph(
4444

4545
# default output_serving key
4646
if not output_serving_key:
47-
output_serving_key = input_serving_key + "__{}"
47+
output_serving_key = input_serving_key + "_bs{}"
4848

4949
# check compile values
5050
if compile_prefix and not compile_class:
@@ -80,7 +80,7 @@ def compile_tf_graph(
8080
for n in spec.shape
8181
]
8282
# : is the delimiter of ops numering scheme
83-
name = f"{spec.name.replace(':', '_')}__bs{bs}"
83+
name = f"{spec.name.replace(':', '_')}_bs{bs}"
8484
# store the new spec
8585
specs[key] = type(spec)(type(spec.shape)(shape), dtype=spec.dtype, name=name)
8686

@@ -109,7 +109,7 @@ def aot_compile(
109109
prefix: str,
110110
class_name: str,
111111
batch_sizes: tuple[int] = (1,),
112-
serving_key: str = r"serving_default__bs{}",
112+
serving_key: str = r"serving_default_bs{}",
113113
) -> None:
114114
"""
115115
Take the provided static subgraph under specified *serving_key* from the SavedModel

tests/test_aot.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def skip_if_no_tf2xla_supported_ops(func):
2323
@functools.wraps(func)
2424
def wrapper(*args, **kwargs):
2525
if not HAS_TF2XLA_SUPPORTED_OPS:
26+
print(f"skipping {func.__name__} because tf2xla_supported_ops is not available")
2627
return
2728
return func(*args, **kwargs)
2829
return wrapper
@@ -79,7 +80,7 @@ def create_graph_def(self, create="saved_model", **kwargs):
7980
keras_graph_def = cmsml_tools.load_graph_def(keras_path, default_signature)
8081
return tf_graph_def, keras_graph_def
8182

82-
elif create == "graph":
83+
if create == "graph":
8384
concrete_func = tf.function(model).get_concrete_function(tf.ones((2, 10)))
8485

8586
with tmp_file(suffix=".pb") as pb_path:
@@ -90,6 +91,8 @@ def create_graph_def(self, create="saved_model", **kwargs):
9091
)
9192
return graph_graph_def
9293

94+
self.assertTrue(False)
95+
9396
@skip_if_no_tf2xla_supported_ops
9497
def test_get_graph_ops_saved_model(self):
9598
from cmsml.tensorflow.aot import get_graph_ops
@@ -159,7 +162,6 @@ def test_parse_ops_table(self):
159162

160163
# check if ops name and content exist
161164
# since content changes with every version only naiv test is done
162-
163165
for op in expected_ops:
164166
self.assertTrue(bool(ops_dict[op]["allowed_types"]))
165167

tests/test_compile_tf_graph.py

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
class TfCompileTestCase(CMSMLTestCase):
13+
1314
def __init__(self, *args, **kwargs):
1415
super(TfCompileTestCase, self).__init__(*args, **kwargs)
1516

@@ -44,6 +45,7 @@ def create_test_model(self, tf):
4445
x = tf.concat([x1, x2], axis=1)
4546
a1 = tf.keras.layers.Dense(10, activation="elu")(x)
4647
y = tf.keras.layers.Dense(5, activation="softmax")(a1)
48+
4749
model = tf.keras.Model(inputs=(x1, x2, x3), outputs=y)
4850
return model
4951

@@ -55,55 +57,60 @@ def test_compile_tf_graph_static_preparation(self):
5557

5658
model = self.create_test_model(tf)
5759

58-
with tmp_dir(create=False) as model_path:
60+
with tmp_dir(create=False) as model_path, tmp_dir(create=False) as static_saved_model_path:
5961
tf.saved_model.save(model, model_path)
6062

61-
with tmp_dir(create=False) as static_saved_model_path:
62-
batch_sizes = [1, 2]
63-
64-
compile_tf_graph(model_path=model_path,
65-
output_path=static_saved_model_path,
66-
batch_sizes=batch_sizes,
67-
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
68-
output_serving_key=None,
69-
compile_prefix=None,
70-
compile_class=None)
71-
72-
# load model and check input shape
73-
loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path)
74-
for batch_size in batch_sizes:
75-
# first entry is empty, second contains inputs tuple(tensorspecs)
76-
model_static_inputs = loaded_static_model.signatures[f"serving_default__{batch_size}"].structured_input_signature[1] # noqa
77-
78-
expected_model_static_inputs = {
79-
f"first__bs{batch_size}": tf.TensorSpec(
80-
shape=(batch_size, 2),
81-
dtype=tf.float32,
82-
name=f"first__bs{batch_size}",
83-
),
84-
f"second__bs{batch_size}": tf.TensorSpec(
85-
shape=(batch_size, 3),
86-
dtype=tf.float32,
87-
name=f"second__bs{batch_size}",
88-
),
89-
f"third__bs{batch_size}": tf.TensorSpec(
90-
shape=(batch_size, 10),
91-
dtype=tf.float32,
92-
name=f"third__bs{batch_size}",
93-
),
94-
}
95-
96-
self.assertDictEqual(model_static_inputs, expected_model_static_inputs)
97-
98-
# throw error if compilation happens with illegal batch size
99-
with self.assertRaises(ValueError):
100-
compile_tf_graph(model_path=model_path,
101-
output_path=static_saved_model_path,
102-
batch_sizes=[-1,],
103-
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
104-
output_serving_key=None,
105-
compile_prefix=None,
106-
compile_class=None)
63+
# throw error if compilation happens with illegal batch size
64+
with self.assertRaises(ValueError):
65+
compile_tf_graph(
66+
model_path=model_path,
67+
output_path=static_saved_model_path,
68+
batch_sizes=[-1],
69+
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
70+
output_serving_key=None,
71+
compile_prefix=None,
72+
compile_class=None,
73+
)
74+
75+
batch_sizes = [1, 2]
76+
compile_tf_graph(
77+
model_path=model_path,
78+
output_path=static_saved_model_path,
79+
batch_sizes=batch_sizes,
80+
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
81+
output_serving_key=None,
82+
compile_prefix=None,
83+
compile_class=None,
84+
)
85+
86+
# load model
87+
loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path)
88+
89+
# check input shape
90+
for batch_size in batch_sizes:
91+
# first entry is empty, second contains inputs tuple(tensorspecs)
92+
key = f"serving_default_bs{batch_size}"
93+
model_static_inputs = loaded_static_model.signatures[key].structured_input_signature[1]
94+
95+
expected_model_static_inputs = {
96+
f"first_bs{batch_size}": tf.TensorSpec(
97+
shape=(batch_size, 2),
98+
dtype=tf.float32,
99+
name=f"first_bs{batch_size}",
100+
),
101+
f"second_bs{batch_size}": tf.TensorSpec(
102+
shape=(batch_size, 3),
103+
dtype=tf.float32,
104+
name=f"second_bs{batch_size}",
105+
),
106+
f"third_bs{batch_size}": tf.TensorSpec(
107+
shape=(batch_size, 10),
108+
dtype=tf.float32,
109+
name=f"third_bs{batch_size}",
110+
),
111+
}
112+
113+
self.assertDictEqual(model_static_inputs, expected_model_static_inputs)
107114

108115
def test_compile_tf_graph_static_aot_compilation(self):
109116
from cmsml.scripts.compile_tf_graph import compile_tf_graph
@@ -112,23 +119,24 @@ def test_compile_tf_graph_static_aot_compilation(self):
112119
tf = self.tf
113120
model = self.create_test_model(tf)
114121

115-
with tmp_dir(create=False) as model_path:
122+
with tmp_dir(create=False) as model_path, tmp_dir(create=False) as static_saved_model_path:
116123
tf.saved_model.save(model, model_path)
117124

118-
with tmp_dir(create=False) as static_saved_model_path:
119-
batch_sizes = [1, 2]
120-
compile_tf_graph(model_path=model_path,
121-
output_path=static_saved_model_path,
122-
batch_sizes=batch_sizes,
123-
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
124-
output_serving_key=None,
125-
compile_prefix="aot_model_bs_{}",
126-
compile_class="bs_{}")
127-
128-
aot_dir = os.path.join(static_saved_model_path, "aot")
129-
for batch_size in batch_sizes:
130-
aot_model_header = os.path.join(aot_dir, "aot_model_bs_{}.h".format(batch_size))
131-
aot_model_object = os.path.join(aot_dir, "aot_model_bs_{}.o".format(batch_size))
132-
133-
self.assertTrue(os.path.exists(aot_model_object))
134-
self.assertTrue(os.path.exists(aot_model_header))
125+
batch_sizes = [1, 2]
126+
compile_tf_graph(
127+
model_path=model_path,
128+
output_path=static_saved_model_path,
129+
batch_sizes=batch_sizes,
130+
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
131+
output_serving_key=None,
132+
compile_prefix="aot_model_bs{}",
133+
compile_class="bs_{}",
134+
)
135+
136+
aot_dir = os.path.join(static_saved_model_path, "aot")
137+
for batch_size in batch_sizes:
138+
aot_model_header = os.path.join(aot_dir, "aot_model_bs{}.h".format(batch_size))
139+
aot_model_object = os.path.join(aot_dir, "aot_model_bs{}.o".format(batch_size))
140+
141+
self.assertTrue(os.path.exists(aot_model_object))
142+
self.assertTrue(os.path.exists(aot_model_header))

0 commit comments

Comments
 (0)