Skip to content

Commit 68565c5

Browse files
zhouhao138Responsible ML Infra Team
authored andcommitted
Fix the bug in fairness indicators due to deprecation of tf estimator and feature column utils.
PiperOrigin-RevId: 627145650
1 parent 8f2dc52 commit 68565c5

File tree

3 files changed

+150
-143
lines changed

3 files changed

+150
-143
lines changed

RELEASE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Current Version (Still in Development)
44

55
## Major Features and Improvements
6-
6+
Update example model to use Keras models instead of estimators.
77
## Bug Fixes and Other Changes
88

99
* Deprecated python 3.8 support

fairness_indicators/example_model.py

Lines changed: 65 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -14,132 +14,92 @@
1414
# ==============================================================================
1515
"""Demo script to train and evaluate a model.
1616
17-
This scripts contains boilerplate code to train a DNNClassifier
17+
This scripts contains boilerplate code to train a Keras Text Classifier
1818
and evaluate it using Tensorflow Model Analysis. Evaluation
1919
results can be visualized using tools like TensorBoard.
20-
21-
Usage:
22-
23-
1. Train model:
24-
demo_script.train_model(...)
25-
26-
2. Evaluate:
27-
demo_script.evaluate_model(...)
2820
"""
2921

30-
import os
31-
import tempfile
22+
from tensorflow import keras
3223
import tensorflow.compat.v1 as tf
33-
from tensorflow.compat.v1 import estimator as tf_estimator
34-
import tensorflow_hub as hub
3524
import tensorflow_model_analysis as tfma
3625
from tensorflow_model_analysis.addons.fairness.post_export_metrics import fairness_indicators # pylint: disable=unused-import
3726

3827

39-
def train_model(model_dir,
40-
train_tf_file,
41-
label,
42-
text_feature,
43-
feature_map,
44-
module_spec='https://tfhub.dev/google/nnlm-en-dim128/1'):
45-
"""Train model using DNN Classifier.
46-
47-
Args:
48-
model_dir: Directory path to save trained model.
49-
train_tf_file: File containing training TFRecordDataset.
50-
label: Groundtruth label.
51-
text_feature: Text feature to be evaluated.
52-
feature_map: Dict of feature names to their data type.
53-
module_spec: A module spec defining the module to instantiate or a path
54-
where to load a module spec.
55-
56-
Returns:
57-
Trained DNNClassifier.
58-
"""
59-
60-
def train_input_fn():
61-
"""Train Input function."""
62-
63-
def parse_function(serialized):
64-
parsed_example = tf.io.parse_single_example(
65-
serialized=serialized, features=feature_map)
66-
# Adds a weight column to deal with unbalanced classes.
67-
parsed_example['weight'] = tf.add(parsed_example[label], 0.1)
68-
return (parsed_example, parsed_example[label])
69-
70-
train_dataset = tf.data.TFRecordDataset(
71-
filenames=[train_tf_file]).map(parse_function).batch(512)
72-
return train_dataset
28+
TEXT_FEATURE = 'comment_text'
29+
LABEL = 'toxicity'
30+
SLICE = 'slice'
31+
FEATURE_MAP = {
32+
LABEL: tf.io.FixedLenFeature([], tf.float32),
33+
TEXT_FEATURE: tf.io.FixedLenFeature([], tf.string),
34+
SLICE: tf.io.VarLenFeature(tf.string),
35+
}
7336

74-
text_embedding_column = hub.text_embedding_column(
75-
key=text_feature, module_spec=module_spec)
7637

77-
classifier = tf_estimator.DNNClassifier(
78-
hidden_units=[500, 100],
79-
weight_column='weight',
80-
feature_columns=[text_embedding_column],
81-
n_classes=2,
82-
optimizer=tf.train.AdagradOptimizer(learning_rate=0.003),
83-
model_dir=model_dir)
38+
class ExampleParser(keras.layers.Layer):
39+
"""A Keras layer that parses the tf.Example."""
8440

85-
classifier.train(input_fn=train_input_fn, steps=1000)
86-
return classifier
41+
def __init__(self, input_feature_key):
42+
self._input_feature_key = input_feature_key
43+
super().__init__()
8744

88-
89-
def evaluate_model(classifier, validate_tf_file, tfma_eval_result_path,
90-
selected_slice, label, feature_map):
45+
def call(self, serialized_examples):
46+
def get_feature(serialized_example):
47+
parsed_example = tf.io.parse_single_example(
48+
serialized_example, features=FEATURE_MAP
49+
)
50+
return parsed_example[self._input_feature_key]
51+
52+
return tf.map_fn(get_feature, serialized_examples)
53+
54+
55+
class ExampleModel(keras.Model):
56+
"""A Example Keras NLP model."""
57+
58+
def __init__(self, input_feature_key):
59+
super().__init__()
60+
self.parser = ExampleParser(input_feature_key)
61+
self.text_vectorization = keras.layers.TextVectorization(
62+
max_tokens=32,
63+
output_mode='int',
64+
output_sequence_length=32,
65+
)
66+
self.text_vectorization.adapt(
67+
['nontoxic', 'toxic comment', 'test comment', 'abc', 'abcdef', 'random']
68+
)
69+
self.dense1 = keras.layers.Dense(32, activation='relu')
70+
self.dense2 = keras.layers.Dense(1)
71+
72+
def call(self, inputs, training=True, mask=None):
73+
parsed_example = self.parser(inputs)
74+
text_vector = self.text_vectorization(parsed_example)
75+
output1 = self.dense1(tf.cast(text_vector, tf.float32))
76+
output2 = self.dense2(output1)
77+
return output2
78+
79+
80+
def evaluate_model(
81+
classifier_model_path,
82+
validate_tf_file_path,
83+
tfma_eval_result_path,
84+
eval_config,
85+
):
9186
"""Evaluate Model using Tensorflow Model Analysis.
9287
9388
Args:
94-
classifier: Trained classifier model to be evaluted.
95-
validate_tf_file: File containing validation TFRecordDataset.
96-
tfma_eval_result_path: Directory path where eval results will be written.
97-
selected_slice: Feature for slicing the data.
98-
label: Groundtruth label.
99-
feature_map: Dict of feature names to their data type.
89+
classifier_model_path: Trained classifier model to be evaluted.
90+
validate_tf_file_path: File containing validation TFRecordDataset.
91+
tfma_eval_result_path: Path to export tfma-related eval path.
92+
eval_config: tfma eval_config.
10093
"""
10194

102-
def eval_input_receiver_fn():
103-
"""Eval Input Receiver function."""
104-
serialized_tf_example = tf.compat.v1.placeholder(
105-
dtype=tf.string, shape=[None], name='input_example_placeholder')
106-
107-
receiver_tensors = {'examples': serialized_tf_example}
108-
109-
features = tf.io.parse_example(serialized_tf_example, feature_map)
110-
features['weight'] = tf.ones_like(features[label])
111-
112-
return tfma.export.EvalInputReceiver(
113-
features=features,
114-
receiver_tensors=receiver_tensors,
115-
labels=features[label])
116-
117-
tfma_export_dir = tfma.export.export_eval_savedmodel(
118-
estimator=classifier,
119-
export_dir_base=os.path.join(tempfile.gettempdir(), 'tfma_eval_model'),
120-
eval_input_receiver_fn=eval_input_receiver_fn)
121-
122-
# Define slices that you want the evaluation to run on.
123-
slice_spec = [
124-
tfma.slicer.SingleSliceSpec(), # Overall slice
125-
tfma.slicer.SingleSliceSpec(columns=[selected_slice]),
126-
]
127-
128-
# Add the fairness metrics.
129-
# pytype: disable=module-attr
130-
add_metrics_callbacks = [
131-
tfma.post_export_metrics.fairness_indicators(
132-
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
133-
]
134-
# pytype: enable=module-attr
135-
13695
eval_shared_model = tfma.default_eval_shared_model(
137-
eval_saved_model_path=tfma_export_dir,
138-
add_metrics_callbacks=add_metrics_callbacks)
96+
eval_saved_model_path=classifier_model_path, eval_config=eval_config
97+
)
13998

14099
# Run the fairness evaluation.
141100
tfma.run_model_analysis(
142101
eval_shared_model=eval_shared_model,
143-
data_location=validate_tf_file,
102+
data_location=validate_tf_file_path,
144103
output_path=tfma_eval_result_path,
145-
slice_spec=slice_spec)
104+
eval_config=eval_config,
105+
)

fairness_indicators/example_model_test.py

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,17 @@
2121
import datetime
2222
import os
2323
import tempfile
24+
2425
from fairness_indicators import example_model
26+
import numpy as np
2527
import six
28+
from tensorflow import keras
2629
import tensorflow.compat.v1 as tf
2730
import tensorflow_model_analysis as tfma
28-
from tensorflow_model_analysis.slicer import slicer_lib as slicer
2931

30-
tf.compat.v1.enable_eager_execution()
32+
from google.protobuf import text_format
3133

32-
TEXT_FEATURE = 'comment_text'
33-
LABEL = 'toxicity'
34-
SLICE = 'slice'
35-
FEATURE_MAP = {
36-
LABEL: tf.io.FixedLenFeature([], tf.float32),
37-
TEXT_FEATURE: tf.io.FixedLenFeature([], tf.string),
38-
SLICE: tf.io.VarLenFeature(tf.string),
39-
}
34+
tf.compat.v1.enable_eager_execution()
4035

4136

4237
class ExampleModelTest(tf.test.TestCase):
@@ -51,13 +46,13 @@ def setUp(self):
5146

5247
def _create_example(self, comment_text, label, slice_value):
5348
example = tf.train.Example()
54-
example.features.feature[TEXT_FEATURE].bytes_list.value[:] = [
49+
example.features.feature[example_model.TEXT_FEATURE].bytes_list.value[:] = [
5550
six.ensure_binary(comment_text, 'utf8')
5651
]
57-
example.features.feature[SLICE].bytes_list.value[:] = [
52+
example.features.feature[example_model.SLICE].bytes_list.value[:] = [
5853
six.ensure_binary(slice_value, 'utf8')
5954
]
60-
example.features.feature[LABEL].float_list.value[:] = [label]
55+
example.features.feature[example_model.LABEL].float_list.value[:] = [label]
6156
return example
6257

6358
def _create_data(self):
@@ -85,34 +80,86 @@ def _write_tf_records(self, examples):
8580
return data_location
8681

8782
def test_example_model(self):
88-
train_tf_file = self._write_tf_records(self._create_data())
89-
classifier = example_model.train_model(self._model_dir, train_tf_file,
90-
LABEL, TEXT_FEATURE, FEATURE_MAP)
91-
92-
validate_tf_file = self._write_tf_records(self._create_data())
83+
data = self._create_data()
84+
classifier = example_model.ExampleModel(example_model.TEXT_FEATURE)
85+
classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse')
86+
print([e.SerializeToString() for e in data])
87+
classifier.predict(tf.constant([e.SerializeToString() for e in data]))
88+
classifier.fit(
89+
tf.constant([e.SerializeToString() for e in data]),
90+
np.array([
91+
e.features.feature[example_model.LABEL].float_list.value[:][0]
92+
for e in data
93+
]),
94+
)
95+
classifier.save(self._model_dir, save_format='tf')
96+
97+
eval_config = text_format.Parse(
98+
"""
99+
model_specs {
100+
signature_name: "serving_default"
101+
prediction_key: "predictions" # placeholder
102+
label_key: "toxicity" # placeholder
103+
}
104+
slicing_specs {}
105+
slicing_specs {
106+
feature_keys: ["slice"]
107+
}
108+
metrics_specs {
109+
metrics {
110+
class_name: "ExampleCount"
111+
}
112+
metrics {
113+
class_name: "FairnessIndicators"
114+
}
115+
}
116+
""",
117+
tfma.EvalConfig(),
118+
)
119+
120+
validate_tf_file_path = self._write_tf_records(data)
93121
tfma_eval_result_path = os.path.join(self._model_dir, 'tfma_eval_result')
94-
example_model.evaluate_model(classifier, validate_tf_file,
95-
tfma_eval_result_path, SLICE, LABEL,
96-
FEATURE_MAP)
122+
example_model.evaluate_model(
123+
self._model_dir,
124+
validate_tf_file_path,
125+
tfma_eval_result_path,
126+
eval_config,
127+
)
97128

98-
expected_slice_keys = [
99-
'Overall', 'slice:slice3', 'slice:slice1', 'slice:slice2'
100-
]
101129
evaluation_results = tfma.load_eval_result(tfma_eval_result_path)
102130

103-
self.assertLen(evaluation_results.slicing_metrics, 4)
104-
105-
# Verify if false_positive_rate metrics are computed for all values of
106-
# slice.
107-
for (slice_key, metric_value) in evaluation_results.slicing_metrics:
108-
slice_key = slicer.stringify_slice_key(slice_key)
109-
self.assertIn(slice_key, expected_slice_keys)
110-
self.assertGreaterEqual(
111-
1.0, metric_value['']['']
112-
['post_export_metrics/false_positive_rate@0.50']['doubleValue'])
113-
self.assertLessEqual(
114-
0.0, metric_value['']['']
115-
['post_export_metrics/false_positive_rate@0.50']['doubleValue'])
131+
expected_slice_keys = [
132+
(),
133+
(('slice', 'slice1'),),
134+
(('slice', 'slice2'),),
135+
(('slice', 'slice3'),),
136+
]
137+
slice_keys = [
138+
slice_key for slice_key, _ in evaluation_results.slicing_metrics
139+
]
140+
self.assertEqual(set(expected_slice_keys), set(slice_keys))
141+
# Verify part of the metrics of fairness indicators
142+
metric_values = dict(evaluation_results.slicing_metrics)[(
143+
('slice', 'slice1'),
144+
)]['']['']
145+
self.assertEqual(metric_values['example_count'], {'doubleValue': 5.0})
146+
147+
self.assertEqual(
148+
metric_values['fairness_indicators_metrics/false_positive_rate@0.1'],
149+
{'doubleValue': 0.0},
150+
)
151+
self.assertEqual(
152+
metric_values['fairness_indicators_metrics/false_negative_rate@0.1'],
153+
{'doubleValue': 1.0},
154+
)
155+
self.assertEqual(
156+
metric_values['fairness_indicators_metrics/true_positive_rate@0.1'],
157+
{'doubleValue': 0.0},
158+
)
159+
self.assertEqual(
160+
metric_values['fairness_indicators_metrics/true_negative_rate@0.1'],
161+
{'doubleValue': 1.0},
162+
)
116163

117164

118165
if __name__ == '__main__':

0 commit comments

Comments
 (0)