Skip to content

Commit

Permalink
Forward kwarg use_legacy_model_save from runner.KerasModelExporter
Browse files Browse the repository at this point in the history
and `runner.SubmoduleExporter` to `runner.export_model()`.
The `KerasModelExporter` already had the kwarg but didn't forward it.

Along the way, corrects the docstrings of the exporters.

PiperOrigin-RevId: 719250357
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Jan 24, 2025
1 parent 7509061 commit a69d97a
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions tensorflow_gnn/runner/utils/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class KerasModelExporter(interfaces.ModelExporter):
"""Exports a Keras model (with Keras API) via `tf.keras.models.save_model`."""
"""Exports a Keras model via `runner.export_model()`."""

def __init__(self,
*,
Expand Down Expand Up @@ -84,11 +84,12 @@ def save(self, run_result: interfaces.RunResult, export_dir: str):
self._include_preprocessing,
self._output_names,
self._subdirectory,
self._options)
self._options,
self._use_legacy_model_save)


class SubmoduleExporter(interfaces.ModelExporter):
"""Exports a Keras submodule.
"""Exports a Keras submodule via `runner.export_model()`.
Given a `RunResult`, this exporter creates and exports a submodule with
inputs identical to the trained model and outputs from some intermediate layer
Expand All @@ -109,7 +110,8 @@ def __init__(self,
output_names: Optional[Any] = None,
subdirectory: Optional[str] = None,
include_preprocessing: bool = False,
options: Optional[tf.saved_model.SaveOptions] = None):
options: Optional[tf.saved_model.SaveOptions] = None,
use_legacy_model_save: Optional[bool] = None):
"""Captures the args shared across `save(...)` calls.
Args:
Expand All @@ -119,12 +121,16 @@ def __init__(self,
to `os.path.join(export_dir, subdirectory)`.
include_preprocessing: Whether to include any `preprocess_model`.
options: Options for saving to a TensorFlow `SavedModel`.
use_legacy_model_save: Optional; most users can leave it unset to get a
useful default for export to inference. See `runner.export_model()`
for more.
"""
self._sublayer_name = sublayer_name
self._output_names = output_names
self._subdirectory = subdirectory
self._include_preprocessing = include_preprocessing
self._options = options
self._use_legacy_model_save = use_legacy_model_save

def save(self, run_result: interfaces.RunResult, export_dir: str):
"""Saves a Keras model submodule.
Expand Down Expand Up @@ -164,7 +170,8 @@ def save(self, run_result: interfaces.RunResult, export_dir: str):
self._include_preprocessing,
self._output_names,
self._subdirectory,
self._options)
self._options,
self._use_legacy_model_save)


def export_model(model: tf.keras.Model,
Expand Down

0 comments on commit a69d97a

Please sign in to comment.