diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index a3168499..100e66ab 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -32,43 +32,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases -def get_registry_generator_fn( - tape: tf.GradientTape, - layer_registry: lr.LayerRegistry, - num_microbatches: Optional[type_aliases.BatchSize] = None, -): - """Creates the generator function for `compute_gradient_norms()`.""" - if layer_registry is None: - # Needed for backwards compatibility. - registry_generator_fn = None - else: - - def registry_generator_fn(layer_instance, args, kwargs): - if layer_instance.trainable_variables: - # Only trainable variables factor into the gradient. - if not layer_registry.is_elem(layer_instance): - raise NotImplementedError( - 'Layer %s is not in the registry of known layers that can ' - 'be used for efficient gradient clipping.' - % layer_instance.__class__.__name__ - ) - registry_fn = layer_registry.lookup(layer_instance) - (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( - layer_instance, args, kwargs, tape, num_microbatches - ) - return layer_outputs, ( - str(id(layer_instance)), - layer_vars, - layer_sqr_norm_fn, - layer_instance.trainable_weights, - ) - else: - # Non-trainable layer. - return layer_instance(*args, **kwargs), None - - return registry_generator_fn - - def _infer_per_example_loss_fn(model: tf.keras.Model): """Infer the per-example loss from model config.""" @@ -190,7 +153,7 @@ def compute_gradient_norms( are applied prior to clipping. """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) - registry_generator_fn = get_registry_generator_fn( + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( tape, layer_registry, num_microbatches ) # First loop computes the model outputs, summed loss, and generator outputs. @@ -241,12 +204,17 @@ def compute_gradient_norms( # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th # occurrence. This is an over-estimate of the actual norm. For more details, # see the explanation in go/dp-sgd-shared-weights. - for layer_id, v, f, weights_list in filtered_outputs: + for registry_fn_output in filtered_outputs: if trainable_vars is None or any( - w.ref() in trainable_vars for w in weights_list + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights ): - layer_vars[layer_id].append(v) - layer_sqr_norm_fns[layer_id].append(f) + layer_vars[registry_fn_output.layer_id].append( + registry_fn_output.layer_vars + ) + layer_sqr_norm_fns[registry_fn_output.layer_id].append( + registry_fn_output.layer_sqr_norm_fn + ) # Second loop evaluates the squared L2 norm functions and appends the results. layer_grad_vars = tape.gradient( summed_loss, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index f179a556..7a060e9b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,14 +13,23 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" -from collections.abc import Sequence, Set -from typing import Any, Optional +from collections.abc import Callable, Sequence, Set +import dataclasses +from typing import Any, Optional, Tuple import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +@dataclasses.dataclass(frozen=True) +class RegistryGeneratorFunctionOutput: + layer_id: str + layer_vars: Optional[Sequence[tf.Variable]] + layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction] + layer_trainable_weights: Optional[Sequence[tf.Variable]] + + def has_internal_compute_graph(input_object: Any): """Checks if input is a TF model and has a TF internal compute graph.""" return ( @@ -32,6 +41,63 @@ def has_internal_compute_graph(input_object: Any): ) +def get_registry_generator_fn( + tape: tf.GradientTape, + layer_registry: lr.LayerRegistry, + num_microbatches: Optional[type_aliases.BatchSize] = None, +) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]: + """Creates the generator function for `model_forward_backward_pass()`. + + Args: + tape: The `tf.GradientTape` to use for the gradient computation. + layer_registry: A `dict` of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable + num_microbatches: An optional number or scalar `tf.Tensor` for the number of + microbatches. If not None, indicates that the loss is grouped into + num_microbatches (in this case, the batch dimension needs to be a multiple + of num_microbatches). + + Returns: + A function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable variables. + """ + if layer_registry is None: + # Needed for backwards compatibility. + registry_generator_fn = None + else: + + def registry_generator_fn(layer_instance, args, kwargs): + if layer_instance.trainable_variables: + # Only trainable variables factor into the gradient. + if not layer_registry.is_elem(layer_instance): + raise NotImplementedError( + 'Layer %s is not in the registry of known layers that can ' + 'be used for efficient gradient clipping.' + % layer_instance.__class__.__name__ + ) + registry_fn = layer_registry.lookup(layer_instance) + (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( + layer_instance, args, kwargs, tape, num_microbatches + ) + return layer_outputs, RegistryGeneratorFunctionOutput( + layer_id=str(id(layer_instance)), + layer_vars=layer_vars, + layer_sqr_norm_fn=layer_sqr_norm_fn, + layer_trainable_weights=layer_instance.trainable_weights, + ) + else: + # Non-trainable layer. + return layer_instance(*args, **kwargs), None + + return registry_generator_fn + + def model_forward_pass( input_model: tf.keras.Model, inputs: type_aliases.PackedTensors,