Skip to content

Commit

Permalink
Support fast clipping in DPAM.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 533589057
  • Loading branch information
walidk authored and tensorflower-gardener committed May 20, 2023
1 parent 60d237b commit 98d7722
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 211 deletions.
83 changes: 48 additions & 35 deletions tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ def registry_generator_fn(layer_instance, args, kwargs):

def compute_gradient_norms(
input_model: tf.keras.Model,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[LossFn] = None,
num_microbatches: Optional[lr.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None,
Expand All @@ -84,15 +85,16 @@ def compute_gradient_norms(
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
weight_batch: Optional batch of weights, passed to the loss function.
per_example_loss_fn: takes as input predictions, labels and weights, and
outputs a vector of per-example losses. If None, derived from
`input_model.loss` by disabling its reduction.
Expand All @@ -108,8 +110,9 @@ def compute_gradient_norms(
variables are included.
Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function.
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
example loss (when num_microbatches is None) or the norm of the gradient of
the i-th microbatch loss (define as a mean over the microbatch).
"""
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn(
Expand All @@ -127,7 +130,7 @@ def compute_gradient_norms(
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
Expand All @@ -140,7 +143,7 @@ def compute_gradient_norms(
)
if num_microbatches is not None:
losses = tf.reduce_mean(
lr.add_microbatch_axis(losses, num_microbatches), axis=1
lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
Expand All @@ -165,8 +168,10 @@ def compute_gradient_norms(
vars_list,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
if not grads_list:
raise ValueError('Empty gradient list.')
sqr_norm_list = []
for grads, f in zip(grads_list, sqr_norm_fns_list):
for grads, f in zip(grads_list, sqr_norm_fns_list, strict=True):
sqr_norm_list.append(f(grads))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
Expand Down Expand Up @@ -199,10 +204,11 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):

def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[lr.BatchSize] = None,
clipping_loss: Optional[LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
Expand All @@ -218,11 +224,6 @@ def compute_clipped_gradients_and_outputs(
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`.
Expand All @@ -232,6 +233,15 @@ def compute_clipped_gradients_and_outputs(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
weight_batch: Optional vector of weights, passed to the loss function. Must
be of size [batch_size]. In case of microbatching, this will be reshaped
to [num_microbatches, batch_size/num_microbatches] before passing it to
the loss.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
Expand All @@ -243,11 +253,10 @@ def compute_clipped_gradients_and_outputs(
the value of the clipped loss does not reflect the true loss.
Returns:
A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
clipped gradient of the loss function, the second is the result of
applying `input_model` to `x_batch`, and the third is loss value of
`input_model`, weighted by the loss weights generated by a specific
`compute_clip_weights()` call.
clipped_grad: the clipped gradient of the loss function
y_pred: the result of applying `input_model` to `x_batch`
clipping_loss_value: the loss value weighted in such a way that its gradient
is `clipped_grad`.
"""
if input_model.loss.reduction == 'none':
raise NotImplementedError(
Expand All @@ -258,13 +267,25 @@ def compute_clipped_gradients_and_outputs(
clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms(
input_model,
layer_registry,
x_batch,
y_batch,
layer_registry,
weight_batch,
num_microbatches=num_microbatches,
trainable_vars=input_model.trainable_variables,
)
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
if weight_batch is not None:
if num_microbatches is None:
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
else:
# In this case, weight_batch is of shape [batch_size], we first reshape to
# [num_microbatches, microbatch_size] then multiply by the clip_weights
# (which is of shape [num_microbatches])
weight_batch = lr.maybe_add_microbatch_axis(
weight_batch, num_microbatches
)
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
with tf.GradientTape() as tape:
# WARNING: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches
Expand All @@ -274,17 +295,9 @@ def compute_clipped_gradients_and_outputs(
# is not defined in the contract so may not hold, especially for
# custom losses.
y_pred = input_model(x_batch, training=True)
loss_y_batch = (
y_batch
if num_microbatches is None
else lr.add_microbatch_axis(y_batch, num_microbatches)
)
loss_y_pred = (
y_pred
if num_microbatches is None
else lr.add_microbatch_axis(y_pred, num_microbatches)
)
clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches)
mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches)
clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights)
clipped_grads = tape.gradient(
clipping_loss_value,
input_model.trainable_variables,
Expand Down
Loading

0 comments on commit 98d7722

Please sign in to comment.