Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy
Browse files Browse the repository at this point in the history
Refactor model_forward_backward_pass out of compute_gradients to allow for other optimizations such as sparsity preserving noise to integrate with it.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660503249
  • Loading branch information
tensorflower-gardener committed Aug 7, 2024
1 parent d3f527e commit a7556da
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 195 deletions.
2 changes: 2 additions & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ py_library(
srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3",
deps = [
":common_manip_utils",
":layer_registry",
":type_aliases",
],
Expand Down Expand Up @@ -94,6 +95,7 @@ py_test(
deps = [
":clip_grads",
":common_test_utils",
":gradient_clipping_utils",
":layer_registry",
":type_aliases",
],
Expand Down
293 changes: 103 additions & 190 deletions tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""

import collections
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Optional

import tensorflow as tf
Expand All @@ -32,110 +32,81 @@
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases


def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None,
def _compute_gradient_norms_internal(
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
trainable_vars: Optional[Sequence[tf.Variable]] = None,
):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
# Needed for backwards compatibility.
registry_generator_fn = None
else:

def registry_generator_fn(layer_instance, args, kwargs):
if layer_instance.trainable_variables:
# Only trainable variables factor into the gradient.
if not layer_registry.is_elem(layer_instance):
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can '
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches
)
return layer_outputs, (
str(id(layer_instance)),
layer_vars,
layer_sqr_norm_fn,
layer_instance.trainable_weights,
)
else:
# Non-trainable layer.
return layer_instance(*args, **kwargs), None
"""Computes the per-example loss gradient norms for given data.
return registry_generator_fn
Args:
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
containing information required to compute the gradient norms and
contribution counts. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
layer_grad_vars: A mapping of layer id to a list of gradients for each
trainablev ariable in the layer. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
trainable_vars: The list of variables included in computing the gradient
norm. When a layer has multiple variables, we include all the variables if
any of the variables is in the list. If `trainable_vars` is None, all the
variables are included.
Returns:
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
weighted example loss (when num_microbatches is None) or the norm of the
gradient of the i-th microbatch loss (define as a mean over the microbatch).
Note that when the loss is weighted (`weight_batch` is not None), weights
are applied prior to clipping.
def _infer_per_example_loss_fn(model: tf.keras.Model):
"""Infer the per-example loss from model config."""
Raises:
ValueError: If `layer_grad_vars` is empty.
ValueError: If the number of gradients for a layer is not equal to the
number of squared norm functions for that layer.
"""
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])

def _convert(loss_fn):
loss_config = loss_fn.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
return loss_fn.from_config(loss_config)
layer_sqr_norm_fns = collections.defaultdict(list)
# The case of shared weights:
# If a layer is called k times, it will appear k times in filtered_outputs,
# with the same id, but potentially with different v and f. The code below
# groups filtered_outputs by layer_id, so we can correctly compute gradient
# norms. The gradient norm of a layer that occurs k times is computed as
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
registry_fn_output.layer_sqr_norm_fn
)

model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return _convert(model_loss)
elif isinstance(model_loss, dict):
# Note that we cannot call the public method `.get_compile_config()` because
# it calls a numpy function, which is not supported inside a `tf.function`
# wrapped function.
compile_config = model._compile_config.config # pylint: disable=protected-access
if compile_config is None:
raise ValueError('Model must be compiled for loss function conversion')
# Does a weighted mean of the configured losses. Note that we cannot build
# from the config of the compiled loss because (i) it builds a
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
# during its construction, (ii) non-unique `tf.Variables` cannot be used
# inside a `tf.function`, which is usually where this function is used.
if 'loss_weights' not in compile_config:
if not layer_grad_vars:
raise ValueError('The gradient list cannot be empty.')
sqr_norm_list = []
for layer_id in layer_sqr_norm_fns.keys():
fns = layer_sqr_norm_fns[layer_id]
grads = layer_grad_vars[layer_id]
# Number of duplicates for this layer in `filtered_outputs`.
num_passes = len(fns)
if len(fns) != len(grads):
raise ValueError(
'Models with multiple loss must have corresponding loss weights for'
' loss function conversion'
'There must be as many gradients as squared norm functions.'
)
weights = compile_config['loss_weights']
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
num_losses = len(weights)

def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
loss_values = []
if model_loss.keys() - y_pred.keys():
raise ValueError(
'y_pred must contain the same keys and the model losses, but '
'got %s and %s' % (y_pred.keys(), model_loss.keys())
)
if model_loss.keys() - y_true.keys():
raise ValueError(
'y_true must contain the same keys and the model losses, but '
'got %s and %s' % (y_true.keys(), model_loss.keys())
)
if sample_weight is not None:
if model_loss.keys() - sample_weight.keys():
raise ValueError(
'sample_weight must contain the same keys and the model losses,'
' but got %s and %s' % (y_true.keys(), model_loss.keys())
)
for k in y_true.keys():
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
sgl_value = (
weights[k]
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
/ num_losses
)
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
return tf.math.add_n(loss_values)

return _per_example_loss_fn
else:
raise ValueError(
'Unsupported type for loss function conversion: {}'.format(
type(model_loss)
)
)
# See go/dp-sgd-shared-weights for more details.
for fn, grad in zip(fns, grads):
sqr_norm_list.append(num_passes * fn(grad))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
return gradient_norms


def compute_gradient_norms(
Expand All @@ -147,7 +118,7 @@ def compute_gradient_norms(
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[Sequence[tf.Variable]] = None,
):
) -> tf.Tensor:
"""Computes the per-example loss gradient norms for given data.
Applies a variant of the approach given in
Expand Down Expand Up @@ -190,86 +161,28 @@ def compute_gradient_norms(
are applied prior to clipping.
"""
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn(
tape, layer_registry, num_microbatches
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=layer_registry,
num_microbatches=num_microbatches,
)
# First loop computes the model outputs, summed loss, and generator outputs.
with tape:
model_outputs, generator_outputs_list = (
gradient_clipping_utils.model_forward_pass(
input_model, x_batch, generator_fn=registry_generator_fn
)
)

# Ignore the original loss function's reduction to get per-example loss.
if per_example_loss_fn is None:
per_example_loss_fn = _infer_per_example_loss_fn(input_model)

losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
)
if len(losses.shape) != 1:
raise NotImplementedError(
'The unreduced (or per-example) loss needs to have a shape of length '
'one, but received an unreduced loss of shape length %s'
% len(losses.shape)
)
if num_microbatches is not None:
losses = tf.reduce_mean(
common_manip_utils.maybe_add_microbatch_axis(
losses, num_microbatches
),
axis=1,
layer_grad_vars, generator_outputs_list = (
gradient_clipping_utils.model_forward_backward_pass(
tape=tape,
input_model=input_model,
x_batch=x_batch,
y_batch=y_batch,
registry_generator_fn=registry_generator_fn,
weight_batch=weight_batch,
per_example_loss_fn=per_example_loss_fn,
num_microbatches=num_microbatches,
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None]
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
layer_vars = collections.defaultdict(list)
layer_sqr_norm_fns = collections.defaultdict(list)
# The case of shared weights:
# If a layer is called k times, it will appear k times in filtered_outputs,
# with the same id, but potentially with different v and f. The code below
# groups filtered_outputs by layer_id, so we can correctly compute gradient
# norms. The gradient norm of a layer that occurs k times is computed as
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for layer_id, v, f, weights_list in filtered_outputs:
if trainable_vars is None or any(
w.ref() in trainable_vars for w in weights_list
):
layer_vars[layer_id].append(v)
layer_sqr_norm_fns[layer_id].append(f)
# Second loop evaluates the squared L2 norm functions and appends the results.
layer_grad_vars = tape.gradient(
summed_loss,
layer_vars,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
if not layer_grad_vars:
raise ValueError('The gradient list cannot be empty.')
sqr_norm_list = []
for layer_id in layer_sqr_norm_fns.keys():
fns = layer_sqr_norm_fns[layer_id]
grads = layer_grad_vars[layer_id]
# Number of duplicates for this layer in `filtered_outputs`.
num_passes = len(fns)
if len(fns) != len(grads):
raise ValueError(
'There must be as many gradients as squared norm functions.'
)
# See go/dp-sgd-shared-weights for more details.
for fn, grad in zip(fns, grads):
sqr_norm_list.append(num_passes * fn(grad))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
return _compute_gradient_norms_internal(
registry_fn_outputs_list=generator_outputs_list,
layer_grad_vars=layer_grad_vars,
trainable_vars=trainable_vars,
)


def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
Expand Down Expand Up @@ -299,14 +212,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):

def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model,
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: type_aliases.InputTensors,
y_batch: type_aliases.OutputTensors,
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None,
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
Expand All @@ -319,15 +235,16 @@ def compute_clipped_gradients_and_outputs(
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
containing information required to compute the gradient norms and
contribution counts. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
layer_grad_vars: A mapping of layer id to a list of gradients for each
trainablev ariable in the layer. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axes of each tensor must be the batch dimension.
y_batch: An `OutputTensor` representing a batch of output labels. The first
Expand Down Expand Up @@ -362,13 +279,9 @@ def compute_clipped_gradients_and_outputs(
)
if clipping_loss is None:
clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms(
input_model,
layer_registry,
x_batch,
y_batch,
weight_batch,
num_microbatches=num_microbatches,
gradient_norms = _compute_gradient_norms_internal(
registry_fn_outputs_list=registry_fn_outputs_list,
layer_grad_vars=layer_grad_vars,
trainable_vars=input_model.trainable_variables,
)
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
Expand Down
Loading

0 comments on commit a7556da

Please sign in to comment.