From ec73ff92beae0abceabe89660bfd3f8ccc96e160 Mon Sep 17 00:00:00 2001 From: William Kong Date: Wed, 22 Nov 2023 07:08:56 -0800 Subject: [PATCH] Add support for fast clipping of dense layer gradients where the dimension of the input is larger than 1. This change specifically wraps the fast clipping logic used in EinsumDense layers, which is a generalization of the Gramian-based that was used for dense layer clipping. PiperOrigin-RevId: 584619076 --- .../registry_functions/BUILD | 1 + .../registry_functions/dense.py | 30 +++++-------------- .../registry_functions/einsum_utils.py | 9 +++--- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index b72263e5f..ab16da6e8 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -57,6 +57,7 @@ py_library( srcs = ["dense.py"], srcs_version = "PY3", deps = [ + ":einsum_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", ], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py index 7c49c5b24..4218a1ad7 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py @@ -16,8 +16,8 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional import tensorflow as tf -from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils def dense_layer_computation( @@ -74,28 +74,12 @@ def dense_layer_computation( outputs = orig_activation(base_vars) if orig_activation else base_vars def sqr_norm_fn(base_vars_grads): - def _compute_gramian(x): - if num_microbatches is not None: - x_microbatched = common_manip_utils.maybe_add_microbatch_axis( - x, - num_microbatches, - ) - return tf.matmul(x_microbatched, x_microbatched, transpose_b=True) - else: - # Special handling for better efficiency - return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x))) - - inputs_gram = _compute_gramian(*input_args) - base_vars_grads_gram = _compute_gramian(base_vars_grads) - if layer_instance.use_bias: - # Adding a bias term is equivalent to a layer with no bias term and which - # adds an additional variable to the layer input that only takes a - # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum - # of the squared values of the inputs. - inputs_gram += 1.0 - return tf.reduce_sum( - inputs_gram * base_vars_grads_gram, - axis=tf.range(1, tf.rank(inputs_gram)), + return einsum_utils.compute_fast_einsum_squared_gradient_norm( + "...b,bc->...c", + input_args[0], + base_vars_grads, + "c" if layer_instance.use_bias else None, + num_microbatches, ) return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py index b7480ac79..84b79e332 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py @@ -19,7 +19,6 @@ import re from typing import Optional -import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils @@ -198,10 +197,10 @@ def _reshape_einsum_inputs( pivot_idx = b_idx # The output tensor is a batched set of matrices, split at the pivot index # of the previously prepped tensor. - base_tensor_shape = input_tensor.shape - batch_size = base_tensor_shape[0] - num_rows = int(np.prod(base_tensor_shape[1:pivot_idx])) - num_columns = int(np.prod(base_tensor_shape[pivot_idx:])) + input_shape = tf.shape(input_tensor) + batch_size = input_shape[0] + num_rows = tf.reduce_prod(input_shape[1:pivot_idx]) + num_columns = tf.reduce_prod(input_shape[pivot_idx:]) return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns])