Skip to content

Commit

Permalink
Add support for fast clipping of dense layer gradients where the dime…
Browse files Browse the repository at this point in the history
…nsion of the input is larger than 1.

This change specifically wraps the fast clipping logic used in EinsumDense layers, which is a generalization of the Gramian-based that was used for dense layer clipping.

PiperOrigin-RevId: 584619076
  • Loading branch information
wwkong authored and tensorflower-gardener committed Nov 28, 2023
1 parent b19088f commit ec73ff9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ py_library(
srcs = ["dense.py"],
srcs_version = "PY3",
deps = [
":einsum_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils


def dense_layer_computation(
Expand Down Expand Up @@ -74,28 +74,12 @@ def dense_layer_computation(
outputs = orig_activation(base_vars) if orig_activation else base_vars

def sqr_norm_fn(base_vars_grads):
def _compute_gramian(x):
if num_microbatches is not None:
x_microbatched = common_manip_utils.maybe_add_microbatch_axis(
x,
num_microbatches,
)
return tf.matmul(x_microbatched, x_microbatched, transpose_b=True)
else:
# Special handling for better efficiency
return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x)))

inputs_gram = _compute_gramian(*input_args)
base_vars_grads_gram = _compute_gramian(base_vars_grads)
if layer_instance.use_bias:
# Adding a bias term is equivalent to a layer with no bias term and which
# adds an additional variable to the layer input that only takes a
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
# of the squared values of the inputs.
inputs_gram += 1.0
return tf.reduce_sum(
inputs_gram * base_vars_grads_gram,
axis=tf.range(1, tf.rank(inputs_gram)),
return einsum_utils.compute_fast_einsum_squared_gradient_norm(
"...b,bc->...c",
input_args[0],
base_vars_grads,
"c" if layer_instance.use_bias else None,
num_microbatches,
)

return base_vars, outputs, sqr_norm_fn
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import re
from typing import Optional

import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils

Expand Down Expand Up @@ -198,10 +197,10 @@ def _reshape_einsum_inputs(
pivot_idx = b_idx
# The output tensor is a batched set of matrices, split at the pivot index
# of the previously prepped tensor.
base_tensor_shape = input_tensor.shape
batch_size = base_tensor_shape[0]
num_rows = int(np.prod(base_tensor_shape[1:pivot_idx]))
num_columns = int(np.prod(base_tensor_shape[pivot_idx:]))
input_shape = tf.shape(input_tensor)
batch_size = input_shape[0]
num_rows = tf.reduce_prod(input_shape[1:pivot_idx])
num_columns = tf.reduce_prod(input_shape[pivot_idx:])
return tf.reshape(input_tensor, shape=[batch_size, num_rows, num_columns])


Expand Down

0 comments on commit ec73ff9

Please sign in to comment.