Skip to content

Commit

Permalink
Update tf_estimator_evaluation and keras_evaluation to new API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 328195220
  • Loading branch information
shs037 authored and tensorflower-gardener committed Aug 24, 2020
1 parent 7a77d5d commit f90c78b
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ def __str__(self):
"""Returns AUC and advantage metrics."""
return '\n'.join([
'RocCurve(',
' AUC: %f.02' % self.get_auc(),
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
])


Expand Down Expand Up @@ -324,8 +324,8 @@ def __str__(self):
'SingleAttackResult(',
' SliceSpec: %s' % str(self.slice_spec),
' AttackType: %s' % str(self.attack_type),
' AUC: %f.02' % self.get_auc(),
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
# Lint as: python3
"""A callback and a function in keras for membership inference attack."""

from typing import Iterable

from absl import logging

import tensorflow.compat.v1 as tf

from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard

Expand All @@ -44,20 +50,25 @@ def calculate_losses(model, data, labels):
class MembershipInferenceCallback(tf.keras.callbacks.Callback):
"""Callback to perform membership inference attack on epoch end."""

def __init__(self, in_train, out_train, attack_classifiers,
tensorboard_dir=None):
def __init__(
self,
in_train, out_train,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
tensorboard_dir=None):
"""Initalizes the callback.
Args:
in_train: (in_training samples, in_training labels)
out_train: (out_training samples, out_training labels)
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
slicing_spec: slicing specification of the attack
attack_types: a list of attacks, each of type AttackType
tensorboard_dir: directory for tensorboard summary
"""
self._in_train_data, self._in_train_labels = in_train
self._out_train_data, self._out_train_labels = out_train
self._attack_classifiers = attack_classifiers
self._slicing_spec = slicing_spec
self._attack_types = attack_types
# Setup tensorboard writer if tensorboard_dir is specified
if tensorboard_dir:
with tf.Graph().as_default():
Expand All @@ -71,24 +82,33 @@ def on_epoch_end(self, epoch, logs=None):
self.model,
(self._in_train_data, self._in_train_labels),
(self._out_train_data, self._out_train_labels),
self._attack_classifiers)
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
self._slicing_spec,
self._attack_types)
logging.info(results)

attack_properties, attack_values = get_all_attack_results(results)
print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))

# Write to tensorboard if tensorboard_dir is specified
write_to_tensorboard(self._writer, ['attack advantage'],
[results['all_thresh_loss_advantage']], epoch)
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
epoch)


def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
def run_attack_on_keras_model(
model, in_train, out_train,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
"""Performs the attack on a trained model.
Args:
model: model to be tested
in_train: a (in_training samples, in_training labels) tuple
out_train: a (out_training samples, out_training labels) tuple
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
slicing_spec: slicing specification of the attack
attack_types: a list of attacks, each of type AttackType
Returns:
Results of the attack
"""
Expand All @@ -100,9 +120,12 @@ def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
in_train_labels)
out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
out_train_labels)
results = mia.run_all_attacks(in_train_loss, out_train_loss,
in_train_pred, out_train_pred,
in_train_labels, out_train_labels,
attack_classifiers=attack_classifiers)
attack_input = AttackInputData(
logits_train=in_train_pred, logits_test=out_train_pred,
labels_train=in_train_labels, labels_test=out_train_labels,
loss_train=in_train_loss, loss_test=out_train_loss
)
results = mia.run_attacks(attack_input,
slicing_spec=slicing_spec,
attack_types=attack_types)
return results

Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@

import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results


GradientDescentOptimizer = tf.train.GradientDescentOptimizer

Expand Down Expand Up @@ -78,10 +82,11 @@ def main(unused_argv):
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

# Get callback for membership inference attack.
mia_callback = MembershipInferenceCallback((train_data, train_labels),
(test_data, test_labels),
[],
FLAGS.model_dir)
mia_callback = MembershipInferenceCallback(
(train_data, train_labels),
(test_data, test_labels),
attack_types=[AttackType.THRESHOLD_ATTACK],
tensorboard_dir=FLAGS.model_dir)

# Train model with Keras
model.fit(train_data, train_labels,
Expand All @@ -91,13 +96,18 @@ def main(unused_argv):
callbacks=[mia_callback],
verbose=2)

print('End of training attack')
attack_results = run_attack_on_keras_model(model,
(train_data, train_labels),
(test_data, test_labels),
[])
print('all_thresh_loss_advantage',
attack_results['all_thresh_loss_advantage'])
print('End of training attack:')
attack_results = run_attack_on_keras_model(
model,
(train_data, train_labels),
(test_data, test_labels),
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
)

attack_properties, attack_values = get_all_attack_results(attack_results)
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import tensorflow.compat.v1 as tf

from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results


class UtilsTest(absltest.TestCase):
Expand Down Expand Up @@ -62,10 +65,11 @@ def test_run_attack_on_keras_model(self):
self.model,
(self.train_data, self.train_labels),
(self.test_data, self.test_labels),
[])
self.assertIsInstance(results, dict)
self.assertIn('all_thresh_loss_auc', results)
self.assertIn('all_thresh_loss_advantage', results)
attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults)
attack_properties, attack_values = get_all_attack_results(results)
self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
# Lint as: python3
"""A hook and a function in tf estimator for membership inference attack."""

from typing import Iterable

from absl import logging

import numpy as np

import tensorflow.compat.v1 as tf

from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard

Expand Down Expand Up @@ -49,25 +55,26 @@ def calculate_losses(estimator, input_fn, labels):


class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
"""Training hook to perform membership inference attack after an epoch."""

def __init__(self,
estimator,
in_train,
out_train,
input_fn_constructor,
attack_classifiers,
writer=None):
"""Initalizes the hook.
"""Training hook to perform membership inference attack on epoch end."""

def __init__(
self,
estimator,
in_train, out_train,
input_fn_constructor,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
writer=None):
"""Initialize the hook.
Args:
estimator: model to be tested
in_train: (in_training samples, in_training labels)
out_train: (out_training samples, out_training labels)
input_fn_constructor: a function that receives sample, label and construct
the input_fn for model prediction
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
slicing_spec: slicing specification of the attack
attack_types: a list of attacks, each of type AttackType
writer: summary writer for tensorboard
"""
in_train_data, self._in_train_labels = in_train
Expand All @@ -79,7 +86,8 @@ def __init__(self,
self._out_train_input_fn = input_fn_constructor(out_train_data,
self._out_train_labels)
self._estimator = estimator
self._attack_classifiers = attack_classifiers
self._slicing_spec = slicing_spec
self._attack_types = attack_types
self._writer = writer
if self._writer:
logging.info('Will write to tensorboard.')
Expand All @@ -89,28 +97,37 @@ def end(self, session):
self._in_train_input_fn,
self._out_train_input_fn,
self._in_train_labels, self._out_train_labels,
self._attack_classifiers)
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
self._slicing_spec,
self._attack_types)
logging.info(results)

attack_properties, attack_values = get_all_attack_results(results)
print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)]))

# Write to tensorboard if writer is specified
global_step = self._estimator.get_variable_value('global_step')
write_to_tensorboard(self._writer, ['attack advantage'],
[results['all_thresh_loss_advantage']], global_step)
attack_property_tags = ['attack/' + '_'.join(p) for p in attack_properties]
write_to_tensorboard(self._writer, attack_property_tags, attack_values,
global_step)


def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
input_fn_constructor, attack_classifiers):
"""A function to perform the attack in the end of training.
def run_attack_on_tf_estimator_model(
estimator, in_train, out_train,
input_fn_constructor,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
"""Performs the attack in the end of training.
Args:
estimator: model to be tested
in_train: (in_training samples, in_training labels)
out_train: (out_training samples, out_training labels)
input_fn_constructor: a function that receives sample, label and construct
the input_fn for model prediction
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
slicing_spec: slicing specification of the attack
attack_types: a list of attacks, each of type AttackType
Returns:
Results of the attack
"""
Expand All @@ -125,17 +142,19 @@ def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
results = run_attack_helper(estimator,
in_train_input_fn, out_train_input_fn,
in_train_labels, out_train_labels,
attack_classifiers)
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
slicing_spec,
attack_types)
logging.info('End of training attack:')
logging.info(results)
return results


def run_attack_helper(estimator,
in_train_input_fn, out_train_input_fn,
in_train_labels, out_train_labels,
attack_classifiers):
def run_attack_helper(
estimator,
in_train_input_fn, out_train_input_fn,
in_train_labels, out_train_labels,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)):
"""A helper function to perform attack.
Args:
Expand All @@ -144,8 +163,8 @@ def run_attack_helper(estimator,
out_train_input_fn: input_fn for out of training data
in_train_labels: in training labels
out_train_labels: out of training labels
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
slicing_spec: slicing specification of the attack
attack_types: a list of attacks, each of type AttackType
Returns:
Results of the attack
"""
Expand All @@ -156,9 +175,13 @@ def run_attack_helper(estimator,
out_train_pred, out_train_loss = calculate_losses(estimator,
out_train_input_fn,
out_train_labels)
results = mia.run_all_attacks(in_train_loss, out_train_loss,
in_train_pred, out_train_pred,
in_train_labels, out_train_labels,
attack_classifiers=attack_classifiers)
attack_input = AttackInputData(
logits_train=in_train_pred, logits_test=out_train_pred,
labels_train=in_train_labels, labels_test=out_train_labels,
loss_train=in_train_loss, loss_test=out_train_loss
)
results = mia.run_attacks(attack_input,
slicing_spec=slicing_spec,
attack_types=attack_types)
return results

Loading

0 comments on commit f90c78b

Please sign in to comment.