Skip to content

Commit

Permalink
Fix fairness-indicators tests broken by keras and v2 compatibility ch…
Browse files Browse the repository at this point in the history
…anges.

PiperOrigin-RevId: 715469494
  • Loading branch information
vkarampudi authored and Responsible ML Infra Team committed Jan 15, 2025
1 parent 2e23c83 commit acd9745
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@
from typing import Any, Union

from absl import logging
from tensorboard_plugin_fairness_indicators import metadata
from google.protobuf import json_format
import six
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin
from tensorboard_plugin_fairness_indicators import metadata
import tensorflow as tf
import tensorflow_model_analysis as tfma
from werkzeug import wrappers

from google.protobuf import json_format
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin


_TEMPLATE_LOCATION = os.path.normpath(
os.path.join(
Expand All @@ -40,7 +39,7 @@


def stringify_slice_key_value(
slice_key: tfma.slicer.slicer_lib.SliceKeyType,
slice_key: tfma.types.SliceKeyType,
) -> str:
"""Stringifies a slice key value.
Expand Down Expand Up @@ -89,7 +88,7 @@ def stringify_slice_key_value(


def _add_cross_slice_key_data(
slice_key: tfma.slicer.slicer_lib.CrossSliceKeyType,
slice_key: tfma.types.CrossSliceKeyType,
metrics: tfma.view.view_types.MetricsByTextKey,
data: list[Any],
):
Expand All @@ -109,9 +108,9 @@ def _add_cross_slice_key_data(
+ stringify_slice_key_value(comparison_key)
)
stringify_slice = (
tfma.slicer.slicer_lib.stringify_slice_key(baseline_key)
tfma.stringify_slice_key(baseline_key)
+ '__XX__'
+ tfma.slicer.slicer_lib.stringify_slice_key(comparison_key)
+ tfma.stringify_slice_key(comparison_key)
)
data.append({
'sliceValue': stringify_slice_value,
Expand All @@ -123,7 +122,7 @@ def _add_cross_slice_key_data(
def convert_slicing_metrics_to_ui_input(
slicing_metrics: list[
tuple[
tfma.slicer.slicer_lib.SliceKeyOrCrossSliceKeyType,
tfma.types.SliceKeyOrCrossSliceKeyType,
tfma.view.view_types.MetricsByOutputName,
]
],
Expand Down Expand Up @@ -172,7 +171,7 @@ def convert_slicing_metrics_to_ui_input(
):
metrics = metric_value[output_name][multi_class_key]
# To add evaluation data for cross slice comparison.
if tfma.slicer.slicer_lib.is_cross_slice_key(slice_key):
if tfma.types.is_cross_slice_key(slice_key):
_add_cross_slice_key_data(slice_key, metrics, data)
# To add evaluation data for regular slices.
elif (
Expand All @@ -182,7 +181,7 @@ def convert_slicing_metrics_to_ui_input(
):
data.append({
'sliceValue': stringify_slice_key_value(slice_key),
'slice': tfma.slicer.slicer_lib.stringify_slice_key(slice_key),
'slice': tfma.stringify_slice_key(slice_key),
'metrics': metrics,
})
if not data:
Expand Down

0 comments on commit acd9745

Please sign in to comment.