From 72c8226f4967aa8f68af84e1ee5852b0d3bc9287 Mon Sep 17 00:00:00 2001 From: William Kong Date: Thu, 16 Nov 2023 13:21:24 -0800 Subject: [PATCH] Add a parameter to the noise function that explicitly specifies the loss reduction type. PiperOrigin-RevId: 583143642 --- .../gradient_clipping_utils.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 3e9e996f8..1b638fc57 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -150,6 +150,7 @@ def add_aggregate_noise( batch_size: tf.Tensor, l2_norm_clip: float, noise_multiplier: float, + loss_reduction: Optional[str] = None, ) -> Sequence[tf.Tensor]: """Adds noise to a collection of clipped gradients. @@ -159,10 +160,13 @@ def add_aggregate_noise( Args: input_model: The `tf.keras.Model` to obtain the layers from. clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. - batch_size: The batch size, used for normalizing the noise, when the loss - reduction is AUTO or SUM_OVER_BATCH_SIZE. + batch_size: The batch size. Used for normalizing the noise when + `loss_reduction` is 'sum'. l2_norm_clip: Clipping norm (max L2 norm of each gradient). noise_multiplier: Ratio of the standard deviation to the clipping norm. + loss_reduction: An optional string description of how the loss is reduced + over examples. Currently supports 'mean' and 'sum'. If `None`, then the + aggregation type is inferred from `input_model.loss`. Returns: A list of tensors containing the clipped gradients, but with the right @@ -170,12 +174,19 @@ def add_aggregate_noise( strategy of the loss function). """ scale = l2_norm_clip - if input_model.loss.reduction in [ - tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, - tf.keras.losses.Reduction.AUTO, - ]: - if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO: + if loss_reduction is None: + implicit_sum_reductions = [ + tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, + tf.keras.losses.Reduction.AUTO, + ] + model_reduction = input_model.loss.reduction + loss_reduction = ( + 'mean' if model_reduction in implicit_sum_reductions else 'sum' + ) + if model_reduction == tf.keras.losses.Reduction.AUTO: logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.') + + if loss_reduction == 'sum': scale /= tf.cast(batch_size, tf.float32) def add_noise(g):