Skip to content

Commit

Permalink
Fix model preparation for custom signatures (#22)
Browse files Browse the repository at this point in the history
* fixed, that the input_serving_key is not passed throught the compiler

* adapted the test of compile_tf_graph to check after custom signatures

* linting typo

Co-authored-by: Marcel Rieger <riga@users.noreply.github.com>

---------

Co-authored-by: Marcel Rieger <riga@users.noreply.github.com>
  • Loading branch information
Bogdan-Wiederspan and riga authored Oct 29, 2024
1 parent 439a2e6 commit b3395d7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 45 deletions.
2 changes: 1 addition & 1 deletion cmsml/scripts/compile_tf_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def compile_tf_graph(
for bs in sorted(set(map(int, batch_sizes))):
# create a fully defined signature, filling leading None's in shapes with the batch size
specs = {}
for key, spec in model.signatures["serving_default"].structured_input_signature[1].items():
for key, spec in model.signatures[input_serving_key].structured_input_signature[1].items():
# ignore inputs without undefined axes
if None not in spec.shape:
continue
Expand Down
102 changes: 58 additions & 44 deletions tests/test_compile_tf_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,59 +58,73 @@ def test_compile_tf_graph_static_preparation(self):
model = self.create_test_model(tf)

with tmp_dir(create=False) as model_path, tmp_dir(create=False) as static_saved_model_path:
tf.saved_model.save(model, model_path)
spec = [
tf.TensorSpec(shape=(None, 2), dtype=tf.float32, name="inputs"),
tf.TensorSpec(shape=(None, 3), dtype=tf.float32, name="inputs_1"),
tf.TensorSpec(shape=(None, 10), dtype=tf.float32, name="inputs_2"),
]

conc_func = tf.function(model.call).get_concrete_function(spec)
signatures = {
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: conc_func,
"custom_signature": conc_func,
}

# throw error if compilation happens with illegal batch size
with self.assertRaises(ValueError):
tf.saved_model.save(model, model_path, signatures=signatures)

for signature in signatures:
# throw error if compilation happens with illegal batch size
with self.assertRaises(ValueError):
compile_tf_graph(
model_path=model_path,
output_path=static_saved_model_path,
batch_sizes=[-1],
input_serving_key=signature,
output_serving_key=None,
compile_prefix=None,
compile_class=None,
)

batch_sizes = [1, 2]
compile_tf_graph(
model_path=model_path,
output_path=static_saved_model_path,
batch_sizes=[-1],
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
batch_sizes=batch_sizes,
input_serving_key=signature,
output_serving_key=None,
compile_prefix=None,
compile_class=None,
)

batch_sizes = [1, 2]
compile_tf_graph(
model_path=model_path,
output_path=static_saved_model_path,
batch_sizes=batch_sizes,
input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
output_serving_key=None,
compile_prefix=None,
compile_class=None,
)

# load model
loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path)

# check input shape
for batch_size in batch_sizes:
# first entry is empty, second contains inputs tuple(tensorspecs)
key = f"serving_default_bs{batch_size}"
model_static_inputs = loaded_static_model.signatures[key].structured_input_signature[1]

expected_model_static_inputs = {
f"inputs_bs{batch_size}": tf.TensorSpec(
shape=(batch_size, 2),
dtype=tf.float32,
name=f"inputs_bs{batch_size}",
),
f"inputs_1_bs{batch_size}": tf.TensorSpec(
shape=(batch_size, 3),
dtype=tf.float32,
name=f"inputs_1_bs{batch_size}",
),
f"inputs_2_bs{batch_size}": tf.TensorSpec(
shape=(batch_size, 10),
dtype=tf.float32,
name=f"inputs_2_bs{batch_size}",
),
}

self.assertDictEqual(model_static_inputs, expected_model_static_inputs)
# load model
breakpoint(header='')
loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path)

# check input shape
for batch_size in batch_sizes:
# first entry is empty, second contains inputs tuple(tensorspecs)
key = f"{signature}_bs{batch_size}"
model_static_inputs = loaded_static_model.signatures[key].structured_input_signature[1]

expected_model_static_inputs = {
f"inputs_bs{batch_size}": tf.TensorSpec(
shape=(batch_size, 2),
dtype=tf.float32,
name=f"inputs_bs{batch_size}",
),
f"inputs_1_bs{batch_size}": tf.TensorSpec(
shape=(batch_size, 3),
dtype=tf.float32,
name=f"inputs_1_bs{batch_size}",
),
f"inputs_2_bs{batch_size}": tf.TensorSpec(
shape=(batch_size, 10),
dtype=tf.float32,
name=f"inputs_2_bs{batch_size}",
),
}

self.assertDictEqual(model_static_inputs, expected_model_static_inputs)

def test_compile_tf_graph_static_aot_compilation(self):
from cmsml.scripts.compile_tf_graph import compile_tf_graph
Expand Down

0 comments on commit b3395d7

Please sign in to comment.