Skip to content

Commit

Permalink
Add a parameter to the noise function that explicitly specifies the l…
Browse files Browse the repository at this point in the history
…oss reduction type.

PiperOrigin-RevId: 583143642
  • Loading branch information
wwkong authored and tensorflower-gardener committed Nov 16, 2023
1 parent 39c8a8c commit 72c8226
Showing 1 changed file with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def add_aggregate_noise(
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[str] = None,
) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
Expand All @@ -159,23 +160,33 @@ def add_aggregate_noise(
Args:
input_model: The `tf.keras.Model` to obtain the layers from.
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size, used for normalizing the noise, when the loss
reduction is AUTO or SUM_OVER_BATCH_SIZE.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
loss_reduction: An optional string description of how the loss is reduced
over examples. Currently supports 'mean' and 'sum'. If `None`, then the
aggregation type is inferred from `input_model.loss`.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
"""
scale = l2_norm_clip
if input_model.loss.reduction in [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]:
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
if loss_reduction is None:
implicit_sum_reductions = [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = input_model.loss.reduction
loss_reduction = (
'mean' if model_reduction in implicit_sum_reductions else 'sum'
)
if model_reduction == tf.keras.losses.Reduction.AUTO:
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')

if loss_reduction == 'sum':
scale /= tf.cast(batch_size, tf.float32)

def add_noise(g):
Expand Down

0 comments on commit 72c8226

Please sign in to comment.