Skip to content

Commit

Permalink
Expand saved model sh_test to cover Model.export().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587648808
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Dec 4, 2023
1 parent e165738 commit f8d8a8e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 7 deletions.
16 changes: 13 additions & 3 deletions tensorflow_gnn/runner/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,19 @@ pytype_strict_binary(
)

sh_test(
name = "saved_model_test",
size = "small",
srcs = ["saved_model_test.sh"],
name = "saved_model_export_test",
size = "medium",
srcs = ["saved_model_export_test.sh"],
data = [
":saved_model_gen_testdata",
":saved_model_load_testdata",
],
)

sh_test(
name = "saved_model_save_test",
size = "medium",
srcs = ["saved_model_save_test.sh"],
data = [
":saved_model_gen_testdata",
":saved_model_load_testdata",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,30 @@
# limitations under the License.
# ==============================================================================

set -eu # Leave no failure undetected.

get_binary () {
echo "runner/utils/$1"
}

# Generate the saved model.
die() {
echo "$*" 1>&2
exit 1
}

# Generate the SavedModel by the Model.export() API.
gen_test_data_par="${TEST_SRCDIR}/$(get_binary 'saved_model_gen_testdata')"
readonly gen_test_data_par

$gen_test_data_par --filepath=${TEST_TMPDIR}/saved_model_testdata || die "Failed to execute $gen_test_data_par"
$gen_test_data_par \
--filepath=${TEST_TMPDIR}/saved_model_testdata \
--use_legacy_model_save=0 \
|| die "Failed to execute $gen_test_data_par"

# Attempt to load the saved model.
testpar="${TEST_SRCDIR}/$(get_binary 'saved_model_load_testdata')"
readonly testpar

$testpar --filepath=${TEST_TMPDIR}/saved_model_testdata || die "Failed to execute $testpar"

echo "PASS"
echo "PASS"
9 changes: 8 additions & 1 deletion tensorflow_gnn/runner/utils/saved_model_gen_testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
required=True,
)

_USE_LEGACY_MODEL_SAVE = flags.DEFINE_boolean(
"use_legacy_model_save",
None,
"Flag forwarded to runner.export_model().",
)


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
Expand All @@ -49,7 +55,8 @@ def fn(inputs, **unused_kwargs):
node_set_name="nodes")(outputs)
model = tf.keras.Model((source, target, h), outputs)

model_export.export_model(model, _FILEPATH.value)
model_export.export_model(model, _FILEPATH.value,
use_legacy_model_save=_USE_LEGACY_MODEL_SAVE.value)


if __name__ == "__main__":
Expand Down
44 changes: 44 additions & 0 deletions tensorflow_gnn/runner/utils/saved_model_save_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash
#
# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

set -eu # Leave no failure undetected.

get_binary () {
echo "runner/utils/$1"
}

die() {
echo "$*" 1>&2
exit 1
}

# Generate the SavedModel by the legay Model.save() API.
gen_test_data_par="${TEST_SRCDIR}/$(get_binary 'saved_model_gen_testdata')"
readonly gen_test_data_par

$gen_test_data_par \
--filepath=${TEST_TMPDIR}/saved_model_testdata \
--use_legacy_model_save=1 \
|| die "Failed to execute $gen_test_data_par"

# Attempt to load the saved model.
testpar="${TEST_SRCDIR}/$(get_binary 'saved_model_load_testdata')"
readonly testpar

$testpar --filepath=${TEST_TMPDIR}/saved_model_testdata || die "Failed to execute $testpar"

echo "PASS"

0 comments on commit f8d8a8e

Please sign in to comment.