Skip to content

Commit

Permalink
Implement possibility to return slice indices.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 540885025
  • Loading branch information
dvadym authored and tensorflower-gardener committed Jun 16, 2023
1 parent a4bdb05 commit 45da453
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,16 @@ def main(unused_argv):
labels_train=training_labels,
labels_test=test_labels,
probs_train=training_pred,
probs_test=test_pred),
probs_test=test_pred,
),
data_structures.SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(data_structures.AttackType.THRESHOLD_ATTACK,
data_structures.AttackType.LOGISTIC_REGRESSION),
privacy_report_metadata=privacy_report_metadata)
attack_types=(
data_structures.AttackType.THRESHOLD_ATTACK,
data_structures.AttackType.LOGISTIC_REGRESSION,
),
privacy_report_metadata=privacy_report_metadata,
return_slice_indices=True,
)
epoch_results.append(attack_results)

# Generate privacy reports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,11 @@ class SingleAttackResult:
# test set samples will have lower scores than the training set samples.
membership_scores_test: Optional[np.ndarray] = None

# Indices of train and test examples from the input data that were used in
# this attack.
train_indices: Optional[np.ndarray] = None
test_indices: Optional[np.ndarray] = None

def get_attacker_advantage(self):
return self.roc_curve.get_attacker_advantage()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ def _slice_if_not_none(a, idx):
return None if a is None else a[idx]


def _slice_data_by_indices(data: AttackInputData, idx_train,
idx_test) -> AttackInputData:
def _slice_data_by_indices(
data: AttackInputData,
idx_train: np.ndarray,
idx_test: np.ndarray,
return_slice_indices: bool,
) -> AttackInputData:
"""Slices train fields with idx_train and test fields with idx_test."""

result = AttackInputData()
Expand Down Expand Up @@ -68,10 +72,16 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
# of the original dataset.
result.multilabel_data = data.is_multilabel_data()

if return_slice_indices:
result.train_indices = np.where(idx_train)[0]
result.test_indices = np.where(idx_test)[0]

return result


def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
def _slice_by_class(
data: AttackInputData, class_value: int, return_slice_indices: bool = False
) -> AttackInputData:
"""Gets the indices (boolean) for examples belonging to the given class."""
if not data.is_multilabel_data():
idx_train = data.labels_train == class_value
Expand All @@ -84,11 +94,15 @@ def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
)
idx_train = data.labels_train[:, class_value].astype(bool)
idx_test = data.labels_test[:, class_value].astype(bool)
return _slice_data_by_indices(data, idx_train, idx_test)
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)


def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
to_percentile: float):
def _slice_by_percentiles(
data: AttackInputData,
from_percentile: float,
to_percentile: float,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Slices samples by loss percentiles."""

# Find from_percentile and to_percentile percentiles in losses.
Expand All @@ -107,22 +121,31 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
idx_train = (from_loss <= loss_train) & (loss_train <= to_loss)
idx_test = (from_loss <= loss_test) & (loss_test <= to_loss)

return _slice_data_by_indices(data, idx_train, idx_test)
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)


def _indices_by_classification(logits_or_probs, labels, correctly_classified):
def _indices_by_classification(
logits_or_probs,
labels,
correctly_classified,
):
idx_correct = labels == np.argmax(logits_or_probs, axis=1)
return idx_correct if correctly_classified else np.invert(idx_correct)


def _slice_by_classification_correctness(data: AttackInputData,
correctly_classified: bool):
def _slice_by_classification_correctness(
data: AttackInputData,
correctly_classified: bool,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Slices attack inputs by whether they were classified correctly.
Args:
data: Data to be used as input to the attack models.
correctly_classified: Whether to use the indices corresponding to the
correctly classified samples.
return_slice_indices: if true, the returned AttackInputData will include
indices of the train and test data samples that were used for this slice.
Returns:
AttackInputData object containing the sliced data.
Expand All @@ -136,20 +159,25 @@ def _slice_by_classification_correctness(data: AttackInputData,
correctly_classified)
idx_test = _indices_by_classification(data.logits_or_probs_test,
data.labels_test, correctly_classified)
return _slice_data_by_indices(data, idx_train, idx_test)
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)


def _slice_by_custom_indices(data: AttackInputData,
custom_train_indices: np.ndarray,
custom_test_indices: np.ndarray,
group_value: int) -> AttackInputData:
def _slice_by_custom_indices(
data: AttackInputData,
custom_train_indices: np.ndarray,
custom_test_indices: np.ndarray,
group_value: int,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Slices attack inputs by custom indices.
Args:
data: Data to be used as input to the attack models.
custom_train_indices: The group indices of each training example.
custom_test_indices: The group indices of each test example.
group_value: The group value to pick.
return_slice_indices: if true, the returned AttackInputData will include
indices of the train and test data samples that were used for this slice.
Returns:
AttackInputData object containing the sliced data.
Expand All @@ -167,12 +195,12 @@ def _slice_by_custom_indices(data: AttackInputData,
f"{test_size}")
idx_train = custom_train_indices == group_value
idx_test = custom_test_indices == group_value
return _slice_data_by_indices(data, idx_train, idx_test)
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)


def get_single_slice_specs(
slicing_spec: SlicingSpec,
num_classes: Optional[int] = None) -> List[SingleSliceSpec]:
slicing_spec: SlicingSpec, num_classes: Optional[int] = None
) -> List[SingleSliceSpec]:
"""Returns slices of data according to slicing_spec.
Args:
Expand Down Expand Up @@ -242,22 +270,34 @@ def get_single_slice_specs(
return result


def get_slice(data: AttackInputData,
slice_spec: SingleSliceSpec) -> AttackInputData:
def get_slice(
data: AttackInputData,
slice_spec: SingleSliceSpec,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Returns a single slice of data according to slice_spec."""
if slice_spec.entire_dataset:
data_slice = copy.copy(data)
elif slice_spec.feature == SlicingFeature.CLASS:
data_slice = _slice_by_class(data, slice_spec.value)
data_slice = _slice_by_class(data, slice_spec.value, return_slice_indices)
elif slice_spec.feature == SlicingFeature.PERCENTILE:
from_percentile, to_percentile = slice_spec.value
data_slice = _slice_by_percentiles(data, from_percentile, to_percentile)
data_slice = _slice_by_percentiles(
data, from_percentile, to_percentile, return_slice_indices
)
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
data_slice = _slice_by_classification_correctness(
data, slice_spec.value, return_slice_indices
)
elif slice_spec.feature == SlicingFeature.CUSTOM:
custom_train_indices, custom_test_indices, group_value = slice_spec.value
data_slice = _slice_by_custom_indices(data, custom_train_indices,
custom_test_indices, group_value)
data_slice = _slice_by_custom_indices(
data,
custom_train_indices,
custom_test_indices,
group_value,
return_slice_indices,
)
else:
raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,12 @@ def test_slice_by_percentile(self):
self.assertTrue((output.labels_train == [1, 0, 2]).all())
self.assertTrue((output.labels_test == [1]).all())

def test_slice_by_correctness(self):
percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED,
False)
output = get_slice(self.input_data, percentile_slice)
@parameterized.parameters(False, True)
def test_slice_by_correctness(self, return_slice_indices):
slice_spec = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False)
output = get_slice(
self.input_data, slice_spec, return_slice_indices=return_slice_indices
)

# Check logits.
self.assertLen(output.logits_train, 2)
Expand All @@ -284,6 +286,14 @@ def test_slice_by_correctness(self):
self.assertTrue((output.labels_train == [0, 2]).all())
self.assertTrue((output.labels_test == [1, 2, 0]).all())

# Check return indices
if return_slice_indices:
self.assertTrue((output.train_indices == [1, 3]).all())
self.assertTrue((output.test_indices == [0, 1, 2]).all())
else:
self.assertFalse(hasattr(output, 'train_indices'))
self.assertFalse(hasattr(output, 'test_indices'))

def test_slice_by_custom_indices(self):
custom_train_indices = np.array([2, 2, 100, 4])
custom_test_indices = np.array([100, 2, 2, 2])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,16 @@ def _run_attack(attack_input: AttackInputData,
return _run_threshold_attack(attack_input)


def run_attacks(attack_input: AttackInputData,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (
AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True,
min_num_samples: int = 1,
backend: Optional[str] = None) -> AttackResults:
def run_attacks(
attack_input: AttackInputData,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True,
min_num_samples: int = 1,
backend: Optional[str] = None,
return_slice_indices=False,
) -> AttackResults:
"""Runs membership inference attacks on a classification model.
It runs attacks specified by attack_types on each attack_input slice which is
Expand All @@ -282,6 +284,9 @@ def run_attacks(attack_input: AttackInputData,
may not support multiprocessing and in those cases the `threading` backend
should be used. See https://joblib.readthedocs.io/en/latest/parallel.html
for more details.
return_slice_indices: if true, the result for each slice will include the
indices of train and test data examples that were used for this slice and
attacks. This does not return indices for the "whole dataset" slice.
Returns:
the attack result.
Expand All @@ -301,7 +306,9 @@ def run_attacks(attack_input: AttackInputData,
logging.info('Will run %s attacks on each of %s slice specifications.',
num_attacks, num_slice_specs)
for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec)
attack_input_slice = get_slice(
attack_input, single_slice_spec, return_slice_indices
)
for attack_type in attack_types:
logging.info('Running attack: %s', attack_type.name)
attack_result = _run_attack(attack_input_slice, attack_type,
Expand All @@ -313,6 +320,9 @@ def run_attacks(attack_input: AttackInputData,
'positive predictive value=%s', attack_type.name,
attack_result.get_auc(), attack_result.get_attacker_advantage(),
attack_result.get_ppv())
if return_slice_indices and not single_slice_spec.entire_dataset:
attack_result.train_indices = attack_input_slice.train_indices
attack_result.test_indices = attack_input_slice.test_indices
attack_results.append(attack_result)

privacy_report_metadata = _compute_missing_privacy_report_metadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,23 @@ def test_run_attacks_insufficient_examples(self):
AttackType.THRESHOLD_ATTACK.value,
)

def test_run_attacks_size_return_indices(self):
result = mia.run_attacks(
get_test_input(100, 100),
SlicingSpec(
entire_dataset=False,
by_class=True,
by_percentiles=True,
by_classification_correctness=True,
),
(AttackType.LOGISTIC_REGRESSION,),
return_slice_indices=True,
)

for attack_result in result.single_attack_results:
self.assertIsNotNone(attack_result.train_indices)
self.assertIsNotNone(attack_result.test_indices)


if __name__ == '__main__':
absltest.main()

0 comments on commit 45da453

Please sign in to comment.