From f90c78bd5444f77936566e9a67d9a7513fd3be48 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Mon, 24 Aug 2020 13:03:02 -0700 Subject: [PATCH] Update tf_estimator_evaluation and keras_evaluation to new API. PiperOrigin-RevId: 328195220 --- .../data_structures.py | 8 +- .../keras_evaluation.py | 59 ++++++++---- .../keras_evaluation_example.py | 32 ++++--- .../keras_evaluation_test.py | 12 ++- .../tf_estimator_evaluation.py | 93 ++++++++++++------- .../tf_estimator_evaluation_example.py | 39 +++++--- .../tf_estimator_evaluation_test.py | 32 ++++--- .../membership_inference_attack/utils.py | 20 ++++ 8 files changed, 195 insertions(+), 100 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index b78a63f4..082a90ba 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -293,8 +293,8 @@ def __str__(self): """Returns AUC and advantage metrics.""" return '\n'.join([ 'RocCurve(', - ' AUC: %f.02' % self.get_auc(), - ' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')' + ' AUC: %.2f' % self.get_auc(), + ' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')' ]) @@ -324,8 +324,8 @@ def __str__(self): 'SingleAttackResult(', ' SliceSpec: %s' % str(self.slice_spec), ' AttackType: %s' % str(self.attack_type), - ' AUC: %f.02' % self.get_auc(), - ' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')' + ' AUC: %.2f' % self.get_auc(), + ' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')' ]) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py index 57284da2..938c42ec 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py @@ -15,11 +15,17 @@ # Lint as: python3 """A callback and a function in keras for membership inference attack.""" +from typing import Iterable + from absl import logging import tensorflow.compat.v1 as tf -from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec +from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard @@ -44,20 +50,25 @@ def calculate_losses(model, data, labels): class MembershipInferenceCallback(tf.keras.callbacks.Callback): """Callback to perform membership inference attack on epoch end.""" - def __init__(self, in_train, out_train, attack_classifiers, - tensorboard_dir=None): + def __init__( + self, + in_train, out_train, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), + tensorboard_dir=None): """Initalizes the callback. Args: in_train: (in_training samples, in_training labels) out_train: (out_training samples, out_training labels) - attack_classifiers: a list of classifiers to be used by attacker, must be - a subset of ['lr', 'mlp', 'rf', 'knn'] + slicing_spec: slicing specification of the attack + attack_types: a list of attacks, each of type AttackType tensorboard_dir: directory for tensorboard summary """ self._in_train_data, self._in_train_labels = in_train self._out_train_data, self._out_train_labels = out_train - self._attack_classifiers = attack_classifiers + self._slicing_spec = slicing_spec + self._attack_types = attack_types # Setup tensorboard writer if tensorboard_dir is specified if tensorboard_dir: with tf.Graph().as_default(): @@ -71,24 +82,33 @@ def on_epoch_end(self, epoch, logs=None): self.model, (self._in_train_data, self._in_train_labels), (self._out_train_data, self._out_train_labels), - self._attack_classifiers) - print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) + self._slicing_spec, + self._attack_types) logging.info(results) + attack_properties, attack_values = get_all_attack_results(results) + print('Attack result:') + print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in + zip(attack_properties, attack_values)])) + # Write to tensorboard if tensorboard_dir is specified - write_to_tensorboard(self._writer, ['attack advantage'], - [results['all_thresh_loss_advantage']], epoch) + attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties] + write_to_tensorboard(self._writer, attack_property_tags, attack_values, + epoch) -def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers): +def run_attack_on_keras_model( + model, in_train, out_train, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)): """Performs the attack on a trained model. Args: model: model to be tested in_train: a (in_training samples, in_training labels) tuple out_train: a (out_training samples, out_training labels) tuple - attack_classifiers: a list of classifiers to be used by attacker, must be - a subset of ['lr', 'mlp', 'rf', 'knn'] + slicing_spec: slicing specification of the attack + attack_types: a list of attacks, each of type AttackType Returns: Results of the attack """ @@ -100,9 +120,12 @@ def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers): in_train_labels) out_train_pred, out_train_loss = calculate_losses(model, out_train_data, out_train_labels) - results = mia.run_all_attacks(in_train_loss, out_train_loss, - in_train_pred, out_train_pred, - in_train_labels, out_train_labels, - attack_classifiers=attack_classifiers) + attack_input = AttackInputData( + logits_train=in_train_pred, logits_test=out_train_pred, + labels_train=in_train_labels, labels_test=out_train_labels, + loss_train=in_train_loss, loss_test=out_train_loss + ) + results = mia.run_attacks(attack_input, + slicing_spec=slicing_spec, + attack_types=attack_types) return results - diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py index 3ddfd1aa..2cdc0290 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py @@ -20,8 +20,12 @@ import numpy as np import tensorflow.compat.v1 as tf +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model +from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results + GradientDescentOptimizer = tf.train.GradientDescentOptimizer @@ -78,10 +82,11 @@ def main(unused_argv): model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) # Get callback for membership inference attack. - mia_callback = MembershipInferenceCallback((train_data, train_labels), - (test_data, test_labels), - [], - FLAGS.model_dir) + mia_callback = MembershipInferenceCallback( + (train_data, train_labels), + (test_data, test_labels), + attack_types=[AttackType.THRESHOLD_ATTACK], + tensorboard_dir=FLAGS.model_dir) # Train model with Keras model.fit(train_data, train_labels, @@ -91,13 +96,18 @@ def main(unused_argv): callbacks=[mia_callback], verbose=2) - print('End of training attack') - attack_results = run_attack_on_keras_model(model, - (train_data, train_labels), - (test_data, test_labels), - []) - print('all_thresh_loss_advantage', - attack_results['all_thresh_loss_advantage']) + print('End of training attack:') + attack_results = run_attack_on_keras_model( + model, + (train_data, train_labels), + (test_data, test_labels), + slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), + attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] + ) + + attack_properties, attack_values = get_all_attack_results(attack_results) + print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in + zip(attack_properties, attack_values)])) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py index 916bc21f..8ce5c19f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py @@ -21,6 +21,9 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results class UtilsTest(absltest.TestCase): @@ -62,10 +65,11 @@ def test_run_attack_on_keras_model(self): self.model, (self.train_data, self.train_labels), (self.test_data, self.test_labels), - []) - self.assertIsInstance(results, dict) - self.assertIn('all_thresh_loss_auc', results) - self.assertIn('all_thresh_loss_advantage', results) + attack_types=[AttackType.THRESHOLD_ATTACK]) + self.assertIsInstance(results, AttackResults) + attack_properties, attack_values = get_all_attack_results(results) + self.assertLen(attack_properties, 2) + self.assertLen(attack_values, 2) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py index 693820ea..954ad5be 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py @@ -15,13 +15,19 @@ # Lint as: python3 """A hook and a function in tf estimator for membership inference attack.""" +from typing import Iterable + from absl import logging import numpy as np import tensorflow.compat.v1 as tf -from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec +from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard @@ -49,16 +55,17 @@ def calculate_losses(estimator, input_fn, labels): class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): - """Training hook to perform membership inference attack after an epoch.""" - - def __init__(self, - estimator, - in_train, - out_train, - input_fn_constructor, - attack_classifiers, - writer=None): - """Initalizes the hook. + """Training hook to perform membership inference attack on epoch end.""" + + def __init__( + self, + estimator, + in_train, out_train, + input_fn_constructor, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), + writer=None): + """Initialize the hook. Args: estimator: model to be tested @@ -66,8 +73,8 @@ def __init__(self, out_train: (out_training samples, out_training labels) input_fn_constructor: a function that receives sample, label and construct the input_fn for model prediction - attack_classifiers: a list of classifiers to be used by attacker, must be - a subset of ['lr', 'mlp', 'rf', 'knn'] + slicing_spec: slicing specification of the attack + attack_types: a list of attacks, each of type AttackType writer: summary writer for tensorboard """ in_train_data, self._in_train_labels = in_train @@ -79,7 +86,8 @@ def __init__(self, self._out_train_input_fn = input_fn_constructor(out_train_data, self._out_train_labels) self._estimator = estimator - self._attack_classifiers = attack_classifiers + self._slicing_spec = slicing_spec + self._attack_types = attack_types self._writer = writer if self._writer: logging.info('Will write to tensorboard.') @@ -89,19 +97,28 @@ def end(self, session): self._in_train_input_fn, self._out_train_input_fn, self._in_train_labels, self._out_train_labels, - self._attack_classifiers) - print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) + self._slicing_spec, + self._attack_types) logging.info(results) + attack_properties, attack_values = get_all_attack_results(results) + print('Attack result:') + print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in + zip(attack_properties, attack_values)])) + # Write to tensorboard if writer is specified global_step = self._estimator.get_variable_value('global_step') - write_to_tensorboard(self._writer, ['attack advantage'], - [results['all_thresh_loss_advantage']], global_step) + attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties] + write_to_tensorboard(self._writer, attack_property_tags, attack_values, + global_step) -def run_attack_on_tf_estimator_model(estimator, in_train, out_train, - input_fn_constructor, attack_classifiers): - """A function to perform the attack in the end of training. +def run_attack_on_tf_estimator_model( + estimator, in_train, out_train, + input_fn_constructor, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)): + """Performs the attack in the end of training. Args: estimator: model to be tested @@ -109,8 +126,8 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train, out_train: (out_training samples, out_training labels) input_fn_constructor: a function that receives sample, label and construct the input_fn for model prediction - attack_classifiers: a list of classifiers to be used by attacker, must be - a subset of ['lr', 'mlp', 'rf', 'knn'] + slicing_spec: slicing specification of the attack + attack_types: a list of attacks, each of type AttackType Returns: Results of the attack """ @@ -125,17 +142,19 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train, results = run_attack_helper(estimator, in_train_input_fn, out_train_input_fn, in_train_labels, out_train_labels, - attack_classifiers) - print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) + slicing_spec, + attack_types) logging.info('End of training attack:') logging.info(results) return results -def run_attack_helper(estimator, - in_train_input_fn, out_train_input_fn, - in_train_labels, out_train_labels, - attack_classifiers): +def run_attack_helper( + estimator, + in_train_input_fn, out_train_input_fn, + in_train_labels, out_train_labels, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)): """A helper function to perform attack. Args: @@ -144,8 +163,8 @@ def run_attack_helper(estimator, out_train_input_fn: input_fn for out of training data in_train_labels: in training labels out_train_labels: out of training labels - attack_classifiers: a list of classifiers to be used by attacker, must be - a subset of ['lr', 'mlp', 'rf', 'knn'] + slicing_spec: slicing specification of the attack + attack_types: a list of attacks, each of type AttackType Returns: Results of the attack """ @@ -156,9 +175,13 @@ def run_attack_helper(estimator, out_train_pred, out_train_loss = calculate_losses(estimator, out_train_input_fn, out_train_labels) - results = mia.run_all_attacks(in_train_loss, out_train_loss, - in_train_pred, out_train_pred, - in_train_labels, out_train_labels, - attack_classifiers=attack_classifiers) + attack_input = AttackInputData( + logits_train=in_train_pred, logits_test=out_train_pred, + labels_train=in_train_labels, labels_test=out_train_labels, + loss_train=in_train_loss, loss_test=out_train_loss + ) + results = mia.run_attacks(attack_input, + slicing_spec=slicing_spec, + attack_types=attack_types) return results diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py index 94e71830..6d4a1bac 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py @@ -21,9 +21,11 @@ import numpy as np import tensorflow.compat.v1 as tf - +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model +from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results GradientDescentOptimizer = tf.train.GradientDescentOptimizer @@ -97,9 +99,9 @@ def load_mnist(): def main(unused_argv): - tf.logging.set_verbosity(tf.logging.INFO) - logging.set_verbosity(logging.INFO) - logging.set_stderrthreshold(logging.INFO) + tf.logging.set_verbosity(tf.logging.ERROR) + logging.set_verbosity(logging.ERROR) + logging.set_stderrthreshold(logging.ERROR) logging.get_absl_handler().use_absl_log_file() # Load training and test data. @@ -121,12 +123,13 @@ def input_fn_constructor(x, y): summary_writer = tf.summary.FileWriter(FLAGS.model_dir) else: summary_writer = None - mia_hook = MembershipInferenceTrainingHook(mnist_classifier, - (train_data, train_labels), - (test_data, test_labels), - input_fn_constructor, - [], - summary_writer) + mia_hook = MembershipInferenceTrainingHook( + mnist_classifier, + (train_data, train_labels), + (test_data, test_labels), + input_fn_constructor, + attack_types=[AttackType.THRESHOLD_ATTACK], + writer=summary_writer) # Create tf.Estimator input functions for the training and test data. train_input_fn = tf.estimator.inputs.numpy_input_fn( @@ -151,11 +154,17 @@ def input_fn_constructor(x, y): print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) print('End of training attack') - run_attack_on_tf_estimator_model(mnist_classifier, - (train_data, train_labels), - (test_data, test_labels), - input_fn_constructor, - ['lr']) + attack_results = run_attack_on_tf_estimator_model( + mnist_classifier, + (train_data, train_labels), + (test_data, test_labels), + input_fn_constructor, + slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), + attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] + ) + attack_properties, attack_values = get_all_attack_results(attack_results) + print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in + zip(attack_properties, attack_values)])) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py index fc73843f..bfb15855 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py @@ -21,6 +21,9 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results class UtilsTest(absltest.TestCase): @@ -77,15 +80,17 @@ def test_calculate_losses(self): def test_run_attack_helper(self): """Test the attack.""" - results = tf_estimator_evaluation.run_attack_helper(self.classifier, - self.input_fn_train, - self.input_fn_test, - self.train_labels, - self.test_labels, - []) - self.assertIsInstance(results, dict) - self.assertIn('all_thresh_loss_auc', results) - self.assertIn('all_thresh_loss_advantage', results) + results = tf_estimator_evaluation.run_attack_helper( + self.classifier, + self.input_fn_train, + self.input_fn_test, + self.train_labels, + self.test_labels, + attack_types=[AttackType.THRESHOLD_ATTACK]) + self.assertIsInstance(results, AttackResults) + attack_properties, attack_values = get_all_attack_results(results) + self.assertLen(attack_properties, 2) + self.assertLen(attack_values, 2) def test_run_attack_on_tf_estimator_model(self): """Test the attack on the final models.""" @@ -97,10 +102,11 @@ def input_fn_constructor(x, y): (self.train_data, self.train_labels), (self.test_data, self.test_labels), input_fn_constructor, - []) - self.assertIsInstance(results, dict) - self.assertIn('all_thresh_loss_auc', results) - self.assertIn('all_thresh_loss_advantage', results) + attack_types=[AttackType.THRESHOLD_ATTACK]) + self.assertIsInstance(results, AttackResults) + attack_properties, attack_values = get_all_attack_results(results) + self.assertLen(attack_properties, 2) + self.assertLen(attack_values, 2) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/membership_inference_attack/utils.py index d3aa5c69..cb2c660d 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils.py @@ -20,6 +20,7 @@ import numpy as np from sklearn import metrics import tensorflow.compat.v1 as tf +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults ArrayDict = Dict[Text, np.ndarray] @@ -73,6 +74,25 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]: return {prefix + k: v for k, v in in_dict.items()} +# ------------------------------------------------------------------------------ +# Utilities for managing result. +# ------------------------------------------------------------------------------ + + +def get_all_attack_results(results: AttackResults): + """Get all results as a list of attack properties and a list of attack result.""" + properties = [] + values = [] + for attack_result in results.single_attack_results: + slice_spec = attack_result.slice_spec + prop = [str(slice_spec), str(attack_result.attack_type)] + properties += [prop + ['adv'], prop + ['auc']] + values += [float(attack_result.get_attacker_advantage()), + float(attack_result.get_auc())] + + return properties, values + + # ------------------------------------------------------------------------------ # Subsampling and data selection functionality # ------------------------------------------------------------------------------