22
22
"""
23
23
24
24
import collections
25
- from collections .abc import Sequence
25
+ from collections .abc import Mapping , Sequence
26
26
from typing import Optional
27
27
28
28
import tensorflow as tf
32
32
from tensorflow_privacy .privacy .fast_gradient_clipping import type_aliases
33
33
34
34
35
- def _infer_per_example_loss_fn (model : tf .keras .Model ):
36
- """Infer the per-example loss from model config."""
35
+ def _compute_gradient_norms_internal (
36
+ registry_fn_outputs_list : Sequence [
37
+ gradient_clipping_utils .RegistryGeneratorFunctionOutput
38
+ ],
39
+ layer_grad_vars : Mapping [str , Sequence [type_aliases .Tensor ]],
40
+ trainable_vars : Optional [Sequence [tf .Variable ]] = None ,
41
+ ):
42
+ """Computes the per-example loss gradient norms for given data.
37
43
38
- def _convert (loss_fn ):
39
- loss_config = loss_fn .get_config ()
40
- loss_config ['reduction' ] = tf .keras .losses .Reduction .NONE
41
- return loss_fn .from_config (loss_config )
44
+ Args:
45
+ registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
46
+ containing information required to compute the gradient norms and
47
+ contribution counts. Output from
48
+ `gradient_clipping_utils.model_forward_backward_pass()`.
49
+ layer_grad_vars: A mapping of layer id to a list of gradients for each
50
+ trainablev ariable in the layer. Output from
51
+ `gradient_clipping_utils.model_forward_backward_pass()`.
52
+ trainable_vars: The list of variables included in computing the gradient
53
+ norm. When a layer has multiple variables, we include all the variables if
54
+ any of the variables is in the list. If `trainable_vars` is None, all the
55
+ variables are included.
42
56
43
- model_loss = model .loss
44
- if isinstance (model_loss , tf .keras .losses .Loss ):
45
- return _convert (model_loss )
46
- elif isinstance (model_loss , dict ):
47
- # Note that we cannot call the public method `.get_compile_config()` because
48
- # it calls a numpy function, which is not supported inside a `tf.function`
49
- # wrapped function.
50
- compile_config = model ._compile_config .config # pylint: disable=protected-access
51
- if compile_config is None :
52
- raise ValueError ('Model must be compiled for loss function conversion' )
53
- # Does a weighted mean of the configured losses. Note that we cannot build
54
- # from the config of the compiled loss because (i) it builds a
55
- # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
56
- # during its construction, (ii) non-unique `tf.Variables` cannot be used
57
- # inside a `tf.function`, which is usually where this function is used.
58
- if 'loss_weights' not in compile_config :
59
- raise ValueError (
60
- 'Models with multiple loss must have corresponding loss weights for'
61
- ' loss function conversion'
62
- )
63
- weights = compile_config ['loss_weights' ]
64
- per_example_losses = {k : _convert (v ) for k , v in model_loss .items ()}
65
- num_losses = len (weights )
57
+ Returns:
58
+ A scalar vector, whose i-th entry is the norm of the gradient of the i-th
59
+ weighted example loss (when num_microbatches is None) or the norm of the
60
+ gradient of the i-th microbatch loss (define as a mean over the microbatch).
61
+ Note that when the loss is weighted (`weight_batch` is not None), weights
62
+ are applied prior to clipping.
66
63
67
- def _per_example_loss_fn (y_true , y_pred , sample_weight = None ):
68
- loss_values = []
69
- if model_loss .keys () - y_pred .keys ():
70
- raise ValueError (
71
- 'y_pred must contain the same keys and the model losses, but '
72
- 'got %s and %s' % (y_pred .keys (), model_loss .keys ())
73
- )
74
- if model_loss .keys () - y_true .keys ():
75
- raise ValueError (
76
- 'y_true must contain the same keys and the model losses, but '
77
- 'got %s and %s' % (y_true .keys (), model_loss .keys ())
78
- )
79
- if sample_weight is not None :
80
- if model_loss .keys () - sample_weight .keys ():
81
- raise ValueError (
82
- 'sample_weight must contain the same keys and the model losses,'
83
- ' but got %s and %s' % (y_true .keys (), model_loss .keys ())
84
- )
85
- for k in y_true .keys ():
86
- sgl_sample_weight = None if sample_weight is None else sample_weight [k ]
87
- sgl_value = (
88
- weights [k ]
89
- * per_example_losses [k ](y_true [k ], y_pred [k ], sgl_sample_weight )
90
- / num_losses
91
- )
92
- loss_values .append (tf .reshape (sgl_value , shape = [- 1 ]))
93
- return tf .math .add_n (loss_values )
64
+ Raises:
65
+ ValueError: If `layer_grad_vars` is empty.
66
+ ValueError: If the number of gradients for a layer is not equal to the
67
+ number of squared norm functions for that layer.
68
+ """
69
+ if trainable_vars is not None :
70
+ # Create a set using `ref()` for fast set membership check. tf.Variable
71
+ # itself is not hashable.
72
+ trainable_vars = set ([v .ref () for v in trainable_vars ])
94
73
95
- return _per_example_loss_fn
96
- else :
97
- raise ValueError (
98
- 'Unsupported type for loss function conversion: {}' .format (
99
- type (model_loss )
100
- )
101
- )
74
+ layer_sqr_norm_fns = collections .defaultdict (list )
75
+ # The case of shared weights:
76
+ # If a layer is called k times, it will appear k times in filtered_outputs,
77
+ # with the same id, but potentially with different v and f. The code below
78
+ # groups filtered_outputs by layer_id, so we can correctly compute gradient
79
+ # norms. The gradient norm of a layer that occurs k times is computed as
80
+ # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
81
+ # occurrence. This is an over-estimate of the actual norm. For more details,
82
+ # see the explanation in go/dp-sgd-shared-weights.
83
+ for registry_fn_output in registry_fn_outputs_list :
84
+ if trainable_vars is None or any (
85
+ w .ref () in trainable_vars
86
+ for w in registry_fn_output .layer_trainable_weights
87
+ ):
88
+ layer_sqr_norm_fns [registry_fn_output .layer_id ].append (
89
+ registry_fn_output .layer_sqr_norm_fn
90
+ )
91
+
92
+ if not layer_grad_vars :
93
+ raise ValueError ('The gradient list cannot be empty.' )
94
+ sqr_norm_list = []
95
+ for layer_id in layer_sqr_norm_fns .keys ():
96
+ fns = layer_sqr_norm_fns [layer_id ]
97
+ grads = layer_grad_vars [layer_id ]
98
+ # Number of duplicates for this layer in `filtered_outputs`.
99
+ num_passes = len (fns )
100
+ if len (fns ) != len (grads ):
101
+ raise ValueError (
102
+ 'There must be as many gradients as squared norm functions.'
103
+ )
104
+ # See go/dp-sgd-shared-weights for more details.
105
+ for fn , grad in zip (fns , grads ):
106
+ sqr_norm_list .append (num_passes * fn (grad ))
107
+ sqr_norm_tsr = tf .stack (sqr_norm_list , axis = 1 )
108
+ gradient_norms = tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
109
+ return gradient_norms
102
110
103
111
104
112
def compute_gradient_norms (
@@ -110,7 +118,7 @@ def compute_gradient_norms(
110
118
per_example_loss_fn : Optional [type_aliases .LossFn ] = None ,
111
119
num_microbatches : Optional [type_aliases .BatchSize ] = None ,
112
120
trainable_vars : Optional [Sequence [tf .Variable ]] = None ,
113
- ):
121
+ ) -> tf . Tensor :
114
122
"""Computes the per-example loss gradient norms for given data.
115
123
116
124
Applies a variant of the approach given in
@@ -154,90 +162,27 @@ def compute_gradient_norms(
154
162
"""
155
163
tape = tf .GradientTape (persistent = True , watch_accessed_variables = False )
156
164
registry_generator_fn = gradient_clipping_utils .get_registry_generator_fn (
157
- tape , layer_registry , num_microbatches
165
+ tape = tape ,
166
+ layer_registry = layer_registry ,
167
+ num_microbatches = num_microbatches ,
158
168
)
159
- # First loop computes the model outputs, summed loss, and generator outputs.
160
- with tape :
161
- model_outputs , generator_outputs_list = (
162
- gradient_clipping_utils .model_forward_pass (
163
- input_model , x_batch , generator_fn = registry_generator_fn
164
- )
165
- )
166
-
167
- # Ignore the original loss function's reduction to get per-example loss.
168
- if per_example_loss_fn is None :
169
- per_example_loss_fn = _infer_per_example_loss_fn (input_model )
170
-
171
- losses = per_example_loss_fn (y_batch , model_outputs , weight_batch )
172
- if losses .shape is None :
173
- raise NotImplementedError (
174
- "The unreduced (or per-example) loss's shape cannot be `None`"
175
- )
176
- if len (losses .shape ) != 1 :
177
- raise NotImplementedError (
178
- 'The unreduced (or per-example) loss needs to have a shape of length '
179
- 'one, but received an unreduced loss of shape length %s'
180
- % len (losses .shape )
181
- )
182
- if num_microbatches is not None :
183
- losses = tf .reduce_mean (
184
- common_manip_utils .maybe_add_microbatch_axis (
185
- losses , num_microbatches
186
- ),
187
- axis = 1 ,
188
- )
189
- summed_loss = tf .reduce_sum (losses )
190
- # Unwrap the generator outputs so that the next loop avoids duplicating
191
- # backprop ops.
192
- filtered_outputs = [t for t in generator_outputs_list if t is not None ]
193
- if trainable_vars is not None :
194
- # Create a set using `ref()` for fast set membership check. tf.Variable
195
- # itself is not hashable.
196
- trainable_vars = set ([v .ref () for v in trainable_vars ])
197
- layer_vars = collections .defaultdict (list )
198
- layer_sqr_norm_fns = collections .defaultdict (list )
199
- # The case of shared weights:
200
- # If a layer is called k times, it will appear k times in filtered_outputs,
201
- # with the same id, but potentially with different v and f. The code below
202
- # groups filtered_outputs by layer_id, so we can correctly compute gradient
203
- # norms. The gradient norm of a layer that occurs k times is computed as
204
- # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
205
- # occurrence. This is an over-estimate of the actual norm. For more details,
206
- # see the explanation in go/dp-sgd-shared-weights.
207
- for registry_fn_output in filtered_outputs :
208
- if trainable_vars is None or any (
209
- w .ref () in trainable_vars
210
- for w in registry_fn_output .layer_trainable_weights
211
- ):
212
- layer_vars [registry_fn_output .layer_id ].append (
213
- registry_fn_output .layer_vars
169
+ layer_grad_vars , generator_outputs_list = (
170
+ gradient_clipping_utils .model_forward_backward_pass (
171
+ tape = tape ,
172
+ input_model = input_model ,
173
+ x_batch = x_batch ,
174
+ y_batch = y_batch ,
175
+ registry_generator_fn = registry_generator_fn ,
176
+ weight_batch = weight_batch ,
177
+ per_example_loss_fn = per_example_loss_fn ,
178
+ num_microbatches = num_microbatches ,
214
179
)
215
- layer_sqr_norm_fns [registry_fn_output .layer_id ].append (
216
- registry_fn_output .layer_sqr_norm_fn
217
- )
218
- # Second loop evaluates the squared L2 norm functions and appends the results.
219
- layer_grad_vars = tape .gradient (
220
- summed_loss ,
221
- layer_vars ,
222
- unconnected_gradients = tf .UnconnectedGradients .ZERO ,
223
180
)
224
- if not layer_grad_vars :
225
- raise ValueError ('The gradient list cannot be empty.' )
226
- sqr_norm_list = []
227
- for layer_id in layer_sqr_norm_fns .keys ():
228
- fns = layer_sqr_norm_fns [layer_id ]
229
- grads = layer_grad_vars [layer_id ]
230
- # Number of duplicates for this layer in `filtered_outputs`.
231
- num_passes = len (fns )
232
- if len (fns ) != len (grads ):
233
- raise ValueError (
234
- 'There must be as many gradients as squared norm functions.'
235
- )
236
- # See go/dp-sgd-shared-weights for more details.
237
- for fn , grad in zip (fns , grads ):
238
- sqr_norm_list .append (num_passes * fn (grad ))
239
- sqr_norm_tsr = tf .stack (sqr_norm_list , axis = 1 )
240
- return tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
181
+ return _compute_gradient_norms_internal (
182
+ registry_fn_outputs_list = generator_outputs_list ,
183
+ layer_grad_vars = layer_grad_vars ,
184
+ trainable_vars = trainable_vars ,
185
+ )
241
186
242
187
243
188
def compute_clip_weights (l2_norm_clip : float , gradient_norms : tf .Tensor ):
@@ -267,14 +212,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
267
212
268
213
def compute_clipped_gradients_and_outputs (
269
214
input_model : tf .keras .Model ,
215
+ registry_fn_outputs_list : Sequence [
216
+ gradient_clipping_utils .RegistryGeneratorFunctionOutput
217
+ ],
218
+ layer_grad_vars : Mapping [str , Sequence [type_aliases .Tensor ]],
270
219
l2_norm_clip : float ,
271
- layer_registry : lr .LayerRegistry ,
272
220
x_batch : type_aliases .InputTensors ,
273
221
y_batch : type_aliases .OutputTensors ,
274
222
weight_batch : Optional [tf .Tensor ] = None ,
275
223
num_microbatches : Optional [type_aliases .BatchSize ] = None ,
276
224
clipping_loss : Optional [type_aliases .LossFn ] = None ,
277
- ) -> tuple [Sequence [tf .Tensor ], tf .Tensor , tf .Tensor ]:
225
+ ) -> tuple [Sequence [type_aliases .Tensor ], tf .Tensor , tf .Tensor ]:
278
226
"""Computes the per-example clipped loss gradient and other useful outputs.
279
227
280
228
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
@@ -287,15 +235,16 @@ def compute_clipped_gradients_and_outputs(
287
235
288
236
Args:
289
237
input_model: The `tf.keras.Model` from which to obtain the layers from.
238
+ registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
239
+ containing information required to compute the gradient norms and
240
+ contribution counts. Output from
241
+ `gradient_clipping_utils.model_forward_backward_pass()`.
242
+ layer_grad_vars: A mapping of layer id to a list of gradients for each
243
+ trainablev ariable in the layer. Output from
244
+ `gradient_clipping_utils.model_forward_backward_pass()`.
290
245
l2_norm_clip: A `float` indicating the norm to which per-example gradients
291
246
will be clipped. That is, all gradients of the per-example loss functions
292
247
will have norm at most `l2_norm_clip`.
293
- layer_registry: A `dict` of layers that support "fast" gradient norm
294
- computations. The key is the class of the layer and the value is a
295
- function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
296
- `output` is the pre-activator tensor, `sqr_grad_norms` is related to the
297
- squared norms of a layer's pre-activation tensor, and `vars` are relevant
298
- trainable weights (see `layer_registry_factories.py` for examples).
299
248
x_batch: An `InputTensor` representing a batch of inputs to the model. The
300
249
first axes of each tensor must be the batch dimension.
301
250
y_batch: An `OutputTensor` representing a batch of output labels. The first
@@ -330,13 +279,9 @@ def compute_clipped_gradients_and_outputs(
330
279
)
331
280
if clipping_loss is None :
332
281
clipping_loss = input_model .compiled_loss
333
- gradient_norms = compute_gradient_norms (
334
- input_model ,
335
- layer_registry ,
336
- x_batch ,
337
- y_batch ,
338
- weight_batch ,
339
- num_microbatches = num_microbatches ,
282
+ gradient_norms = _compute_gradient_norms_internal (
283
+ registry_fn_outputs_list = registry_fn_outputs_list ,
284
+ layer_grad_vars = layer_grad_vars ,
340
285
trainable_vars = input_model .trainable_variables ,
341
286
)
342
287
clip_weights = compute_clip_weights (l2_norm_clip , gradient_norms )
0 commit comments