From 96d516cb299d928af09ce953934695550d65ca9b Mon Sep 17 00:00:00 2001 From: William Kong Date: Thu, 16 Nov 2023 13:21:24 -0800 Subject: [PATCH] Add a parameter to the noise function that explicitly specifies the loss reduction type. PiperOrigin-RevId: 583143642 --- .../privacy/fast_gradient_clipping/BUILD | 1 + .../gradient_clipping_utils.py | 51 +++++++++++++---- .../gradient_clipping_utils_test.py | 55 +++++++++++++++++++ tensorflow_privacy/privacy/keras_models/BUILD | 1 + .../privacy/keras_models/dp_keras_model.py | 3 +- 5 files changed, 99 insertions(+), 12 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index d7fc5608e..f2eefd7c7 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -38,6 +38,7 @@ py_test( name = "gradient_clipping_utils_test", srcs = ["gradient_clipping_utils_test.py"], python_version = "PY3", + shard_count = 8, srcs_version = "PY3", deps = [ ":gradient_clipping_utils", diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 3e9e996f8..adca9af66 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -14,7 +14,7 @@ """Utility functions that help in the computation of per-example gradient norms.""" from collections.abc import Sequence, Set -from typing import Any, Optional +from typing import Any, Literal, Optional from absl import logging import tensorflow as tf @@ -145,11 +145,12 @@ def all_trainable_layers_are_registered( def add_aggregate_noise( - input_model: tf.keras.Model, clipped_grads: list[tf.Tensor], batch_size: tf.Tensor, l2_norm_clip: float, noise_multiplier: float, + loss_reduction: Optional[Literal['mean', 'sum']] = None, + loss_model: Optional[tf.keras.Model] = None, ) -> Sequence[tf.Tensor]: """Adds noise to a collection of clipped gradients. @@ -157,25 +158,53 @@ def add_aggregate_noise( input model's loss function. Args: - input_model: The `tf.keras.Model` to obtain the layers from. clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. - batch_size: The batch size, used for normalizing the noise, when the loss - reduction is AUTO or SUM_OVER_BATCH_SIZE. + batch_size: The batch size. Used for normalizing the noise when + `loss_reduction` is 'sum'. l2_norm_clip: Clipping norm (max L2 norm of each gradient). noise_multiplier: Ratio of the standard deviation to the clipping norm. + loss_reduction: An string description of how the loss is reduced over + examples. Currently supports 'mean' and 'sum'. If `None`, then the + aggregation type must be inferred from `input_model.loss`. + loss_model: An optional `tf.keras.Model` used to infer the loss reduction + strategy from if `loss_reduction` is `None`. Returns: A list of tensors containing the clipped gradients, but with the right amount of Gaussian noise added to them (depending on the reduction strategy of the loss function). + + Raises: + ValueError: If both `loss_model` and `loss_reduction` are `None` or if + they are both not `None`. """ + if loss_reduction is None and loss_model is None: + raise ValueError( + 'Exactly one of `loss_reduction` and `loss_model` must be populated.' + ' Instead, both arguments were `None`.' + ) + if loss_reduction is not None and loss_model is not None: + raise ValueError( + 'Exactly one of `loss_reduction` and `loss_model` must be populated.' + ' Instead, both arguments were not `None`.' + ) + + if loss_reduction is None and loss_model is not None: + implicit_mean_reductions = [ + tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, + tf.keras.losses.Reduction.AUTO, + ] + model_reduction = loss_model.loss.reduction + loss_reduction = ( + 'mean' if model_reduction in implicit_mean_reductions else 'sum' + ) + if model_reduction == tf.keras.losses.Reduction.AUTO: + logging.info( + 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' + ) + scale = l2_norm_clip - if input_model.loss.reduction in [ - tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, - tf.keras.losses.Reduction.AUTO, - ]: - if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO: - logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.') + if loss_reduction == 'mean': scale /= tf.cast(batch_size, tf.float32) def add_noise(g): diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py index 7069273d9..f7cf3b0fc 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py @@ -15,6 +15,7 @@ from typing import Any from absl.testing import parameterized +import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils @@ -134,6 +135,60 @@ def test_outputs_are_consistent( self.assertAllClose(computed_outputs, true_outputs) +class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + l2_norm_clip=[3.0, 5.0], + noise_multiplier=[2.0, 4.0], + batch_size=[1, 2, 10], + model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], + noise_fn_reduction=[None, 'mean', 'sum'], + ) + def test_noise_is_computed_correctly( + self, + l2_norm_clip, + noise_multiplier, + batch_size, + model_fn_reduction, + noise_fn_reduction, + ): + # Skip invalid combinations. + if model_fn_reduction is None and noise_fn_reduction is None: + return + if model_fn_reduction is not None and noise_fn_reduction is not None: + return + # Make an simple model container for storing the loss. + if model_fn_reduction is not None: + linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) + linear_model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) + ) + else: + linear_model = None + # The main computation is done on a deterministic dummy vector. + num_units = 100 + clipped_grads = [ + tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) + ] + noised_grads = gradient_clipping_utils.add_aggregate_noise( + clipped_grads, + batch_size, + l2_norm_clip, + noise_multiplier, + noise_fn_reduction, + linear_model, + ) + # The only measure that varies is the standard deviation of the variation. + scale = ( + 1.0 + if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' + else 1.0 / batch_size + ) + computed_std = np.std(noised_grads[0] - clipped_grads[0]) + expected_std = l2_norm_clip * noise_multiplier * scale + self.assertNear(computed_std, expected_std, 0.1 * expected_std) + + class GenerateOutputsUsingCoreKerasLayers( tf.test.TestCase, parameterized.TestCase ): diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD index 508285965..738008492 100644 --- a/tensorflow_privacy/privacy/keras_models/BUILD +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -25,6 +25,7 @@ py_test( name = "dp_keras_model_test", srcs = ["dp_keras_model_test.py"], python_version = "PY3", + shard_count = 16, srcs_version = "PY3", deps = [ "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 29cb7d59b..62172460e 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -264,11 +264,12 @@ def train_step(self, data): output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss if self._noise_multiplier > 0: grads = gradient_clipping_utils.add_aggregate_noise( - self, clipped_grads, num_microbatches, self._l2_norm_clip, self._noise_multiplier, + loss_reduction=None, + loss_model=self, ) else: grads = clipped_grads