From acd9745e0904db39993d12c591f6941f39d88e5a Mon Sep 17 00:00:00 2001 From: Madhur Karampudi Date: Tue, 14 Jan 2025 11:56:45 -0800 Subject: [PATCH] Fix fairness-indicators tests broken by keras and v2 compatibility changes. PiperOrigin-RevId: 715469494 --- .../plugin.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py index b03a695..21610c6 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py @@ -22,16 +22,15 @@ from typing import Any, Union from absl import logging -from tensorboard_plugin_fairness_indicators import metadata +from google.protobuf import json_format import six +from tensorboard.backend import http_util +from tensorboard.plugins import base_plugin +from tensorboard_plugin_fairness_indicators import metadata import tensorflow as tf import tensorflow_model_analysis as tfma from werkzeug import wrappers -from google.protobuf import json_format -from tensorboard.backend import http_util -from tensorboard.plugins import base_plugin - _TEMPLATE_LOCATION = os.path.normpath( os.path.join( @@ -40,7 +39,7 @@ def stringify_slice_key_value( - slice_key: tfma.slicer.slicer_lib.SliceKeyType, + slice_key: tfma.types.SliceKeyType, ) -> str: """Stringifies a slice key value. @@ -89,7 +88,7 @@ def stringify_slice_key_value( def _add_cross_slice_key_data( - slice_key: tfma.slicer.slicer_lib.CrossSliceKeyType, + slice_key: tfma.types.CrossSliceKeyType, metrics: tfma.view.view_types.MetricsByTextKey, data: list[Any], ): @@ -109,9 +108,9 @@ def _add_cross_slice_key_data( + stringify_slice_key_value(comparison_key) ) stringify_slice = ( - tfma.slicer.slicer_lib.stringify_slice_key(baseline_key) + tfma.stringify_slice_key(baseline_key) + '__XX__' - + tfma.slicer.slicer_lib.stringify_slice_key(comparison_key) + + tfma.stringify_slice_key(comparison_key) ) data.append({ 'sliceValue': stringify_slice_value, @@ -123,7 +122,7 @@ def _add_cross_slice_key_data( def convert_slicing_metrics_to_ui_input( slicing_metrics: list[ tuple[ - tfma.slicer.slicer_lib.SliceKeyOrCrossSliceKeyType, + tfma.types.SliceKeyOrCrossSliceKeyType, tfma.view.view_types.MetricsByOutputName, ] ], @@ -172,7 +171,7 @@ def convert_slicing_metrics_to_ui_input( ): metrics = metric_value[output_name][multi_class_key] # To add evaluation data for cross slice comparison. - if tfma.slicer.slicer_lib.is_cross_slice_key(slice_key): + if tfma.types.is_cross_slice_key(slice_key): _add_cross_slice_key_data(slice_key, metrics, data) # To add evaluation data for regular slices. elif ( @@ -182,7 +181,7 @@ def convert_slicing_metrics_to_ui_input( ): data.append({ 'sliceValue': stringify_slice_key_value(slice_key), - 'slice': tfma.slicer.slicer_lib.stringify_slice_key(slice_key), + 'slice': tfma.stringify_slice_key(slice_key), 'metrics': metrics, }) if not data: