From b3395d716a4334738eb0c333db3591d904249af5 Mon Sep 17 00:00:00 2001 From: Bogdan-Wiederspan <79155113+Bogdan-Wiederspan@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:40:48 +0100 Subject: [PATCH] Fix model preparation for custom signatures (#22) * fixed, that the input_serving_key is not passed throught the compiler * adapted the test of compile_tf_graph to check after custom signatures * linting typo Co-authored-by: Marcel Rieger --------- Co-authored-by: Marcel Rieger --- cmsml/scripts/compile_tf_graph.py | 2 +- tests/test_compile_tf_graph.py | 102 +++++++++++++++++------------- 2 files changed, 59 insertions(+), 45 deletions(-) diff --git a/cmsml/scripts/compile_tf_graph.py b/cmsml/scripts/compile_tf_graph.py index 40b708c..78dcb5e 100644 --- a/cmsml/scripts/compile_tf_graph.py +++ b/cmsml/scripts/compile_tf_graph.py @@ -71,7 +71,7 @@ def compile_tf_graph( for bs in sorted(set(map(int, batch_sizes))): # create a fully defined signature, filling leading None's in shapes with the batch size specs = {} - for key, spec in model.signatures["serving_default"].structured_input_signature[1].items(): + for key, spec in model.signatures[input_serving_key].structured_input_signature[1].items(): # ignore inputs without undefined axes if None not in spec.shape: continue diff --git a/tests/test_compile_tf_graph.py b/tests/test_compile_tf_graph.py index 73e28ab..72cde6b 100644 --- a/tests/test_compile_tf_graph.py +++ b/tests/test_compile_tf_graph.py @@ -58,59 +58,73 @@ def test_compile_tf_graph_static_preparation(self): model = self.create_test_model(tf) with tmp_dir(create=False) as model_path, tmp_dir(create=False) as static_saved_model_path: - tf.saved_model.save(model, model_path) + spec = [ + tf.TensorSpec(shape=(None, 2), dtype=tf.float32, name="inputs"), + tf.TensorSpec(shape=(None, 3), dtype=tf.float32, name="inputs_1"), + tf.TensorSpec(shape=(None, 10), dtype=tf.float32, name="inputs_2"), + ] + + conc_func = tf.function(model.call).get_concrete_function(spec) + signatures = { + tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: conc_func, + "custom_signature": conc_func, + } - # throw error if compilation happens with illegal batch size - with self.assertRaises(ValueError): + tf.saved_model.save(model, model_path, signatures=signatures) + + for signature in signatures: + # throw error if compilation happens with illegal batch size + with self.assertRaises(ValueError): + compile_tf_graph( + model_path=model_path, + output_path=static_saved_model_path, + batch_sizes=[-1], + input_serving_key=signature, + output_serving_key=None, + compile_prefix=None, + compile_class=None, + ) + + batch_sizes = [1, 2] compile_tf_graph( model_path=model_path, output_path=static_saved_model_path, - batch_sizes=[-1], - input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + batch_sizes=batch_sizes, + input_serving_key=signature, output_serving_key=None, compile_prefix=None, compile_class=None, ) - batch_sizes = [1, 2] - compile_tf_graph( - model_path=model_path, - output_path=static_saved_model_path, - batch_sizes=batch_sizes, - input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - output_serving_key=None, - compile_prefix=None, - compile_class=None, - ) - - # load model - loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path) - - # check input shape - for batch_size in batch_sizes: - # first entry is empty, second contains inputs tuple(tensorspecs) - key = f"serving_default_bs{batch_size}" - model_static_inputs = loaded_static_model.signatures[key].structured_input_signature[1] - - expected_model_static_inputs = { - f"inputs_bs{batch_size}": tf.TensorSpec( - shape=(batch_size, 2), - dtype=tf.float32, - name=f"inputs_bs{batch_size}", - ), - f"inputs_1_bs{batch_size}": tf.TensorSpec( - shape=(batch_size, 3), - dtype=tf.float32, - name=f"inputs_1_bs{batch_size}", - ), - f"inputs_2_bs{batch_size}": tf.TensorSpec( - shape=(batch_size, 10), - dtype=tf.float32, - name=f"inputs_2_bs{batch_size}", - ), - } - - self.assertDictEqual(model_static_inputs, expected_model_static_inputs) + # load model + breakpoint(header='') + loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path) + + # check input shape + for batch_size in batch_sizes: + # first entry is empty, second contains inputs tuple(tensorspecs) + key = f"{signature}_bs{batch_size}" + model_static_inputs = loaded_static_model.signatures[key].structured_input_signature[1] + + expected_model_static_inputs = { + f"inputs_bs{batch_size}": tf.TensorSpec( + shape=(batch_size, 2), + dtype=tf.float32, + name=f"inputs_bs{batch_size}", + ), + f"inputs_1_bs{batch_size}": tf.TensorSpec( + shape=(batch_size, 3), + dtype=tf.float32, + name=f"inputs_1_bs{batch_size}", + ), + f"inputs_2_bs{batch_size}": tf.TensorSpec( + shape=(batch_size, 10), + dtype=tf.float32, + name=f"inputs_2_bs{batch_size}", + ), + } + + self.assertDictEqual(model_static_inputs, expected_model_static_inputs) def test_compile_tf_graph_static_aot_compilation(self): from cmsml.scripts.compile_tf_graph import compile_tf_graph