Skip to content

Commit 3dc184e

Browse files
Sparsity Preserving DP-SGD in TF Privacy
Refactor model_forward_backward_pass out of compute_gradients to allow for other optimizations such as sparsity preserving noise to integrate with it. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660503249
1 parent 8294cec commit 3dc184e

File tree

6 files changed

+310
-164
lines changed

6 files changed

+310
-164
lines changed

tensorflow_privacy/privacy/fast_gradient_clipping/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ py_library(
4343
srcs = ["gradient_clipping_utils.py"],
4444
srcs_version = "PY3",
4545
deps = [
46+
":common_manip_utils",
4647
":layer_registry",
4748
":type_aliases",
4849
],
@@ -94,6 +95,7 @@ py_test(
9495
deps = [
9596
":clip_grads",
9697
":common_test_utils",
98+
":gradient_clipping_utils",
9799
":layer_registry",
98100
":type_aliases",
99101
],

tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py

Lines changed: 106 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
import collections
25-
from collections.abc import Sequence
25+
from collections.abc import Mapping, Sequence
2626
from typing import Optional
2727

2828
import tensorflow as tf
@@ -32,73 +32,81 @@
3232
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
3333

3434

35-
def _infer_per_example_loss_fn(model: tf.keras.Model):
36-
"""Infer the per-example loss from model config."""
35+
def _compute_gradient_norms_internal(
36+
registry_fn_outputs_list: Sequence[
37+
gradient_clipping_utils.RegistryGeneratorFunctionOutput
38+
],
39+
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
40+
trainable_vars: Optional[Sequence[tf.Variable]] = None,
41+
):
42+
"""Computes the per-example loss gradient norms for given data.
3743
38-
def _convert(loss_fn):
39-
loss_config = loss_fn.get_config()
40-
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
41-
return loss_fn.from_config(loss_config)
44+
Args:
45+
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
46+
containing information required to compute the gradient norms and
47+
contribution counts. Output from
48+
`gradient_clipping_utils.model_forward_backward_pass()`.
49+
layer_grad_vars: A mapping of layer id to a list of gradients for each
50+
trainablev ariable in the layer. Output from
51+
`gradient_clipping_utils.model_forward_backward_pass()`.
52+
trainable_vars: The list of variables included in computing the gradient
53+
norm. When a layer has multiple variables, we include all the variables if
54+
any of the variables is in the list. If `trainable_vars` is None, all the
55+
variables are included.
4256
43-
model_loss = model.loss
44-
if isinstance(model_loss, tf.keras.losses.Loss):
45-
return _convert(model_loss)
46-
elif isinstance(model_loss, dict):
47-
# Note that we cannot call the public method `.get_compile_config()` because
48-
# it calls a numpy function, which is not supported inside a `tf.function`
49-
# wrapped function.
50-
compile_config = model._compile_config.config # pylint: disable=protected-access
51-
if compile_config is None:
52-
raise ValueError('Model must be compiled for loss function conversion')
53-
# Does a weighted mean of the configured losses. Note that we cannot build
54-
# from the config of the compiled loss because (i) it builds a
55-
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
56-
# during its construction, (ii) non-unique `tf.Variables` cannot be used
57-
# inside a `tf.function`, which is usually where this function is used.
58-
if 'loss_weights' not in compile_config:
59-
raise ValueError(
60-
'Models with multiple loss must have corresponding loss weights for'
61-
' loss function conversion'
62-
)
63-
weights = compile_config['loss_weights']
64-
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
65-
num_losses = len(weights)
57+
Returns:
58+
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
59+
weighted example loss (when num_microbatches is None) or the norm of the
60+
gradient of the i-th microbatch loss (define as a mean over the microbatch).
61+
Note that when the loss is weighted (`weight_batch` is not None), weights
62+
are applied prior to clipping.
6663
67-
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
68-
loss_values = []
69-
if model_loss.keys() - y_pred.keys():
70-
raise ValueError(
71-
'y_pred must contain the same keys and the model losses, but '
72-
'got %s and %s' % (y_pred.keys(), model_loss.keys())
73-
)
74-
if model_loss.keys() - y_true.keys():
75-
raise ValueError(
76-
'y_true must contain the same keys and the model losses, but '
77-
'got %s and %s' % (y_true.keys(), model_loss.keys())
78-
)
79-
if sample_weight is not None:
80-
if model_loss.keys() - sample_weight.keys():
81-
raise ValueError(
82-
'sample_weight must contain the same keys and the model losses,'
83-
' but got %s and %s' % (y_true.keys(), model_loss.keys())
84-
)
85-
for k in y_true.keys():
86-
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
87-
sgl_value = (
88-
weights[k]
89-
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
90-
/ num_losses
91-
)
92-
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
93-
return tf.math.add_n(loss_values)
64+
Raises:
65+
ValueError: If `layer_grad_vars` is empty.
66+
ValueError: If the number of gradients for a layer is not equal to the
67+
number of squared norm functions for that layer.
68+
"""
69+
if trainable_vars is not None:
70+
# Create a set using `ref()` for fast set membership check. tf.Variable
71+
# itself is not hashable.
72+
trainable_vars = set([v.ref() for v in trainable_vars])
9473

95-
return _per_example_loss_fn
96-
else:
97-
raise ValueError(
98-
'Unsupported type for loss function conversion: {}'.format(
99-
type(model_loss)
100-
)
101-
)
74+
layer_sqr_norm_fns = collections.defaultdict(list)
75+
# The case of shared weights:
76+
# If a layer is called k times, it will appear k times in filtered_outputs,
77+
# with the same id, but potentially with different v and f. The code below
78+
# groups filtered_outputs by layer_id, so we can correctly compute gradient
79+
# norms. The gradient norm of a layer that occurs k times is computed as
80+
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
81+
# occurrence. This is an over-estimate of the actual norm. For more details,
82+
# see the explanation in go/dp-sgd-shared-weights.
83+
for registry_fn_output in registry_fn_outputs_list:
84+
if trainable_vars is None or any(
85+
w.ref() in trainable_vars
86+
for w in registry_fn_output.layer_trainable_weights
87+
):
88+
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
89+
registry_fn_output.layer_sqr_norm_fn
90+
)
91+
92+
if not layer_grad_vars:
93+
raise ValueError('The gradient list cannot be empty.')
94+
sqr_norm_list = []
95+
for layer_id in layer_sqr_norm_fns.keys():
96+
fns = layer_sqr_norm_fns[layer_id]
97+
grads = layer_grad_vars[layer_id]
98+
# Number of duplicates for this layer in `filtered_outputs`.
99+
num_passes = len(fns)
100+
if len(fns) != len(grads):
101+
raise ValueError(
102+
'There must be as many gradients as squared norm functions.'
103+
)
104+
# See go/dp-sgd-shared-weights for more details.
105+
for fn, grad in zip(fns, grads):
106+
sqr_norm_list.append(num_passes * fn(grad))
107+
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
108+
gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
109+
return gradient_norms
102110

103111

104112
def compute_gradient_norms(
@@ -110,7 +118,7 @@ def compute_gradient_norms(
110118
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
111119
num_microbatches: Optional[type_aliases.BatchSize] = None,
112120
trainable_vars: Optional[Sequence[tf.Variable]] = None,
113-
):
121+
) -> tf.Tensor:
114122
"""Computes the per-example loss gradient norms for given data.
115123
116124
Applies a variant of the approach given in
@@ -154,90 +162,27 @@ def compute_gradient_norms(
154162
"""
155163
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
156164
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
157-
tape, layer_registry, num_microbatches
165+
tape=tape,
166+
layer_registry=layer_registry,
167+
num_microbatches=num_microbatches,
158168
)
159-
# First loop computes the model outputs, summed loss, and generator outputs.
160-
with tape:
161-
model_outputs, generator_outputs_list = (
162-
gradient_clipping_utils.model_forward_pass(
163-
input_model, x_batch, generator_fn=registry_generator_fn
164-
)
165-
)
166-
167-
# Ignore the original loss function's reduction to get per-example loss.
168-
if per_example_loss_fn is None:
169-
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
170-
171-
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
172-
if losses.shape is None:
173-
raise NotImplementedError(
174-
"The unreduced (or per-example) loss's shape cannot be `None`"
175-
)
176-
if len(losses.shape) != 1:
177-
raise NotImplementedError(
178-
'The unreduced (or per-example) loss needs to have a shape of length '
179-
'one, but received an unreduced loss of shape length %s'
180-
% len(losses.shape)
181-
)
182-
if num_microbatches is not None:
183-
losses = tf.reduce_mean(
184-
common_manip_utils.maybe_add_microbatch_axis(
185-
losses, num_microbatches
186-
),
187-
axis=1,
188-
)
189-
summed_loss = tf.reduce_sum(losses)
190-
# Unwrap the generator outputs so that the next loop avoids duplicating
191-
# backprop ops.
192-
filtered_outputs = [t for t in generator_outputs_list if t is not None]
193-
if trainable_vars is not None:
194-
# Create a set using `ref()` for fast set membership check. tf.Variable
195-
# itself is not hashable.
196-
trainable_vars = set([v.ref() for v in trainable_vars])
197-
layer_vars = collections.defaultdict(list)
198-
layer_sqr_norm_fns = collections.defaultdict(list)
199-
# The case of shared weights:
200-
# If a layer is called k times, it will appear k times in filtered_outputs,
201-
# with the same id, but potentially with different v and f. The code below
202-
# groups filtered_outputs by layer_id, so we can correctly compute gradient
203-
# norms. The gradient norm of a layer that occurs k times is computed as
204-
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
205-
# occurrence. This is an over-estimate of the actual norm. For more details,
206-
# see the explanation in go/dp-sgd-shared-weights.
207-
for registry_fn_output in filtered_outputs:
208-
if trainable_vars is None or any(
209-
w.ref() in trainable_vars
210-
for w in registry_fn_output.layer_trainable_weights
211-
):
212-
layer_vars[registry_fn_output.layer_id].append(
213-
registry_fn_output.layer_vars
169+
layer_grad_vars, generator_outputs_list = (
170+
gradient_clipping_utils.model_forward_backward_pass(
171+
tape=tape,
172+
input_model=input_model,
173+
x_batch=x_batch,
174+
y_batch=y_batch,
175+
registry_generator_fn=registry_generator_fn,
176+
weight_batch=weight_batch,
177+
per_example_loss_fn=per_example_loss_fn,
178+
num_microbatches=num_microbatches,
214179
)
215-
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
216-
registry_fn_output.layer_sqr_norm_fn
217-
)
218-
# Second loop evaluates the squared L2 norm functions and appends the results.
219-
layer_grad_vars = tape.gradient(
220-
summed_loss,
221-
layer_vars,
222-
unconnected_gradients=tf.UnconnectedGradients.ZERO,
223180
)
224-
if not layer_grad_vars:
225-
raise ValueError('The gradient list cannot be empty.')
226-
sqr_norm_list = []
227-
for layer_id in layer_sqr_norm_fns.keys():
228-
fns = layer_sqr_norm_fns[layer_id]
229-
grads = layer_grad_vars[layer_id]
230-
# Number of duplicates for this layer in `filtered_outputs`.
231-
num_passes = len(fns)
232-
if len(fns) != len(grads):
233-
raise ValueError(
234-
'There must be as many gradients as squared norm functions.'
235-
)
236-
# See go/dp-sgd-shared-weights for more details.
237-
for fn, grad in zip(fns, grads):
238-
sqr_norm_list.append(num_passes * fn(grad))
239-
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
240-
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
181+
return _compute_gradient_norms_internal(
182+
registry_fn_outputs_list=generator_outputs_list,
183+
layer_grad_vars=layer_grad_vars,
184+
trainable_vars=trainable_vars,
185+
)
241186

242187

243188
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
@@ -267,14 +212,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
267212

268213
def compute_clipped_gradients_and_outputs(
269214
input_model: tf.keras.Model,
215+
registry_fn_outputs_list: Sequence[
216+
gradient_clipping_utils.RegistryGeneratorFunctionOutput
217+
],
218+
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
270219
l2_norm_clip: float,
271-
layer_registry: lr.LayerRegistry,
272220
x_batch: type_aliases.InputTensors,
273221
y_batch: type_aliases.OutputTensors,
274222
weight_batch: Optional[tf.Tensor] = None,
275223
num_microbatches: Optional[type_aliases.BatchSize] = None,
276224
clipping_loss: Optional[type_aliases.LossFn] = None,
277-
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
225+
) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]:
278226
"""Computes the per-example clipped loss gradient and other useful outputs.
279227
280228
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
@@ -287,15 +235,16 @@ def compute_clipped_gradients_and_outputs(
287235
288236
Args:
289237
input_model: The `tf.keras.Model` from which to obtain the layers from.
238+
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
239+
containing information required to compute the gradient norms and
240+
contribution counts. Output from
241+
`gradient_clipping_utils.model_forward_backward_pass()`.
242+
layer_grad_vars: A mapping of layer id to a list of gradients for each
243+
trainablev ariable in the layer. Output from
244+
`gradient_clipping_utils.model_forward_backward_pass()`.
290245
l2_norm_clip: A `float` indicating the norm to which per-example gradients
291246
will be clipped. That is, all gradients of the per-example loss functions
292247
will have norm at most `l2_norm_clip`.
293-
layer_registry: A `dict` of layers that support "fast" gradient norm
294-
computations. The key is the class of the layer and the value is a
295-
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
296-
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
297-
squared norms of a layer's pre-activation tensor, and `vars` are relevant
298-
trainable weights (see `layer_registry_factories.py` for examples).
299248
x_batch: An `InputTensor` representing a batch of inputs to the model. The
300249
first axes of each tensor must be the batch dimension.
301250
y_batch: An `OutputTensor` representing a batch of output labels. The first
@@ -330,13 +279,9 @@ def compute_clipped_gradients_and_outputs(
330279
)
331280
if clipping_loss is None:
332281
clipping_loss = input_model.compiled_loss
333-
gradient_norms = compute_gradient_norms(
334-
input_model,
335-
layer_registry,
336-
x_batch,
337-
y_batch,
338-
weight_batch,
339-
num_microbatches=num_microbatches,
282+
gradient_norms = _compute_gradient_norms_internal(
283+
registry_fn_outputs_list=registry_fn_outputs_list,
284+
layer_grad_vars=layer_grad_vars,
340285
trainable_vars=input_model.trainable_variables,
341286
)
342287
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)

0 commit comments

Comments
 (0)