Skip to content

Commit

Permalink
Fix fairness-indicators tests broken by keras and v2 compatibility ch…
Browse files Browse the repository at this point in the history
…anges.

PiperOrigin-RevId: 715469494
  • Loading branch information
vkarampudi authored and Responsible ML Infra Team committed Jan 15, 2025
1 parent 2e23c83 commit 46cb27a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 150 deletions.
153 changes: 7 additions & 146 deletions tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,13 @@
# ==============================================================================
"""TensorBoard Fairnss Indicators plugin."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from typing import Any, Union

from absl import logging
from tensorboard_plugin_fairness_indicators import metadata
import six
import tensorflow as tf
import tensorflow_model_analysis as tfma
from werkzeug import wrappers

from google.protobuf import json_format
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin


_TEMPLATE_LOCATION = os.path.normpath(
os.path.join(
__file__, '../../'
'tensorflow_model_analysis/static/vulcanized_tfma.js'))


def stringify_slice_key_value(
slice_key: tfma.slicer.slicer_lib.SliceKeyType,
slice_key: tfma.types.SliceKeyType,
) -> str:
"""Stringifies a slice key value.
Expand Down Expand Up @@ -89,7 +69,7 @@ def stringify_slice_key_value(


def _add_cross_slice_key_data(
slice_key: tfma.slicer.slicer_lib.CrossSliceKeyType,
slice_key: tfma.types.CrossSliceKeyType,
metrics: tfma.view.view_types.MetricsByTextKey,
data: list[Any],
):
Expand All @@ -109,9 +89,9 @@ def _add_cross_slice_key_data(
+ stringify_slice_key_value(comparison_key)
)
stringify_slice = (
tfma.slicer.slicer_lib.stringify_slice_key(baseline_key)
tfma.stringify_slice_key(baseline_key)
+ '__XX__'
+ tfma.slicer.slicer_lib.stringify_slice_key(comparison_key)
+ tfma.stringify_slice_key(comparison_key)
)
data.append({
'sliceValue': stringify_slice_value,
Expand All @@ -123,7 +103,7 @@ def _add_cross_slice_key_data(
def convert_slicing_metrics_to_ui_input(
slicing_metrics: list[
tuple[
tfma.slicer.slicer_lib.SliceKeyOrCrossSliceKeyType,
tfma.types.SliceKeyOrCrossSliceKeyType,
tfma.view.view_types.MetricsByOutputName,
]
],
Expand Down Expand Up @@ -172,7 +152,7 @@ def convert_slicing_metrics_to_ui_input(
):
metrics = metric_value[output_name][multi_class_key]
# To add evaluation data for cross slice comparison.
if tfma.slicer.slicer_lib.is_cross_slice_key(slice_key):
if tfma.types.is_cross_slice_key(slice_key):
_add_cross_slice_key_data(slice_key, metrics, data)
# To add evaluation data for regular slices.
elif (
Expand All @@ -182,7 +162,7 @@ def convert_slicing_metrics_to_ui_input(
):
data.append({
'sliceValue': stringify_slice_key_value(slice_key),
'slice': tfma.slicer.slicer_lib.stringify_slice_key(slice_key),
'slice': tfma.stringify_slice_key(slice_key),
'metrics': metrics,
})
if not data:
Expand All @@ -192,122 +172,3 @@ def convert_slicing_metrics_to_ui_input(
% (output_name, multi_class_key, slicing_column, slicing_spec)
)
return data


class FairnessIndicatorsPlugin(base_plugin.TBPlugin):
"""A plugin to visualize Fairness Indicators."""

plugin_name = metadata.PLUGIN_NAME

def __init__(self, context):
"""Instantiates plugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance. A magic container that
TensorBoard uses to make objects available to the plugin.
"""
self._multiplexer = context.multiplexer

def get_plugin_apps(self):
"""Gets all routes offered by the plugin.
This method is called by TensorBoard when retrieving all the
routes offered by the plugin.
Returns:
A dictionary mapping URL path to route that handles it.
"""
return {
'/get_evaluation_result':
self._get_evaluation_result,
'/get_evaluation_result_from_remote_path':
self._get_evaluation_result_from_remote_path,
'/index.js':
self._serve_js,
'/vulcanized_tfma.js':
self._serve_vulcanized_js,
}

def frontend_metadata(self):
return base_plugin.FrontendMetadata(
es_module_path='/index.js',
disable_reload=False,
tab_name='Fairness Indicators',
remove_dom=False,
element_name=None)

def is_active(self):
"""Determines whether this plugin is active.
This plugin is only active if TensorBoard sampled any summaries
relevant to the plugin.
Returns:
Whether this plugin is active.
"""
return bool(
self._multiplexer.PluginRunToTagToContent(
FairnessIndicatorsPlugin.plugin_name))

# pytype: disable=wrong-arg-types
@wrappers.Request.application
def _serve_js(self, request):
filepath = os.path.join(os.path.dirname(__file__), 'static', 'index.js')
with open(filepath) as infile:
contents = infile.read()
return http_util.Respond(
request, contents, content_type='application/javascript')

@wrappers.Request.application
def _serve_vulcanized_js(self, request):
with open(_TEMPLATE_LOCATION) as infile:
contents = infile.read()
return http_util.Respond(
request, contents, content_type='application/javascript')

@wrappers.Request.application
def _get_evaluation_result(self, request):
run = request.args.get('run')
try:
run = six.ensure_text(run)
except (UnicodeDecodeError, AttributeError):
pass

data = []
try:
eval_result_output_dir = six.ensure_text(
self._multiplexer.Tensors(run, FairnessIndicatorsPlugin.plugin_name)
[0].tensor_proto.string_val[0])
eval_result = tfma.load_eval_result(output_path=eval_result_output_dir)
# TODO(b/141283811): Allow users to choose different model output names
# and class keys in case of multi-output and multi-class model.
data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics)
except (KeyError, json_format.ParseError) as error:
logging.info('Error while fetching evaluation data, %s', error)
return http_util.Respond(request, data, content_type='application/json')

def _get_output_file_format(self, evaluation_output_path):
file_format = os.path.splitext(evaluation_output_path)[1]
if file_format:
return file_format[1:]

return ''

@wrappers.Request.application
def _get_evaluation_result_from_remote_path(self, request):
evaluation_output_path = request.args.get('evaluation_output_path')
try:
evaluation_output_path = six.ensure_text(evaluation_output_path)
except (UnicodeDecodeError, AttributeError):
pass
try:
eval_result = tfma.load_eval_result(
os.path.dirname(evaluation_output_path),
output_file_format=self._get_output_file_format(
evaluation_output_path))
data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics)
except (KeyError, json_format.ParseError) as error:
logging.info('Error while fetching evaluation data, %s', error)
data = []
return http_util.Respond(request, data, content_type='application/json')
# pytype: enable=wrong-arg-types
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from tensorboard_plugin_fairness_indicators import plugin
from tensorboard_plugin_fairness_indicators import summary_v2
import six
from tensorflow import keras
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
# import tensorflow.compat.v2 as tf2
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.utils import example_keras_model
from werkzeug import test as werkzeug_test
Expand All @@ -38,7 +39,7 @@
from tensorboard.plugins import base_plugin

tf.enable_eager_execution()
tf = tf2
# tf = tf2


class PluginTest(tf.test.TestCase):
Expand All @@ -53,7 +54,7 @@ def setUp(self):
if not os.path.isdir(self._eval_result_output_dir):
os.mkdir(self._eval_result_output_dir)

writer = tf.summary.create_file_writer(self._log_dir)
writer = tf.summary.FileWriter(self._log_dir)

with writer.as_default():
summary_v2.FairnessIndicators(self._eval_result_output_dir, step=1)
Expand All @@ -77,7 +78,7 @@ def tearDown(self):

def _export_keras_model(self, classifier):
temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir")
classifier.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse")
classifier.compile(optimizer=keras.optimizers.Adam(), loss="mse")
tf.saved_model.save(classifier, temp_eval_export_dir)
return temp_eval_export_dir

Expand Down

0 comments on commit 46cb27a

Please sign in to comment.