Skip to content

Commit

Permalink
Add Positive Predictive Value as a metric for membership attack model…
Browse files Browse the repository at this point in the history
…s performance on imbalanced data.

PiperOrigin-RevId: 461390184
  • Loading branch information
tensorflower-gardener committed Jul 16, 2022
1 parent 328795a commit 2b5d5b6
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
import pandas as pd
from scipy import special
from sklearn import metrics
import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils as utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils

# The minimum TPR or FPR below which they are considered equal.
_ABSOLUTE_TOLERANCE = 1e-3

ENTIRE_DATASET_SLICE_STR = 'Entire dataset'

Expand Down Expand Up @@ -116,7 +119,6 @@ class AttackType(enum.Enum):
K_NEAREST_NEIGHBORS = 'knn'
THRESHOLD_ATTACK = 'threshold'
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
TF_LOGISTIC_REGRESSION = 'tf_lr'

@property
def is_trained_attack(self):
Expand All @@ -133,6 +135,7 @@ class PrivacyMetric(enum.Enum):
"""An enum for the supported privacy risk metrics."""
AUC = 'AUC'
ATTACKER_ADVANTAGE = 'Attacker advantage'
PPV = 'Positive predictive value'

def __str__(self):
"""Returns 'AUC' instead of PrivacyMetric.AUC."""
Expand Down Expand Up @@ -627,6 +630,11 @@ class RocCurve:
# False positive rates based on thresholds
fpr: np.ndarray

# Ratio of test to train set size.
# In Jayaraman et al. (https://arxiv.org/pdf/2005.10881.pdf) it is referred to
# as 'gamma' (see Table 1 for the definition).
test_train_ratio: np.float64

def get_auc(self):
"""Calculates area under curve (aka AUC)."""
return metrics.auc(self.fpr, self.tpr)
Expand All @@ -643,12 +651,69 @@ def get_attacker_advantage(self):
"""
return max(np.abs(self.tpr - self.fpr))

def get_ppv(self) -> float:
"""Calculates Positive Predictive Value of the membership attacker.
The Positive Predictive Value (PPV) is the proportion of positive
predictions that are true positives. It can be expressed as PPV=TP/(TP+FP).
It was suggested by Jayaraman et al. (https://arxiv.org/pdf/2005.10881.pdf)
that this would be a suitable metric for membership attacks on datasets
where the number of samples from the training set and the number of samples
from the test set are very different. These are referred to as imbalanced
datasets.
Returns:
A single float number for the Positive Predictive Value.
"""

# The Positive Predictive Value (PPV) is the proportion of positive
# predictions that are true positives. It is expressed as PPV=TP/(TP+FP).
# It was suggested by Jayaraman et al.
# (https://arxiv.org/pdf/2005.10881.pdf) that this would be a suitable
# metric for membership attack models trained on datasets where the number
# of samples from the training set and the number of samples from the test
# set are very different. These are referred to as imbalanced datasets.
num = np.asarray(self.tpr)
den = num + np.asarray([r * self.test_train_ratio for r in self.fpr])
# There is a special case when both `num` and `den` are 0. Both would be 0
# when TPR and FPR are both 0, since test_train_ratio is strictly positive
# (exclude the case when there is no test set). Then TPR = 0 means that all
# positive (train set) examples are misclassified and FPR = 0 means that all
# negatives (test set) were correctly classified.
# Consider that when TPR and FPR are close to 0, TPR ~ FPR. Call this value
# 'R'. So the expression for PPV can be rewritten as:
# PPV = R / (R + test_train_ratio * R) = 1 / (1 + test_train_ratio).
# We can check this expression when test_train_ratio = 0, i.e. there is no
# test set, then PPV = 1 (perfect classification). When
# test_train_ratio >> 0, i,e, the test set size >> train set size, and
# PPV = 0 (perfect mis-classification).
# When test_train_ratio = 1, test and train sets are of the same size, and
# PPV = 0.5 (random guessing). This is because TPR = 0 means all positives
# are misclassified (i.e. classified as negatives) and FPR = 0 means all
# negatives are correctly classified (i.e. classified as neegatives).
# The normal case is when both `num` and `den` are not 0, and PPV is just
# the ratio of `num` to `den`.

# Find when `tpr` and `fpr` are 0.
tpr_is_0 = np.isclose(self.tpr, 0.0, atol=_ABSOLUTE_TOLERANCE)
fpr_is_0 = np.isclose(self.fpr, 0.0, atol=_ABSOLUTE_TOLERANCE)
tpr_and_fpr_both_0 = np.logical_and(tpr_is_0, fpr_is_0)
# PPV when both are zero is given by the expression below.
ppv_when_tpr_fpr_both_0 = 1. / (1. + self.test_train_ratio)
# PPV when one is not zero is given by the expression below.
ppv_when_one_of_tpr_fpr_not_0 = np.divide(
num, den, out=np.zeros_like(den), where=den != 0)
return np.max(
np.where(tpr_and_fpr_both_0, ppv_when_tpr_fpr_both_0,
ppv_when_one_of_tpr_fpr_not_0))

def __str__(self):
"""Returns AUC and advantage metrics."""
"""Returns AUC, advantage and PPV metrics."""
return '\n'.join([
'RocCurve(',
' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
' Attacker advantage: %.2f' % self.get_attacker_advantage(),
' Positive predictive value: %.2f' % self.get_ppv(), ')'
])


Expand Down Expand Up @@ -695,6 +760,11 @@ class SingleAttackResult:
def get_attacker_advantage(self):
return self.roc_curve.get_attacker_advantage()

def get_ppv(self) -> float:
if self.data_size.ntrain == 0:
raise ValueError('Size of the training data cannot be zero.')
return self.roc_curve.get_ppv()

def get_auc(self):
return self.roc_curve.get_auc()

Expand All @@ -707,7 +777,8 @@ def __str__(self):
(self.data_size.ntrain, self.data_size.ntest),
' AttackType: %s' % str(self.attack_type),
' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
' Attacker advantage: %.2f' % self.get_attacker_advantage(),
' Positive Predictive Value: %.2f' % self.get_ppv(), ')'
])


Expand Down Expand Up @@ -791,7 +862,14 @@ def collect_results(self, threshold_list, return_roc_results=True):
np.zeros(len(self.test_membership_probs)))),
np.concatenate(
(self.train_membership_probs, self.test_membership_probs)))
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
ntrain = np.shape(self.train_membership_probs)[0]
ntest = np.shape(self.test_membership_probs)[0]
test_train_ratio = ntest / ntrain
roc_curve = RocCurve(
tpr=tpr,
fpr=fpr,
thresholds=thresholds,
test_train_ratio=test_train_ratio)
summary.append(
' thresholding on membership probability achieved an AUC of %.2f' %
(roc_curve.get_auc()))
Expand Down Expand Up @@ -860,6 +938,7 @@ def calculate_pd_dataframe(self):
data_size_test = []
attack_types = []
advantages = []
ppvs = []
aucs = []

for attack_result in self.single_attack_results:
Expand All @@ -874,6 +953,7 @@ def calculate_pd_dataframe(self):
data_size_test.append(attack_result.data_size.ntest)
attack_types.append(str(attack_result.attack_type))
advantages.append(float(attack_result.get_attacker_advantage()))
ppvs.append(float(attack_result.get_ppv()))
aucs.append(float(attack_result.get_auc()))

df = pd.DataFrame({
Expand All @@ -883,6 +963,7 @@ def calculate_pd_dataframe(self):
str(AttackResultsDFColumns.DATA_SIZE_TEST): data_size_test,
str(AttackResultsDFColumns.ATTACK_TYPE): attack_types,
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
str(PrivacyMetric.PPV): ppvs,
str(PrivacyMetric.AUC): aucs
})
return df
Expand Down Expand Up @@ -918,6 +999,14 @@ def summary(self, by_slices=False) -> str:
max_advantage_result_all.get_attacker_advantage(),
max_advantage_result_all.slice_spec))

max_ppv_result_all = self.get_result_with_max_ppv()
summary.append(
' %s (with %d training and %d test examples) achieved a positive '
'predictive value of %.2f on slice %s' %
(max_ppv_result_all.attack_type, max_ppv_result_all.data_size.ntrain,
max_ppv_result_all.data_size.ntest, max_ppv_result_all.get_ppv(),
max_ppv_result_all.slice_spec))

slice_dict = self._group_results_by_slice()

if by_slices and len(slice_dict.keys()) > 1:
Expand All @@ -937,6 +1026,12 @@ def summary(self, by_slices=False) -> str:
max_advantage_result.data_size.ntrain,
max_auc_result.data_size.ntest,
max_advantage_result.get_attacker_advantage()))
max_ppv_result = results.get_result_with_max_ppv()
summary.append(
' %s (with %d training and %d test examples) achieved a positive '
'predictive value of %.2f' %
(max_ppv_result.attack_type, max_ppv_result.data_size.ntrain,
max_ppv_result.data_size.ntest, max_ppv_result.get_ppv()))

return '\n'.join(summary)

Expand Down Expand Up @@ -966,6 +1061,11 @@ def get_result_with_max_attacker_advantage(self) -> SingleAttackResult:
result.get_attacker_advantage() for result in self.single_attack_results
])]

def get_result_with_max_ppv(self) -> SingleAttackResult:
"""Gets the result with max positive predictive value for all attacks and slices."""
return self.single_attack_results[np.argmax(
[result.get_ppv() for result in self.single_attack_results])]

def save(self, filepath):
"""Saves self to a pickle file."""
with open(filepath, 'wb') as out:
Expand Down Expand Up @@ -1035,6 +1135,7 @@ def get_flattened_attack_metrics(results: AttackResults):
attack_metrics += ['adv', 'auc']
values += [
float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc())
float(attack_result.get_auc()),
float(attack_result.get_ppv()),
]
return types, slices, attack_metrics, values
Loading

0 comments on commit 2b5d5b6

Please sign in to comment.