Skip to content

Commit

Permalink
Add a parameter to the noise function that explicitly specifies the l…
Browse files Browse the repository at this point in the history
…oss reduction type.

PiperOrigin-RevId: 583143642
  • Loading branch information
wwkong authored and tensorflower-gardener committed Nov 17, 2023
1 parent 39c8a8c commit 96d516c
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 12 deletions.
1 change: 1 addition & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ py_test(
name = "gradient_clipping_utils_test",
srcs = ["gradient_clipping_utils_test.py"],
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
":gradient_clipping_utils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Utility functions that help in the computation of per-example gradient norms."""

from collections.abc import Sequence, Set
from typing import Any, Optional
from typing import Any, Literal, Optional

from absl import logging
import tensorflow as tf
Expand Down Expand Up @@ -145,37 +145,66 @@ def all_trainable_layers_are_registered(


def add_aggregate_noise(
input_model: tf.keras.Model,
clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
input_model: The `tf.keras.Model` to obtain the layers from.
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size, used for normalizing the noise, when the loss
reduction is AUTO or SUM_OVER_BATCH_SIZE.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
loss_reduction: An string description of how the loss is reduced over
examples. Currently supports 'mean' and 'sum'. If `None`, then the
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
they are both not `None`.
"""
if loss_reduction is None and loss_model is None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were `None`.'
)
if loss_reduction is not None and loss_model is not None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were not `None`.'
)

if loss_reduction is None and loss_model is not None:
implicit_mean_reductions = [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = loss_model.loss.reduction
loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum'
)
if model_reduction == tf.keras.losses.Reduction.AUTO:
logging.info(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)

scale = l2_norm_clip
if input_model.loss.reduction in [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]:
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')
if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32)

def add_noise(g):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils

Expand Down Expand Up @@ -134,6 +135,60 @@ def test_outputs_are_consistent(
self.assertAllClose(computed_outputs, true_outputs)


class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase):

@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
clipped_grads = [
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
]
noised_grads = gradient_clipping_utils.add_aggregate_noise(
clipped_grads,
batch_size,
l2_norm_clip,
noise_multiplier,
noise_fn_reduction,
linear_model,
)
# The only measure that varies is the standard deviation of the variation.
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std)


class GenerateOutputsUsingCoreKerasLayers(
tf.test.TestCase, parameterized.TestCase
):
Expand Down
1 change: 1 addition & 0 deletions tensorflow_privacy/privacy/keras_models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ py_test(
name = "dp_keras_model_test",
srcs = ["dp_keras_model_test.py"],
python_version = "PY3",
shard_count = 16,
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_privacy/privacy/keras_models/dp_keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,12 @@ def train_step(self, data):
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0:
grads = gradient_clipping_utils.add_aggregate_noise(
self,
clipped_grads,
num_microbatches,
self._l2_norm_clip,
self._noise_multiplier,
loss_reduction=None,
loss_model=self,
)
else:
grads = clipped_grads
Expand Down

0 comments on commit 96d516c

Please sign in to comment.