Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fast clipping in DPAM. #476

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 48 additions & 35 deletions tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ def registry_generator_fn(layer_instance, args, kwargs):

def compute_gradient_norms(
input_model: tf.keras.Model,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[LossFn] = None,
num_microbatches: Optional[lr.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None,
Expand All @@ -84,15 +85,16 @@ def compute_gradient_norms(
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
layer_registry: A `LayerRegistry` instance containing functions that help
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
weight_batch: Optional batch of weights, passed to the loss function.
per_example_loss_fn: takes as input predictions, labels and weights, and
outputs a vector of per-example losses. If None, derived from
`input_model.loss` by disabling its reduction.
Expand All @@ -108,8 +110,9 @@ def compute_gradient_norms(
variables are included.

Returns:
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
per-example loss function.
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
example loss (when num_microbatches is None) or the norm of the gradient of
the i-th microbatch loss (define as a mean over the microbatch).
"""
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn(
Expand All @@ -127,7 +130,7 @@ def compute_gradient_norms(
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
Expand All @@ -140,7 +143,7 @@ def compute_gradient_norms(
)
if num_microbatches is not None:
losses = tf.reduce_mean(
lr.add_microbatch_axis(losses, num_microbatches), axis=1
lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
Expand All @@ -165,8 +168,10 @@ def compute_gradient_norms(
vars_list,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
if not grads_list:
raise ValueError('Empty gradient list.')
sqr_norm_list = []
for grads, f in zip(grads_list, sqr_norm_fns_list):
for grads, f in zip(grads_list, sqr_norm_fns_list, strict=True):
sqr_norm_list.append(f(grads))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
Expand Down Expand Up @@ -199,10 +204,11 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):

def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: InputTensor,
y_batch: tf.Tensor,
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[lr.BatchSize] = None,
clipping_loss: Optional[LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
Expand All @@ -218,11 +224,6 @@ def compute_clipped_gradients_and_outputs(

Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`.
Expand All @@ -232,6 +233,15 @@ def compute_clipped_gradients_and_outputs(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
weight_batch: Optional vector of weights, passed to the loss function. Must
be of size [batch_size]. In case of microbatching, this will be reshaped
to [num_microbatches, batch_size/num_microbatches] before passing it to
the loss.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
Expand All @@ -243,11 +253,10 @@ def compute_clipped_gradients_and_outputs(
the value of the clipped loss does not reflect the true loss.

Returns:
A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
clipped gradient of the loss function, the second is the result of
applying `input_model` to `x_batch`, and the third is loss value of
`input_model`, weighted by the loss weights generated by a specific
`compute_clip_weights()` call.
clipped_grad: the clipped gradient of the loss function
y_pred: the result of applying `input_model` to `x_batch`
clipping_loss_value: the loss value weighted in such a way that its gradient
is `clipped_grad`.
"""
if input_model.loss.reduction == 'none':
raise NotImplementedError(
Expand All @@ -258,13 +267,25 @@ def compute_clipped_gradients_and_outputs(
clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms(
input_model,
layer_registry,
x_batch,
y_batch,
layer_registry,
weight_batch,
num_microbatches=num_microbatches,
trainable_vars=input_model.trainable_variables,
)
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
if weight_batch is not None:
if num_microbatches is None:
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
else:
# In this case, weight_batch is of shape [batch_size], we first reshape to
# [num_microbatches, microbatch_size] then multiply by the clip_weights
# (which is of shape [num_microbatches])
weight_batch = lr.maybe_add_microbatch_axis(
weight_batch, num_microbatches
)
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
with tf.GradientTape() as tape:
# WARNING: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches
Expand All @@ -274,17 +295,9 @@ def compute_clipped_gradients_and_outputs(
# is not defined in the contract so may not hold, especially for
# custom losses.
y_pred = input_model(x_batch, training=True)
loss_y_batch = (
y_batch
if num_microbatches is None
else lr.add_microbatch_axis(y_batch, num_microbatches)
)
loss_y_pred = (
y_pred
if num_microbatches is None
else lr.add_microbatch_axis(y_pred, num_microbatches)
)
clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches)
mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches)
clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights)
clipped_grads = tape.gradient(
clipping_loss_value,
input_model.trainable_variables,
Expand Down
Loading