From a69d97a590fc20131d2916fb47a4778c7710692b Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Fri, 24 Jan 2025 04:32:39 -0800 Subject: [PATCH] Forward kwarg `use_legacy_model_save` from `runner.KerasModelExporter` and `runner.SubmoduleExporter` to `runner.export_model()`. The `KerasModelExporter` already had the kwarg but didn't forward it. Along the way, corrects the docstrings of the exporters. PiperOrigin-RevId: 719250357 --- tensorflow_gnn/runner/utils/model_export.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow_gnn/runner/utils/model_export.py b/tensorflow_gnn/runner/utils/model_export.py index fc910301..f1a5ea99 100644 --- a/tensorflow_gnn/runner/utils/model_export.py +++ b/tensorflow_gnn/runner/utils/model_export.py @@ -23,7 +23,7 @@ class KerasModelExporter(interfaces.ModelExporter): - """Exports a Keras model (with Keras API) via `tf.keras.models.save_model`.""" + """Exports a Keras model via `runner.export_model()`.""" def __init__(self, *, @@ -84,11 +84,12 @@ def save(self, run_result: interfaces.RunResult, export_dir: str): self._include_preprocessing, self._output_names, self._subdirectory, - self._options) + self._options, + self._use_legacy_model_save) class SubmoduleExporter(interfaces.ModelExporter): - """Exports a Keras submodule. + """Exports a Keras submodule via `runner.export_model()`. Given a `RunResult`, this exporter creates and exports a submodule with inputs identical to the trained model and outputs from some intermediate layer @@ -109,7 +110,8 @@ def __init__(self, output_names: Optional[Any] = None, subdirectory: Optional[str] = None, include_preprocessing: bool = False, - options: Optional[tf.saved_model.SaveOptions] = None): + options: Optional[tf.saved_model.SaveOptions] = None, + use_legacy_model_save: Optional[bool] = None): """Captures the args shared across `save(...)` calls. Args: @@ -119,12 +121,16 @@ def __init__(self, to `os.path.join(export_dir, subdirectory)`. include_preprocessing: Whether to include any `preprocess_model`. options: Options for saving to a TensorFlow `SavedModel`. + use_legacy_model_save: Optional; most users can leave it unset to get a + useful default for export to inference. See `runner.export_model()` + for more. """ self._sublayer_name = sublayer_name self._output_names = output_names self._subdirectory = subdirectory self._include_preprocessing = include_preprocessing self._options = options + self._use_legacy_model_save = use_legacy_model_save def save(self, run_result: interfaces.RunResult, export_dir: str): """Saves a Keras model submodule. @@ -164,7 +170,8 @@ def save(self, run_result: interfaces.RunResult, export_dir: str): self._include_preprocessing, self._output_names, self._subdirectory, - self._options) + self._options, + self._use_legacy_model_save) def export_model(model: tf.keras.Model,