diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index 9828852f..2f7ec0e2 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -13,6 +13,7 @@ py_library( name = "einsum_utils", srcs = ["einsum_utils.py"], srcs_version = "PY3", + deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils"], ) py_test( @@ -24,6 +25,33 @@ py_test( deps = [":einsum_utils"], ) +py_library( + name = "einsum_dense", + srcs = ["einsum_dense.py"], + srcs_version = "PY3", + deps = [ + ":einsum_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", + ], +) + +py_test( + name = "einsum_dense_test", + size = "large", + srcs = ["einsum_dense_test.py"], + python_version = "PY3", + shard_count = 12, + srcs_version = "PY3", + deps = [ + ":dense", + ":einsum_dense", + "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + ], +) + py_library( name = "dense", srcs = ["dense.py"], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense.py new file mode 100644 index 00000000..8c5ecbc7 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense.py @@ -0,0 +1,70 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast clipping function for `tfm.nlp.layers.EinsumDense`.""" + +from collections.abc import Mapping, Sequence +from typing import Any, Optional +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils + + +def einsum_layer_computation( + layer_instance: tf.keras.layers.EinsumDense, + input_args: Sequence[Any], + input_kwargs: Mapping[str, Any], + tape: tf.GradientTape, + num_microbatches: Optional[tf.Tensor] = None, +) -> type_aliases.RegistryFunctionOutput: + """Registry function for `tf.keras.layers.EinsumDense`. + + For the technical details, see the documentation of + `einsum_utils.compute_fast_einsum_gradient_norm()`. + + Args: + layer_instance: A `tf.keras.layers.EinsumDense` instance. + input_args: See `dense_layer_computation()` in `dense.py`. + input_kwargs: See `dense_layer_computation()` in `dense.py`. + tape: See `dense_layer_computation()` in `dense.py`. + num_microbatches: See `dense_layer_computation()` in `dense.py`. + + Returns: + See `dense_layer_computation()` in `dense.py`. + """ + if input_kwargs: + raise ValueError("EinsumDense layer calls should not receive kwargs.") + del input_kwargs + if len(input_args) != 1: + raise ValueError("Only layer inputs of length 1 are permitted.") + orig_activation = layer_instance.activation + # Some activation functions may not apply a transform to the elements of the + # output individually (which is needed for the fast clipping trick to work). + # To avoid this case, we watch the variables that are only generated by the + # linear transformation of the `EinsumDense` layer instance. + layer_instance.activation = None + base_vars = layer_instance(*input_args) + tape.watch(base_vars) + layer_instance.activation = orig_activation + outputs = orig_activation(base_vars) if orig_activation else base_vars + + def sqr_norm_fn(grads): + return einsum_utils.compute_fast_einsum_squared_gradient_norm( + layer_instance.equation, + input_args[0], + grads, + layer_instance.bias_axes, + num_microbatches, + ) + + return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_test.py new file mode 100644 index 00000000..197c5766 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_test.py @@ -0,0 +1,171 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf +import tensorflow_models as tfm +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense + + +def get_einsum_layer_generators(): + def pure_einsum_layer(equation, output_dims, bias_axes): + return tf.keras.layers.EinsumDense( + equation, output_dims, bias_axes=bias_axes + ) + + def sigmoid_einsum_layer(equation, output_dims, bias_axes): + return tf.keras.layers.EinsumDense( + equation, output_dims, bias_axes=bias_axes, activation='sigmoid' + ) + + return { + 'pure_einsum': pure_einsum_layer, + 'sigmoid_einsum': sigmoid_einsum_layer, + } + + +def get_einsum_parameter_tuples(): + # (equation, input_dims, output_dims, bias_axes) + return [ + # Case C1. + ('ab,bc->ac', [2], [3], None), + ('ab,bc->ac', [2], [3], 'c'), + ('abc,cd->abd', [2, 3], [2, 4], None), + ('abc,cd->abd', [2, 3], [2, 4], 'b'), + ('abc,cd->abd', [2, 3], [2, 4], 'd'), + ('abc,cd->abd', [2, 3], [2, 4], 'bd'), + ('abc,cef->abef', [2, 3], [2, 4, 5], None), + ('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'), + # Case C2. + ('...b,bc->...c', [2, 3], [4], None), + ('...b,bc->...c', [2, 3], [4], 'c'), + ('...ab,bc->...ac', [2, 3], [2, 4], None), + ('...ab,bc->...ac', [2, 4], [2, 4], 'c'), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'), + ('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None), + ('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'), + # Case C3. + ('ab...,bc->ac...', [2, 3], [4, 3], None), + ('ab...,bc->ac...', [2, 3], [4, 3], 'c'), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'), + ('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None), + ('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'), + ] + + +def get_einsum_layer_registry(): + einsum_registry = layer_registry.LayerRegistry() + einsum_registry.insert( + tfm.nlp.layers.EinsumDense, + einsum_dense.einsum_layer_computation, + ) + return einsum_registry + + +class GradNormTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.strategy = tf.distribute.get_strategy() + self.using_tpu = False + + @parameterized.product( + layer_name=list(get_einsum_layer_generators()), + param_tuple=get_einsum_parameter_tuples(), + num_microbatches=[None, 2], + is_eager=[True, False], + ) + def test_gradient_norms_on_various_models( + self, + layer_name, + param_tuple, + num_microbatches, + is_eager, + ): + # Parse inputs to generate test data. Note that each batched input is a + # reshape of a `tf.range()` call. + equation, input_dims, output_dims, bias_axes = param_tuple + batch_size = 4 + example_size = tf.reduce_prod(input_dims) + example_values = tf.range(batch_size * example_size, dtype=tf.float32) + x_batch = tf.reshape(example_values, [batch_size] + input_dims) + + # Make the layer generator via currying. + einsum_generator = get_einsum_layer_generators()[layer_name] + + def curried_generator(a, b): + del a, b + return einsum_generator(equation, output_dims, bias_axes) + + # Load shared assets to all devices. + with self.strategy.scope(): + model = common_test_utils.get_model_from_generator( + model_generator=common_test_utils.make_one_layer_functional_model, + layer_generator=curried_generator, + input_dims=input_dims, + output_dims=output_dims, + is_eager=is_eager, + ) + + # Define the main testing ops. These may be later compiled to a Graph op. + def test_op(x): + return common_test_utils.get_computed_and_true_norms_from_model( + model=model, + per_example_loss_fn=None, + num_microbatches=num_microbatches, + x_batch=x, + registry=get_einsum_layer_registry(), + ) + + # TPUs can only run `tf.function`-decorated functions. + if self.using_tpu: + test_op = tf.function(test_op, autograph=False) + + # TPUs use lower precision than CPUs, so we relax our criterion. + # E.g., one of the TPU runs generated the following results: + # + # computed_norm = 93.48296 + # true_norm = 93.31176 + # abs_diff = 0.17120361 + # rel_diff = 0.00183475 + # + # which is a reasonable level of error for computing gradient norms. + # Other trials also give an absolute (resp. relative) error of around + # 0.05 (resp. 0.0015). + rtol = 1e-2 if self.using_tpu else 1e-3 + atol = 5e-1 if self.using_tpu else 1e-2 + + # Set up the device ops and run the test. + computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,)) + # TPUs return replica contexts, which must be unwrapped. + if self.using_tpu: + common_test_utils.assert_replica_values_are_close(self, computed_norms) + common_test_utils.assert_replica_values_are_close(self, true_norms) + computed_norms = computed_norms.values[0] + true_norms = true_norms.values[0] + expected_size = num_microbatches or batch_size + self.assertEqual(tf.shape(computed_norms)[0], expected_size) + self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_tpu_test.py new file mode 100644 index 00000000..69b31b10 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_tpu_test.py @@ -0,0 +1,30 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense_test + + +class GradNormTpuTest(einsum_dense_test.GradNormTest): + + def setUp(self): + super(einsum_dense_test.GradNormTest, self).setUp() + self.strategy = common_test_utils.create_tpu_strategy() + self.assertIn('TPU', self.strategy.extended.worker_devices[0]) + self.using_tpu = True + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py index 4912939b..b7480ac7 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_utils.py @@ -17,9 +17,11 @@ import itertools import os import re +from typing import Optional import numpy as np import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils EquationType = enum.Enum( "EquationType", @@ -139,9 +141,7 @@ def _reshape_einsum_inputs( (num_batches, num_rows, num_columns) ``` When `input_tensor` is a rank-2 `tf.Tensor`, the number of output rows is 1 - and the number of output columns is the second dimension of the input. The - product of the non-trivial dimensions of the output should be equal to - the product of the dimensions of `input_tensor`. + and the number of output columns is the second dimension of the input. Raises: ValueError: If `equation` is not a valid einsum equation in the context of @@ -345,3 +345,141 @@ def _get_einsum_bias_adjoint_reduction_axes( else: reduction_axes.append(idx) return reduction_axes + + +def compute_fast_einsum_squared_gradient_norm( + equation: str, + input_tensor: tf.Tensor, + grad_tensor: tf.Tensor, + bias_axes: Optional[str], + num_microbatches: Optional[int] = None, +) -> tf.Tensor: + """Computes the batch gradient norms of an Einsum gradient decompostion. + + This logic generalizes the one for `tf.keras.layers.Dense` and assumes that + the `equation` parameter is one of the following forms: + + C1. ab,bc->ac, + C2. ...ab,bc->...ac + C3. ab...,bc->ac... + + where `a`, `b`, and `c` are non-empty substrings. + + For reference, we describe part of the mathematical analysis below. It can be + safely skipped upon the first reading of this docstring. + + ----------------------------------------------------------------------------- + BEGIN ANALYSIS + ----------------------------------------------------------------------------- + For ease of exposition, all analysis is done for a single example, i.e., + batch dimension is excluded from our consideration. + + Recall that the einsum dense computation, excluding activation functions is of + the form + ``` + output = tf.einsum(equation, input, kernel) + bias, + ``` + where `bias` is broadcasted and summed with the output of the `tf.einsum()` + call, and equation is of the forms in C1, C2, and C3. + + Mathematically, the above computation is equivalent to: + ``` + output = tf.matmul(X, W) + Q(bias) + ``` + where `X` (resp. `W`) is a 2D tensor reshaped from `input` (resp. `kernel`) + and `Q` is a linear operator that transforms `bias` to comport with the + tensor output by the `tf.matmul()` call. When generalizing to a batch of + examples, `X` is a 3D tensor whose first dimension is the batch dimension. + + Following the same trick as for `tf.keras.layers.Dense` layers, suppose that + we have: + ``` + loss = f(output) + G = tape.gradient(loss, output) + ``` + Then, using the chain rule and denoting `A'` to be the adjoint of a matrix + `A`, one can show that the gradient of `loss` with respect to `W` is given by + the block matrix `K := [X'G; Q'G]`. Hence, the square norm of `K`, i.e., what + is returned by `sqr_norm_fn` is given by + ``` + sqr_norm = + ||Q'G||_F^2 + ``` + where `||.||_F` is the Frobenius norm and `<.,.>` is the Euclidean inner + product for matrices. + ----------------------------------------------------------------------------- + END ANALYSIS + ----------------------------------------------------------------------------- + + Args: + equation: A `string` representing the einsum equation. + input_tensor: A `tf.Tensor` reprenting the einsum input. + grad_tensor: A `tf.Tensor` that is the gradient of the scalar loss with + respect to the pre-activation tensor. + bias_axes: An optional `string` that specifies the einsum biases in + `equation`. + num_microbatches: An optional `int` that specifies the number of + microbatches used in a batch. + + Returns: + A 1D `tf.Tensor` whose i-th entry is the squared gradient corresponding + to the i-th example in `input_tensor`. + """ + # Compute the matrix `X X'` and `G G'` for each example or microbatch. + # `x.shape = (batch_size, num_rows, num_columns)` + x = _reshape_einsum_inputs(input_tensor, equation) + g = _reshape_einsum_outputs(grad_tensor, equation) + # Adding microbatches is equivalent to splitting the first `(batch_size)` + # axis into `(num_microbatches, microbatch_size)` axes and merging the + # `microbatch_size` axis with the `num_rows` axis via a reshape. + if num_microbatches is not None: + # `x.shape = (num_microbatches, microbatch_size, num_rows, num_columns)` + x = common_manip_utils.maybe_add_microbatch_axis(x, num_microbatches) + g = common_manip_utils.maybe_add_microbatch_axis(g, num_microbatches) + sx = tf.shape(x) + sg = tf.shape(g) + # `x.shape = (num_microbatches, microbatch_size * num_rows, num_columns)` + x = tf.reshape(x, shape=[sx[0], sx[1] * sx[2], sx[3]]) + g = tf.reshape(g, shape=[sg[0], sg[1] * sg[2], sg[3]]) + # NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do + # a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`. + if ( + _is_batch_of_vectors(input_tensor) + and _is_batch_of_vectors(grad_tensor) + and num_microbatches is None + ): + x_matrix = tf.reshape(x, [tf.shape(x)[0], -1]) + g_matrix = tf.reshape(g, [tf.shape(g)[0], -1]) + batch_xxt = tf.reduce_sum(tf.square(x_matrix), axis=1) + batch_ggt = tf.reduce_sum(tf.square(g_matrix), axis=1) + else: + batch_xxt = tf.matmul(x, x, transpose_b=True) + batch_ggt = tf.matmul(g, g, transpose_b=True) + # Compute the (micro)batch inner product; adjust for biases if necessary. + batch_xxt_ggt = tf.multiply(batch_xxt, batch_ggt) + reduction_axes = tf.range(1, tf.rank(batch_xxt_ggt)) + sqr_norms = tf.reduce_sum(batch_xxt_ggt, axis=reduction_axes) + if bias_axes is not None: + # The adjoint operator `Q` on `G` is a reduce sum on the axes in `G` that + # are not broadcasted from `bias`. + grad_rank = len(grad_tensor.shape) + adjoint_reduction_axes = _get_einsum_bias_adjoint_reduction_axes( + equation, + bias_axes, + grad_rank, + ) + # Adding microbatches with non-trival bias axes is equivalent to splitting + # the first `(batch_size)` axis into `(num_microbatches, microbatch_size)` + # axes, and adding the `microbatch_size` axis (=1) to the reduction axes + # needed to compute the bias broadcast adjoint operator. + if num_microbatches is not None: + grad_tensor = common_manip_utils.maybe_add_microbatch_axis( + grad_tensor, num_microbatches + ) + adjoint_reduction_axes = [i + 1 for i in adjoint_reduction_axes] + adjoint_reduction_axes = [1] + adjoint_reduction_axes + qg = tf.reduce_sum(grad_tensor, axis=adjoint_reduction_axes) + qg_reduction_axes = tf.range(1, tf.rank(qg)) + bias_sqr_norms = tf.reduce_sum(tf.square(qg), axis=qg_reduction_axes) + sqr_norms += bias_sqr_norms + + return sqr_norms diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py index d92b38f7..96f0e304 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py @@ -13,7 +13,7 @@ # limitations under the License. import tensorflow as tf -from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test @@ -21,7 +21,7 @@ class GradNormTpuTest(layer_normalization_test.GradNormTest): def setUp(self): super(layer_normalization_test.GradNormTest, self).setUp() - self.strategy = ctu.create_tpu_strategy() + self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.using_tpu = True