|
| 1 | +# coding: utf-8 |
| 2 | + |
1 | 3 | import os
|
| 4 | +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
2 | 5 |
|
3 | 6 | import cmsml
|
4 | 7 | from cmsml.util import tmp_dir
|
|
10 | 13 | class TfCompileTestCase(CMSMLTestCase):
|
11 | 14 | def __init__(self, *args, **kwargs):
|
12 | 15 | super(TfCompileTestCase, self).__init__(*args, **kwargs)
|
13 |
| - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
14 | 16 |
|
15 | 17 | self._tf = None
|
16 | 18 | self._tf1 = None
|
@@ -70,11 +72,25 @@ def test_compile_tf_graph_static_preparation(self):
|
70 | 72 | loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path)
|
71 | 73 | for batch_size in batch_sizes:
|
72 | 74 | # first entry is empty, second contains inputs tuple(tensorspecs)
|
73 |
| - model_static_inputs = loaded_static_model.signatures[f'serving_default__{batch_size}'].structured_input_signature[1] |
74 |
| - |
75 |
| - expected_model_static_inputs = {f"first__bs{batch_size}": tf.TensorSpec(shape=(batch_size, 2), dtype=tf.float32, name=f"first__bs{batch_size}"), |
76 |
| - f"second__bs{batch_size}": tf.TensorSpec(shape=(batch_size, 3), dtype=tf.float32, name=f"second__bs{batch_size}"), |
77 |
| - f"third__bs{batch_size}": tf.TensorSpec(shape=(batch_size, 10), dtype=tf.float32, name=f"third__bs{batch_size}")} |
| 75 | + model_static_inputs = loaded_static_model.signatures[f"serving_default__{batch_size}"].structured_input_signature[1] # noqa |
| 76 | + |
| 77 | + expected_model_static_inputs = { |
| 78 | + f"first__bs{batch_size}": tf.TensorSpec( |
| 79 | + shape=(batch_size, 2), |
| 80 | + dtype=tf.float32, |
| 81 | + name=f"first__bs{batch_size}", |
| 82 | + ), |
| 83 | + f"second__bs{batch_size}": tf.TensorSpec( |
| 84 | + shape=(batch_size, 3), |
| 85 | + dtype=tf.float32, |
| 86 | + name=f"second__bs{batch_size}", |
| 87 | + ), |
| 88 | + f"third__bs{batch_size}": tf.TensorSpec( |
| 89 | + shape=(batch_size, 10), |
| 90 | + dtype=tf.float32, |
| 91 | + name=f"third__bs{batch_size}", |
| 92 | + ), |
| 93 | + } |
78 | 94 |
|
79 | 95 | self.assertDictEqual(model_static_inputs, expected_model_static_inputs)
|
80 | 96 |
|
@@ -103,13 +119,13 @@ def test_compile_tf_graph_static_aot_compilation(self):
|
103 | 119 | batch_sizes=batch_sizes,
|
104 | 120 | input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
105 | 121 | output_serving_key=None,
|
106 |
| - compile_prefix='aot_model_bs_{}', |
107 |
| - compile_class='bs_{}') |
| 122 | + compile_prefix="aot_model_bs_{}", |
| 123 | + compile_class="bs_{}") |
108 | 124 |
|
109 | 125 | aot_dir = os.path.join(static_saved_model_path, "aot")
|
110 | 126 | for batch_size in batch_sizes:
|
111 |
| - aot_model_header = os.path.join(aot_dir, 'aot_model_bs_{}.h'.format(batch_size)) |
112 |
| - aot_model_object = os.path.join(aot_dir, 'aot_model_bs_{}.o'.format(batch_size)) |
| 127 | + aot_model_header = os.path.join(aot_dir, "aot_model_bs_{}.h".format(batch_size)) |
| 128 | + aot_model_object = os.path.join(aot_dir, "aot_model_bs_{}.o".format(batch_size)) |
113 | 129 |
|
114 | 130 | self.assertTrue(os.path.exists(aot_model_object))
|
115 | 131 | self.assertTrue(os.path.exists(aot_model_header))
|
0 commit comments