From 39a31384d2a5713172623a0c3078da69d149549e Mon Sep 17 00:00:00 2001 From: Madhur Karampudi Date: Wed, 18 Dec 2024 09:31:57 -0800 Subject: [PATCH] Remove unused keyword arguments to Keras Model.save and Model.load. PiperOrigin-RevId: 707577369 --- fairness_indicators/example_model_test.py | 2 +- setup.py | 10 +-- tensorboard_plugin/setup.py | 8 +- .../plugin.py | 3 +- .../plugin_test.py | 84 +++++++++++-------- 5 files changed, 62 insertions(+), 45 deletions(-) diff --git a/fairness_indicators/example_model_test.py b/fairness_indicators/example_model_test.py index 09266a2..8be9fa4 100644 --- a/fairness_indicators/example_model_test.py +++ b/fairness_indicators/example_model_test.py @@ -91,7 +91,7 @@ def test_example_model(self): ]), batch_size=1, ) - classifier.save(self._model_dir, save_format='tf') + tf.saved_model.save(classifier, self._model_dir) eval_config = text_format.Parse( """ diff --git a/setup.py b/setup.py index 0451434..9ee89c9 100644 --- a/setup.py +++ b/setup.py @@ -38,15 +38,15 @@ def select_constraint(default, nightly=None, git_master=None): return default REQUIRED_PACKAGES = [ - 'tensorflow>=2.15,<2.16', + 'tensorflow>=2.16,<2.17', 'tensorflow-hub>=0.16.1,<1.0.0', 'tensorflow-data-validation' + select_constraint( - default='>=1.15.1,<2.0.0', - nightly='>=1.16.0.dev', + default='>=1.16.1,<2.0.0', + nightly='>=1.17.0.dev', git_master='@git+https://github.com/tensorflow/data-validation@master'), 'tensorflow-model-analysis' + select_constraint( - default='>=0.46,<0.47', - nightly='>=0.47.0.dev', + default='>=0.47.0,<0.48.0', + nightly='>=0.48.0.dev', git_master='@git+https://github.com/tensorflow/model-analysis@master'), 'witwidget>=1.4.4,<2', 'protobuf>=3.20.3,<5', diff --git a/tensorboard_plugin/setup.py b/tensorboard_plugin/setup.py index 6663771..a097a48 100644 --- a/tensorboard_plugin/setup.py +++ b/tensorboard_plugin/setup.py @@ -43,12 +43,12 @@ def select_constraint(default, nightly=None, git_master=None): REQUIRED_PACKAGES = [ 'protobuf>=3.20.3,<5', - 'tensorboard>=2.15.2,<2.16.0', - 'tensorflow>=2.15,<2.16', + 'tensorboard>=2.16.2,<2.17.0', + 'tensorflow>=2.16,<2.17', 'tensorflow-model-analysis' + select_constraint( - default='>=0.46,<0.47', - nightly='>=0.47.0.dev', + default='>=0.47,<0.48', + nightly='>=0.48.0.dev', git_master='@git+https://github.com/tensorflow/model-analysis@master', ), 'werkzeug<2', diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py index f3e1856..69390be 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py @@ -23,7 +23,8 @@ from tensorboard_plugin_fairness_indicators import metadata import six import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.addons.fairness.view import widget_view +# from tensorflow_model_analysis.addons.fairness.view import widget_view +from tensorflow_model_analysis.view import widget_view from werkzeug import wrappers from google.protobuf import json_format from tensorboard.backend import http_util diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py index e465bef..1a0965f 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py @@ -25,10 +25,10 @@ from tensorboard_plugin_fairness_indicators import plugin from tensorboard_plugin_fairness_indicators import summary_v2 import six -import tensorflow.compat.v1 as tf -import tensorflow.compat.v2 as tf2 +import tensorflow as tf2 +from tensorflow.keras import layers +from tensorflow.keras import models import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.eval_saved_model.example_trainers import linear_classifier from werkzeug import test as werkzeug_test from werkzeug import wrappers @@ -36,10 +36,23 @@ from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer from tensorboard.plugins import base_plugin -tf.enable_eager_execution() +Sequential = models.Sequential +Dense = layers.Dense + tf = tf2 +# Define keras based linear classifier. +def create_linear_classifier(model_dir): + + model = Sequential([Dense(1, activation="sigmoid", input_shape=(2,))]) + model.compile( + optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"] + ) + tf.saved_model.save(model_dir, model) + return model + + class PluginTest(tf.test.TestCase): """Tests for Fairness Indicators plugin server.""" @@ -74,19 +87,19 @@ def tearDown(self): super(PluginTest, self).tearDown() shutil.rmtree(self._log_dir, ignore_errors=True) - def _exportEvalSavedModel(self, classifier): + def _export_eval_saved_model(self): + """Export the evaluation saved model.""" temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir") - _, eval_export_dir = classifier(None, temp_eval_export_dir) - return eval_export_dir + return create_linear_classifier(temp_eval_export_dir) - def _writeTFExamplesToTFRecords(self, examples): + def _write_tf_examples_to_tfrecords(self, examples): data_location = os.path.join(self.get_temp_dir(), "input_data.rio") with tf.io.TFRecordWriter(data_location) as writer: for example in examples: writer.write(example.SerializeToString()) return data_location - def _makeExample(self, age, language, label): + def _make_tf_example(self, age, language, label): example = tf.train.Example() example.features.feature["age"].float_list.value[:] = [age] example.features.feature["language"].bytes_list.value[:] = [ @@ -112,14 +125,14 @@ def testRoutes(self): "foo": "".encode("utf-8") }}, ) - def testIsActive(self, get_random_stub): + def testIsActive(self): self.assertTrue(self._plugin.is_active()) @mock.patch.object( event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", return_value={}) - def testIsInactive(self, get_random_stub): + def testIsInactive(self): self.assertFalse(self._plugin.is_active()) def testIndexJsRoute(self): @@ -134,16 +147,15 @@ def testVulcanizedTemplateRoute(self): self.assertEqual(200, response.status_code) def testGetEvalResultsRoute(self): - model_location = self._exportEvalSavedModel( - linear_classifier.simple_linear_classifier) + model_location = self._export_eval_saved_model() # Call the method examples = [ - self._makeExample(age=3.0, language="english", label=1.0), - self._makeExample(age=3.0, language="chinese", label=0.0), - self._makeExample(age=4.0, language="english", label=1.0), - self._makeExample(age=5.0, language="chinese", label=1.0), - self._makeExample(age=5.0, language="hindi", label=1.0) + self._make_tf_example(age=3.0, language="english", label=1.0), + self._make_tf_example(age=3.0, language="chinese", label=0.0), + self._make_tf_example(age=4.0, language="english", label=1.0), + self._make_tf_example(age=5.0, language="chinese", label=1.0), + self._make_tf_example(age=5.0, language="hindi", label=1.0), ] - data_location = self._writeTFExamplesToTFRecords(examples) + data_location = self._write_tf_examples_to_tfrecords(examples) _ = tfma.run_model_analysis( eval_shared_model=tfma.default_eval_shared_model( eval_saved_model_path=model_location, example_weight_key="age"), @@ -155,16 +167,15 @@ def testGetEvalResultsRoute(self): self.assertEqual(200, response.status_code) def testGetEvalResultsFromURLRoute(self): - model_location = self._exportEvalSavedModel( - linear_classifier.simple_linear_classifier) + model_location = self._export_eval_saved_model() # Call the method examples = [ - self._makeExample(age=3.0, language="english", label=1.0), - self._makeExample(age=3.0, language="chinese", label=0.0), - self._makeExample(age=4.0, language="english", label=1.0), - self._makeExample(age=5.0, language="chinese", label=1.0), - self._makeExample(age=5.0, language="hindi", label=1.0) + self._make_tf_example(age=3.0, language="english", label=1.0), + self._make_tf_example(age=3.0, language="chinese", label=0.0), + self._make_tf_example(age=4.0, language="english", label=1.0), + self._make_tf_example(age=5.0, language="chinese", label=1.0), + self._make_tf_example(age=5.0, language="hindi", label=1.0), ] - data_location = self._writeTFExamplesToTFRecords(examples) + data_location = self._write_tf_examples_to_tfrecords(examples) _ = tfma.run_model_analysis( eval_shared_model=tfma.default_eval_shared_model( eval_saved_model_path=model_location, example_weight_key="age"), @@ -172,15 +183,20 @@ def testGetEvalResultsFromURLRoute(self): output_path=self._eval_result_output_dir) response = self._server.get( - "/data/plugin/fairness_indicators/" + - "get_evaluation_result_from_remote_path?evaluation_output_path=" + - os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY)) + "/data/plugin/fairness_indicators/" + + "get_evaluation_result_from_remote_path?evaluation_output_path=" + + self._eval_result_output_dir + ) self.assertEqual(200, response.status_code) - def testGetOutputFileFormat(self): - self.assertEqual("", self._plugin._get_output_file_format("abc_path")) - self.assertEqual("tfrecord", - self._plugin._get_output_file_format("abc_path.tfrecord")) + def test_get_output_file_format(self): + evaluation_output_path = os.path.join( + self._eval_result_output_dir, "eval_result.tfrecord" + ) + self.assertEqual( + self._plugin._get_output_file_format(evaluation_output_path), + "tfrecord", + ) if __name__ == "__main__":