Skip to content

Commit

Permalink
Refactor: move loss computation utilities under privacy_tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 463391913
  • Loading branch information
shs037 authored and tensorflower-gardener committed Jul 26, 2022
1 parent 44dc404 commit 17cd0c5
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 88 deletions.
19 changes: 17 additions & 2 deletions tensorflow_privacy/privacy/privacy_tests/BUILD
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
load("@rules_python//python:defs.bzl", "py_library")
load("@rules_python//python:defs.bzl", "py_library", "py_test")

package(default_visibility = ["//visibility:private"])
package(default_visibility = ["//visibility:public"])

licenses(["notice"])

py_library(
name = "privacy_tests",
srcs = ["__init__.py"],
)

py_test(
name = "utils_test",
timeout = "long",
srcs = ["utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":utils"],
)

py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY3",
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@ py_library(
srcs_version = "PY3",
)

py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY3",
)

py_test(
name = "utils_test",
timeout = "long",
srcs = ["utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":utils"],
)

py_test(
name = "membership_inference_attack_test",
timeout = "long",
Expand All @@ -45,7 +30,10 @@ py_test(
srcs = ["data_structures_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":membership_inference_attack"],
deps = [
":membership_inference_attack",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)

py_test(
Expand Down Expand Up @@ -95,7 +83,7 @@ py_library(
"seq2seq_mia.py",
],
srcs_version = "PY3",
deps = [":utils"],
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
)

py_library(
Expand All @@ -122,8 +110,8 @@ py_library(
srcs_version = "PY3",
deps = [
":membership_inference_attack",
":utils",
":utils_tensorboard",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)

Expand All @@ -144,8 +132,8 @@ py_library(
srcs_version = "PY3",
deps = [
":membership_inference_attack",
":utils",
":utils_tensorboard",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)

Expand Down Expand Up @@ -185,7 +173,7 @@ py_library(
"advanced_mia.py",
],
srcs_version = "PY3",
deps = [":utils"],
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
)

py_test(
Expand All @@ -205,6 +193,6 @@ py_binary(
deps = [
":advanced_mia",
":membership_inference_attack",
":utils",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Sequence, Union
import numpy as np
import scipy.stats
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.privacy_tests.utils import log_loss


def replace_nan_with_column_mean(a: np.ndarray):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData

FLAGS = flags.FLAGS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pandas as pd
from scipy import special
from sklearn import metrics
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests import utils

# The minimum TPR or FPR below which they are considered equal.
_ABSOLUTE_TOLERANCE = 1e-3
Expand Down Expand Up @@ -183,12 +183,6 @@ def _log_value(probs, small_value=1e-30):
return -np.log(np.maximum(probs, small_value))


class LossFunction(enum.Enum):
"""An enum that defines loss function to use in `AttackInputData`."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'


@dataclasses.dataclass
class AttackInputData:
"""Input data for running an attack.
Expand Down Expand Up @@ -225,7 +219,7 @@ class AttackInputData:
# If a callable is provided, it should take in two argument, the 1st is
# labels, the 2nd is logits or probs.
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction] = LossFunction.CROSS_ENTROPY
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
# Whether `loss_function` will be called with logits or probs. If not set
# (None), will decide by availablity of logits and probs and logits is
# preferred when both are available.
Expand Down Expand Up @@ -298,52 +292,6 @@ def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
true_labels]
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)

@staticmethod
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
multilabel_data: if the data is from a multilabel classification problem.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')

predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data:
loss = utils.multilabel_bce_loss(labels, predictions,
loss_function_using_logits)
else:
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = utils.squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss

def __post_init__(self):
"""Checks performed after instantiation of the AttackInputData dataclass."""
# Check if the data is multilabel.
Expand All @@ -358,7 +306,7 @@ def get_loss_train(self):
"""
if self.loss_function_using_logits is None:
self.loss_function_using_logits = (self.logits_train is not None)
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
self.probs_train, self.loss_function,
self.loss_function_using_logits, self.multilabel_data)

Expand All @@ -370,7 +318,7 @@ def get_loss_test(self):
"""
if self.loss_function_using_logits is None:
self.loss_function_using_logits = bool(self.logits_test)
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
self.probs_test, self.loss_function,
self.loss_function_using_logits, self.multilabel_data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
from absl.testing import parameterized
import numpy as np
import pandas as pd
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_get_squared_loss(self, loss_function_using_logits, expected_train,
probs_test=np.array([1, 1.]),
labels_train=np.array([1, 0.]),
labels_test=np.array([0, 2.]),
loss_function=LossFunction.SQUARED,
loss_function=utils.LossFunction.SQUARED,
loss_function_using_logits=loss_function_using_logits,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_default_loss_function_using_logits(self, logits, probs, expected):
probs_test=probs,
labels_train=np.array([1, 0.]),
labels_test=np.array([1, 0.]),
loss_function=LossFunction.SQUARED,
loss_function=utils.LossFunction.SQUARED,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
np.testing.assert_allclose(attack_input.get_loss_test(), expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import numpy as np
import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.
"""Utility functions for membership inference attacks."""

import enum
import logging
from typing import Callable, Optional, Union

import numpy as np
from scipy import special

Expand Down Expand Up @@ -122,3 +125,65 @@ def multilabel_bce_loss(labels: np.ndarray,
bce = labels * np.log(pred + small_value)
bce += (1 - labels) * np.log(1 - pred + small_value)
return -bce


class LossFunction(enum.Enum):
"""An enum that defines loss function."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'


def string_to_loss_function(string: str):
"""Convert string to the corresponding LossFunction."""

if string == LossFunction.CROSS_ENTROPY.value:
return LossFunction.CROSS_ENTROPY
if string == LossFunction.SQUARED.value:
return LossFunction.SQUARED
raise ValueError(f'{string} is not a valid loss function name.')


def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
multilabel_data: if the data is from a multilabel classification problem.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')

predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data:
loss = multilabel_bce_loss(labels, predictions,
loss_function_using_logits)
else:
loss = log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,24 @@
from absl.testing import parameterized
import numpy as np

from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests import utils


class LossFunctionFromStringTest(parameterized.TestCase):

@parameterized.parameters(
(utils.LossFunction.CROSS_ENTROPY, 'cross_entropy'),
(utils.LossFunction.SQUARED, 'squared'),
)
def test_from_str(self, en, string):
self.assertEqual(utils.string_to_loss_function(string), en)

@parameterized.parameters(
('random string'),
(''),
)
def test_from_str_wrong_input(self, string):
self.assertRaises(ValueError, utils.string_to_loss_function, string)


class TestLogLoss(parameterized.TestCase):
Expand Down

0 comments on commit 17cd0c5

Please sign in to comment.