From 6512e442e62c51367f6105b2642ed982761f26f4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 23 Feb 2023 13:01:51 -0800 Subject: [PATCH] Add registry function and tests for `tf.keras.layers.EinsumDense`. PiperOrigin-RevId: 511864469 --- .../privacy/fast_gradient_clipping/BUILD | 8 + .../fast_gradient_clipping/clip_grads.py | 30 +- .../fast_gradient_clipping/clip_grads_test.py | 256 ++++++++++++--- .../fast_gradient_clipping/einsum_utils.py | 302 ++++++++++++++++++ .../gradient_clipping_utils.py | 38 ++- .../fast_gradient_clipping/layer_registry.py | 107 +++++-- 6 files changed, 652 insertions(+), 89 deletions(-) create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/einsum_utils.py diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index ffa666f3a..2df0e1cb3 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -6,12 +6,20 @@ py_library( name = "gradient_clipping_utils", srcs = ["gradient_clipping_utils.py"], srcs_version = "PY3", + deps = [":layer_registry"], +) + +py_library( + name = "einsum_utils", + srcs = ["einsum_utils.py"], + srcs_version = "PY3", ) py_library( name = "layer_registry", srcs = ["layer_registry.py"], srcs_version = "PY3", + deps = [":einsum_utils"], ) py_library( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 6a37ae30f..42fdc1e95 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,11 +21,20 @@ `compute_gradient_norms()` function). """ +from typing import Union, Iterable, Text, TypeAlias + import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr + +InputTensor: TypeAlias = Union[ + tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor] +] -def get_registry_generator_fn(tape, layer_registry): +def get_registry_generator_fn( + tape: tf.GradientTape, layer_registry: lr.LayerRegistry +): """Creates the generator function for `compute_gradient_norms()`.""" if layer_registry is None: # Needed for backwards compatibility. @@ -53,7 +62,12 @@ def registry_generator_fn(layer_instance, args, kwargs): return registry_generator_fn -def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): +def compute_gradient_norms( + input_model: tf.keras.Model, + x_batch: InputTensor, + y_batch: tf.Tensor, + layer_registry: lr.LayerRegistry, +): """Computes the per-example loss gradient norms for given data. Applies a variant of the approach given in @@ -62,7 +76,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): Args: input_model: The `tf.keras.Model` from which to obtain the layers from. The loss of the model *must* be a scalar loss. - x_batch: A `tf.Tensor` representing a batch of inputs to the model. The + x_batch: An `InputTensor` representing a batch of inputs to the model. The first axis must be the batch dimension. y_batch: A `tf.Tensor` representing a batch of output labels. The first axis must be the batch dimension. The number of examples should match the @@ -106,7 +120,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) -def compute_clip_weights(l2_norm_clip, gradient_norms): +def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): """Computes the per-example loss/clip weights for clipping. When the sum of the per-example losses is replaced a weighted sum, where @@ -132,7 +146,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms): def compute_pred_and_clipped_gradients( - input_model, x_batch, y_batch, l2_norm_clip, layer_registry + input_model: tf.keras.Model, + x_batch: InputTensor, + y_batch: tf.Tensor, + l2_norm_clip: float, + layer_registry: lr.LayerRegistry, ): """Computes the per-example predictions and per-example clipped loss gradient. @@ -147,7 +165,7 @@ def compute_pred_and_clipped_gradients( Args: input_model: The `tf.keras.Model` from which to obtain the layers from. - x_batch: A `tf.Tensor` representing a batch of inputs to the model. The + x_batch: An `InputTensor` representing a batch of inputs to the model. The first axis must be the batch dimension. y_batch: A `tf.Tensor` representing a batch of output labels. The first axis must be the batch dimension. The number of examples should match the diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 183d890ce..6f923db8f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -13,6 +13,8 @@ # limitations under the License. import itertools +from typing import Callable, Any, List, TypeAlias + from absl.testing import parameterized import tensorflow as tf @@ -20,23 +22,37 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +# ============================================================================== +# Type aliases +# ============================================================================== +LayerGenerator: TypeAlias = Callable[ + [List[int], List[int]], tf.keras.layers.Layer +] + +ModelGenerator: TypeAlias = Callable[ + [LayerGenerator, List[int], List[int]], tf.keras.Model +] + + # ============================================================================== # Helper functions and classes. # ============================================================================== class DoubleDense(tf.keras.layers.Layer): """Generates two dense layers nested together.""" - def __init__(self, units): + def __init__(self, shape: List[int]): super().__init__() - self.dense1 = tf.keras.layers.Dense(units) + self.dense1 = tf.keras.layers.Dense(shape) self.dense2 = tf.keras.layers.Dense(1) - def call(self, inputs): + def call(self, inputs: Any): x = self.dense1(inputs) return self.dense2(x) -def double_dense_layer_computation(layer_instance, inputs, tape): +def double_dense_layer_computation( + layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape +): """Layer registry function for the custom `DoubleDense` layer class.""" vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( layer_instance.dense1, inputs, tape @@ -53,7 +69,9 @@ def sqr_norm_fn(base_vars): return [vars1, vars2], outputs, sqr_norm_fn -def compute_true_gradient_norms(input_model, x_batch, y_batch): +def compute_true_gradient_norms( + input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor +): """Computes the real gradient norms for an input `(model, x, y)`.""" loss_config = input_model.loss.get_config() loss_config['reduction'] = tf.keras.losses.Reduction.NONE @@ -73,14 +91,14 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch): def get_computed_and_true_norms( - model_generator, - layer_generator, - input_dims, - output_dim, - is_eager, - x_input, - rng_seed=777, - registry=None, + model_generator: ModelGenerator, + layer_generator: LayerGenerator, + input_dims: List[int], + output_dims: List[int], + is_eager: bool, + x_input: tf.Tensor, + rng_seed: int = 777, + registry: layer_registry.LayerRegistry = None, ): """Obtains the true and computed gradient norms for a model and batch input. @@ -96,7 +114,7 @@ def get_computed_and_true_norms( Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension `idim` and returns output tensors of dimension `odim`. input_dims: The input dimension(s) of the test `tf.keras.Model` instance. - output_dim: The output dimension of the test `tf.keras.Model` instance. + output_dims: The output dimension(s) of the test `tf.keras.Model` instance. is_eager: A `bool` that is `True` if the model should be run eagerly. x_input: `tf.Tensor` inputs to be tested. rng_seed: An `int` used to initialize model weights. @@ -109,7 +127,7 @@ def get_computed_and_true_norms( model and layer generators. The second element contains the true clipped gradient norms under the aforementioned setting. """ - model = model_generator(layer_generator, input_dims, output_dim) + model = model_generator(layer_generator, input_dims, output_dims) model.compile( optimizer=tf.keras.optimizers.SGD(learning_rate=1.0), loss=tf.keras.losses.MeanSquaredError( @@ -131,61 +149,71 @@ def get_computed_and_true_norms( # ============================================================================== # Model generators. # ============================================================================== -def make_two_layer_sequential_model(layer_generator, input_dim, output_dim): +def make_one_layer_sequential_model(layer_generator, input_dims, output_dims): + """Creates a 1-layer sequential model.""" + inputs = tf.keras.Input(shape=input_dims) + layer1 = layer_generator(input_dims, output_dims) + temp1 = layer1(inputs) + reduction_axes1 = tf.range(1, len(temp1.shape)) + outputs = tf.reduce_sum(temp1, axis=reduction_axes1, keepdims=True) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + +def make_two_layer_sequential_model(layer_generator, input_dims, output_dims): """Creates a 2-layer sequential model.""" model = tf.keras.Sequential() - model.add(tf.keras.Input(shape=(input_dim,))) - model.add(layer_generator(input_dim, output_dim)) + model.add(tf.keras.Input(shape=input_dims)) + model.add(layer_generator(input_dims, output_dims)) model.add(tf.keras.layers.Dense(1)) return model -def make_three_layer_sequential_model(layer_generator, input_dim, output_dim): +def make_three_layer_sequential_model(layer_generator, input_dims, output_dims): """Creates a 3-layer sequential model.""" model = tf.keras.Sequential() - model.add(tf.keras.Input(shape=(input_dim,))) - layer1 = layer_generator(input_dim, output_dim) + model.add(tf.keras.Input(shape=input_dims)) + layer1 = layer_generator(input_dims, output_dims) model.add(layer1) if isinstance(layer1, tf.keras.layers.Embedding): # Having multiple consecutive embedding layers does not make sense since # embedding layers only map integers to real-valued vectors. - model.add(tf.keras.layers.Dense(output_dim)) + model.add(tf.keras.layers.Dense(output_dims)) else: - model.add(layer_generator(output_dim, output_dim)) + model.add(layer_generator(output_dims, output_dims)) model.add(tf.keras.layers.Dense(1)) return model -def make_two_layer_functional_model(layer_generator, input_dim, output_dim): +def make_two_layer_functional_model(layer_generator, input_dims, output_dims): """Creates a 2-layer 1-input functional model with a pre-output square op.""" - inputs = tf.keras.Input(shape=(input_dim,)) - layer1 = layer_generator(input_dim, output_dim) + inputs = tf.keras.Input(shape=input_dims) + layer1 = layer_generator(input_dims, output_dims) temp1 = layer1(inputs) temp2 = tf.square(temp1) outputs = tf.keras.layers.Dense(1)(temp2) return tf.keras.Model(inputs=inputs, outputs=outputs) -def make_two_tower_model(layer_generator, input_dim, output_dim): +def make_two_tower_model(layer_generator, input_dims, output_dims): """Creates a 2-layer 2-input functional model.""" - inputs1 = tf.keras.Input(shape=(input_dim,)) - layer1 = layer_generator(input_dim, output_dim) + inputs1 = tf.keras.Input(shape=input_dims) + layer1 = layer_generator(input_dims, output_dims) temp1 = layer1(inputs1) - inputs2 = tf.keras.Input(shape=(input_dim,)) - layer2 = layer_generator(input_dim, output_dim) + inputs2 = tf.keras.Input(shape=input_dims) + layer2 = layer_generator(input_dims, output_dims) temp2 = layer2(inputs2) temp3 = tf.add(temp1, temp2) outputs = tf.keras.layers.Dense(1)(temp3) return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs) -def make_bow_model(layer_generator, input_dims, output_dim): +def make_bow_model(layer_generator, input_dims, output_dims): del layer_generator inputs = tf.keras.Input(shape=input_dims) # For the Embedding layer, input_dim is the vocabulary size. This should # be distinguished from the input_dim argument, which is the number of ids # in eache example. - emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) + emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dims[0]) feature_embs = emb_layer(inputs) reduction_axes = tf.range(1, len(feature_embs.shape)) example_embs = tf.expand_dims( @@ -194,7 +222,7 @@ def make_bow_model(layer_generator, input_dims, output_dim): return tf.keras.Model(inputs=inputs, outputs=example_embs) -def make_dense_bow_model(layer_generator, input_dims, output_dim): +def make_dense_bow_model(layer_generator, input_dims, output_dims): del layer_generator inputs = tf.keras.Input(shape=input_dims) # For the Embedding layer, input_dim is the vocabulary size. This should @@ -202,7 +230,7 @@ def make_dense_bow_model(layer_generator, input_dims, output_dim): # in eache example. cardinality = 10 emb_layer = tf.keras.layers.Embedding( - input_dim=cardinality, output_dim=output_dim + input_dim=cardinality, output_dim=output_dims[0] ) feature_embs = emb_layer(inputs) reduction_axes = tf.range(1, len(feature_embs.shape)) @@ -213,7 +241,7 @@ def make_dense_bow_model(layer_generator, input_dims, output_dim): return tf.keras.Model(inputs=inputs, outputs=outputs) -def make_weighted_bow_model(layer_generator, input_dims, output_dim): +def make_weighted_bow_model(layer_generator, input_dims, output_dims): # NOTE: This model only accepts dense input tensors. del layer_generator inputs = tf.keras.Input(shape=input_dims) @@ -222,7 +250,7 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim): # in eache example. cardinality = 10 emb_layer = tf.keras.layers.Embedding( - input_dim=cardinality, output_dim=output_dim + input_dim=cardinality, output_dim=output_dims[0] ) feature_embs = emb_layer(inputs) feature_weights = tf.random.uniform(tf.shape(feature_embs)) @@ -238,7 +266,7 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim): # ============================================================================== # Factory functions. # ============================================================================== -def get_nd_test_tensors(n): +def get_nd_test_tensors(n: int): """Returns a list of candidate tests for a given dimension n.""" return [ tf.zeros((n,), dtype=tf.float64), @@ -246,7 +274,7 @@ def get_nd_test_tensors(n): ] -def get_nd_test_batches(n): +def get_nd_test_batches(n: int): """Returns a list of candidate input batches of dimension n.""" result = [] tensors = get_nd_test_tensors(n) @@ -263,17 +291,17 @@ def sigmoid_dense_layer(b): return tf.keras.layers.Dense(b, activation='sigmoid') return { - 'pure_dense': lambda a, b: tf.keras.layers.Dense(b), - 'sigmoid_dense': lambda a, b: sigmoid_dense_layer(b), + 'pure_dense': lambda a, b: tf.keras.layers.Dense(b[0]), + 'sigmoid_dense': lambda a, b: sigmoid_dense_layer(b[0]), } def get_dense_model_generators(): return { - 'seq1': make_two_layer_sequential_model, - 'seq2': make_three_layer_sequential_model, - 'func1': make_two_layer_functional_model, - 'tower1': make_two_tower_model, + 'seq2': make_two_layer_sequential_model, + 'seq3': make_three_layer_sequential_model, + 'func2': make_two_layer_functional_model, + 'tower2': make_two_tower_model, } @@ -285,6 +313,64 @@ def get_embedding_model_generators(): } +def get_einsum_layer_generators(): + def pure_einsum_layer(equation, output_dims, bias_axes): + return tf.keras.layers.EinsumDense( + equation, output_dims, bias_axes=bias_axes + ) + + def sigmoid_einsum_layer(equation, output_dims, bias_axes): + return tf.keras.layers.EinsumDense( + equation, output_dims, bias_axes=bias_axes, activation='sigmoid' + ) + + return { + 'pure_einsum': pure_einsum_layer, + 'sigmoid_einsum': sigmoid_einsum_layer, + } + + +def get_einsum_model_generators(): + return { + 'seq1': make_one_layer_sequential_model, + } + + +def get_einsum_parameter_tuples(): + """Consists of (equation, input_dims, output_dims, bias_axes).""" + return [ + # Case (C1). + ('ab,bc->ac', [2], [3], None), + ('ab,bc->ac', [2], [3], 'c'), + ('abc,cd->abd', [2, 3], [2, 4], None), + ('abc,cd->abd', [2, 3], [2, 4], 'b'), + ('abc,cd->abd', [2, 3], [2, 4], 'd'), + ('abc,cd->abd', [2, 3], [2, 4], 'bd'), + ('abc,cef->abef', [2, 3], [2, 4, 5], None), + ('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'), + # Case (C2). + ('...b,bc->...c', [2, 3], [4], None), + ('...b,bc->...c', [2, 3], [4], 'c'), + ('...ab,bc->...ac', [2, 3], [2, 4], None), + ('...ab,bc->...ac', [2, 4], [2, 4], 'c'), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'), + ('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'), + ('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None), + ('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'), + # Case (C3). + ('ab...,bc->ac...', [2, 3], [4, 3], None), + ('ab...,bc->ac...', [2, 3], [4, 3], 'c'), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'), + ('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'), + ('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None), + ('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'), + ] + + # ============================================================================== # Main tests. # ============================================================================== @@ -301,7 +387,7 @@ def test_clip_weights(self, input_dim, clip_value): self.assertAllLessEqual(t * weights, clip_value + tol) -class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): +class ClipGradsOneDimDenseLayerTest(tf.test.TestCase, parameterized.TestCase): @parameterized.product( model_name=list(get_dense_model_generators().keys()), @@ -318,15 +404,15 @@ def test_gradient_norms_on_various_models( x_batches = get_nd_test_batches(input_dim) default_registry = layer_registry.make_default_layer_registry() for x_batch in x_batches: - if model_name == 'tower1': + if model_name == 'tower2': x_input = [x_batch, x_batch] else: x_input = x_batch (computed_norms, true_norms) = get_computed_and_true_norms( model_generator, layer_generator, - input_dim, - output_dim, + [input_dim], + [output_dim], is_eager, x_input, registry=default_registry, @@ -334,6 +420,34 @@ def test_gradient_norms_on_various_models( self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) +class ClipGradsTwoDimDenseLayerTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + layer_name=list(get_dense_layer_generators().keys()), + is_eager=[True, False], + ) + def test_gradient_norms_on_various_models(self, layer_name, is_eager): + batch_size = 2 + input_dims = [3, 4] + output_dims = [5] + model_generator = make_one_layer_sequential_model + layer_generator = get_dense_layer_generators()[layer_name] + example_size = tf.reduce_prod(input_dims) + example_values = tf.range(batch_size * example_size, dtype=tf.float32) + x_batch = tf.reshape(example_values, [batch_size] + input_dims) + default_registry = layer_registry.make_default_layer_registry() + (computed_norms, true_norms) = get_computed_and_true_norms( + model_generator, + layer_generator, + input_dims, + output_dims, + is_eager, + x_batch, + registry=default_registry, + ) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + + class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): # TODO(wkong): Test sparse input tensors when the GitHub CI environment @@ -374,7 +488,7 @@ def test_gradient_norms_on_various_models( model_generator=model_generator, layer_generator=None, input_dims=x_batch.shape[1:], - output_dim=output_dim, + output_dims=[output_dim], is_eager=is_eager, x_input=x_batch, registry=default_registry, @@ -398,9 +512,9 @@ def test_gradient_norms_on_various_models( for x_batch in x_batches: (computed_norms, true_norms) = get_computed_and_true_norms( model_generator=make_two_layer_sequential_model, - layer_generator=lambda a, b: DoubleDense(b), - input_dims=input_dim, - output_dim=output_dim, + layer_generator=lambda a, b: DoubleDense(b[0]), + input_dims=[input_dim], + output_dims=[output_dim], is_eager=is_eager, x_input=x_batch, registry=registry, @@ -408,5 +522,41 @@ def test_gradient_norms_on_various_models( self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) +class ClipGradsEinsumDenseTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + model_name=list(get_einsum_model_generators().keys()), + layer_name=list(get_einsum_layer_generators().keys()), + param_tuple=get_einsum_parameter_tuples(), + is_eager=[False, True], + ) + def test_gradient_norms_on_various_models( + self, model_name, layer_name, param_tuple, is_eager + ): + equation, input_dims, output_dims, bias_axes = param_tuple + model_generator = get_einsum_model_generators()[model_name] + einsum_generator = get_einsum_layer_generators()[layer_name] + registry = layer_registry.make_default_layer_registry() + + def curried_generator(a, b): # pylint: disable=unused-argument + return einsum_generator(equation, output_dims, bias_axes) + + # Each batched input is a reshape of a `tf.range()` call. + batch_size = 2 + example_size = tf.reduce_prod(input_dims) + example_values = tf.range(batch_size * example_size, dtype=tf.float32) + x_batch = tf.reshape(example_values, [batch_size] + input_dims) + (computed_norms, true_norms) = get_computed_and_true_norms( + model_generator=model_generator, + layer_generator=curried_generator, + input_dims=input_dims, + output_dims=output_dims, + is_eager=is_eager, + x_input=x_batch, + registry=registry, + ) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/einsum_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/einsum_utils.py new file mode 100644 index 000000000..ab3e325c2 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/einsum_utils.py @@ -0,0 +1,302 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Various helper functions related to `tf.keras.layers.EinsumDense`.""" + +import itertools +import re + +import numpy as np +import tensorflow as tf + + +def _is_batch_of_vectors(t: tf.Tensor) -> bool: + """Checks if an input is a batch of 1D vectors.""" + num_nontrivial_indices = 0 + for s in t.shape[1:]: + if num_nontrivial_indices > 1: + return False + if s > 1: + num_nontrivial_indices += 1 + return num_nontrivial_indices <= 1 + + +def _parse_einsum_equation(equation: str) -> tuple[int, tuple[str, str, str]]: + """Returns a case number and I/O substrings of an einsum equation.""" + case_number = 0 + match1 = re.match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", equation) + if match1 is not None: + case_number = 1 + match2 = re.match( + r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)", equation + ) + if match2 is not None: + case_number = 2 + match3 = re.match( + r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.", equation + ) + if match3 is not None: + case_number = 3 + matched = [g for g in [match1, match2, match3] if g is not None] + if len(matched) != 1: + raise ValueError( + "Invalid Einsum eqution string " + + equation + + " ." + "Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " + "{ab...,bc->ac...}" + ) + return case_number, matched[0].groups() + + +def _reshape_einsum_inputs( + input_tensor: tf.Tensor, + equation: str, +) -> tf.Tensor: + """Converts input tensor to a batched matrix according to an einsum equation. + + Args: + input_tensor: A `tf.Tensor` corresponding to the first input of the einsum + equation. + equation: The einsum equation `string`. + + Returns: + A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The + product of the non-trivial dimensions of the output should be equal to + the product of the non-trivial dimensions of `input_tensor`. + """ + # Find the components `ab`, `bc`, and `ac` given that `equation` can only be + # one of the following mutually exclusive forms: + # + # (C1) ab,bc->ac, + # (C2) ...ab,bc->...ac + # (C3) ab...,bc->ac... + # + # NOTE: `a`, `b`, and `c` are (possibly) also substrings. + + # Compute the first index of the `b` part of the `ab` component. + input_shape = input_tensor.shape + input_len = len(input_shape) + case_number, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation) + if case_number == 2: + # In case (C2), the `a` part of this component can be empty, so we have no + # choice but to compare the `c` part of `ac` with the `bc` component. + c_len = 0 + for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)): + if s1 == s2: + c_len += 1 + else: + break + b_len = len(bc_str) - c_len + b_idx = input_len - b_len + else: + # For the other cases, we simply compare `ab` with `ac` to get the length + # of the `a` component, i.e., the first index of `b`. + b_idx = 0 + for s1, s2 in itertools.zip_longest(ab_str, ac_str): + if s1 == s2: + b_idx += 1 + else: + break + # Prepare `input_tensor` for reshaping and get the pivot index of the prepped + # tensor. Note that case (C3) requires a transpose to ensure that matrix + # multiplication is performed by the caller. + if case_number == 3: + ellipses_idx = len(ab_str) + # Convert `ab...` to `a...b`. + new_ordering = ( + list(range(0, b_idx)) + + list(range(ellipses_idx, input_len)) + + list(range(b_idx, ellipses_idx)) + ) + base_tensor = tf.transpose(input_tensor, perm=new_ordering) + ellipses_len = input_len - ellipses_idx + pivot_idx = b_idx + ellipses_len + else: + base_tensor = input_tensor + pivot_idx = b_idx + # The output tensor is a batched set of matrices, split at the pivot index + # of the previously prepped tensor. + base_tensor_shape = base_tensor.shape + batch_size = base_tensor_shape[0] + num_rows = int(np.product(base_tensor_shape[1:pivot_idx])) + num_columns = int(np.product(base_tensor_shape[pivot_idx:])) + return tf.reshape(base_tensor, shape=[batch_size, num_rows, num_columns]) + + +def _reshape_einsum_outputs( + output_tensor: tf.Tensor, + equation: str, +) -> tf.Tensor: + """Converts output tensor to a batched matrix according to an einsum equation. + + The logic is almost the same as in `_reshape_einsum_inputs()` except + in the case where the equation is left-elided by ellipses. For this case, + we need to pass in a reversed kernel shape. + + Args: + output_tensor: A `tf.Tensor` corresponding to the output of the einsum + equation. + equation: The einsum equation `string`. + + Returns: + A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The + product of the non-trivial dimensions of the output should be equal to + the product of the non-trivial dimensions of `output_tensor`. + """ + match = re.match(r"([a-zA-Z|.]+),([a-zA-Z|.]+)->([a-zA-Z|.]+)", equation) + if match is not None: + s1, s2, s3 = match.groups() + else: + raise ValueError( + "Invalid Einsum eqution string " + + equation + + " ." + "Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, " + "{ab...,bc->ac...}" + ) + reversed_equation = s3 + "," + s2[::-1] + "->" + s1 + return _reshape_einsum_inputs(output_tensor, reversed_equation) + + +def _get_einsum_bias_adjoint_reduction_axes( + equation: str, + bias_axes: str, + grad_shape: tf.TensorShape, +) -> list[int]: + """Computes axes related to the adjoint of the einsum bias broadcast op.""" + reduction_axes = [] + case_number, (_, _, ac_str) = _parse_einsum_equation(equation) + # If `equation` of the form `...ab,bc->...ac`, i.e., case (C2), we do a + # right to left traversal; the other cases do a left to right traversal. + left_elided = case_number == 2 + grad_indices = range(len(grad_shape)) + traversal_zip = ( + itertools.zip_longest(reversed(grad_indices), reversed(ac_str)) + if left_elided + else itertools.zip_longest(grad_indices, ac_str) + ) + bias_traversal_str = bias_axes[::-1] if left_elided else bias_axes + # Perform the traversal. + ptr = 0 + for idx, output_chr in traversal_zip: + if idx != 0: + if output_chr is not None and ptr < len(bias_axes): + if bias_traversal_str[ptr] == output_chr: + ptr += 1 + else: + reduction_axes.append(idx) + else: + reduction_axes.append(idx) + return reduction_axes + + +def compute_fast_einsum_squared_gradient_norm( + equation: str, + input_tensor: tf.Tensor, + grad_tensor: tf.Tensor, + bias_axes: str | None, +): + """Computes the batch gradient norms of an Einsum gradient decompostion. + + This logic generalizes the one for `tf.keras.layers.Dense`. For reference, + we describe part of the mathematical analysis below. It can be safely skipped + upon first reading of this docstring. + + ----------------------------------------------------------------------------- + BEGIN ANALYSIS + ----------------------------------------------------------------------------- + Recall that the einsum dense computation for a single example is of the form + ``` + output = tf.einsum(equation, input, kernel) + bias, + ``` + where `bias` is broadcasted and summed with the output of the `tf.einsum()` + call, and equation has one of the following forms: + + (C1) ab,bc->ac, + (C2) ...ab,bc->...ac + (C3) ab...,bc->ac... + + Mathematically, the above computation is equivalent to: + ``` + output = tf.matmul(X, W) + Q(bias) + ``` + where `X` (resp. `W`) is a 2D tensor reshaped from `input` (resp. `kernel`) + and `Q` is a linear operator that transforms `bias` to comport with the + tensor output by the `tf.matmul()` call. + + Following the same trick as for `tf.keras.layers.Dense` layers, suppose that + we have: + ``` + loss = f(base_vars) + G = tape.gradient(loss, base_vars) + ``` + Then, using the chain rule and denoting `A'` to be the adjoint of a matrix + `A`, it is straightforward to show that the gradient of `loss` with respect + to `W` is given by the block matrix `K := [X' G; Q' G]`. Hence, the square + norm of `K`, i.e., what is returned by `sqr_norm_fn` is given by + ``` + sqr_norm = + || Q' G ||_F^2 + ``` + where `||.||_F` is the Frobenius norm and `<.,.>` is the Euclidean inner + product for matrices. + ----------------------------------------------------------------------------- + END ANALYSIS + ----------------------------------------------------------------------------- + + Args: + equation: A `string` representing the einsum equation. + input_tensor: A `tf.Tensor` reprenting the einsum input. + grad_tensor: A `tf.Tensor` that is the gradient of the scalar loss with + respect to the pre-activation tensor. + bias_axes: A `string` that specifies the einsum biases in `equation`. + + Returns: + A 1D `tf.Tensor` whose i-th entry is the squared gradient corresponding + to the i-th example in `input_tensor`. + """ + # NOTE: When the input/gradient tensors are 1D, it is MUCH faster to do + # a `tf.square()` + `tf.reduce_sum()` than a single `tf.matmul()`. + + # Compute the matrix `X X'` for each example. + x = _reshape_einsum_inputs(input_tensor, equation) + if _is_batch_of_vectors(x): + x_matrix = tf.reshape(x, [x.shape[0], -1]) + batch_xxt = tf.reduce_sum(tf.square(x_matrix), axis=1) + else: + batch_xxt = tf.matmul(x, x, transpose_b=True) + # Compute the matrix `G G'` for each example. + g = _reshape_einsum_outputs(grad_tensor, equation) + if _is_batch_of_vectors(g): + g_matrix = tf.reshape(g, [g.shape[0], -1]) + batch_ggt = tf.reduce_sum(tf.square(g_matrix), axis=1) + else: + batch_ggt = tf.matmul(g, g, transpose_b=True) + # Compute the inner product and adjust for bias (if it exists). + reduction_axes = tf.range(1, len(batch_ggt.shape)) + sqr_norms = tf.reduce_sum(batch_xxt * batch_ggt, axis=reduction_axes) + if bias_axes is not None: + # The adjoint operator `Q` on `G` is a reduce sum on the axes in `G` that + # are not broadcasted from `bias`. + grads_shape = grad_tensor.shape + adjoint_reduction_axes = _get_einsum_bias_adjoint_reduction_axes( + equation, + bias_axes, + grads_shape, + ) + qg = tf.reduce_sum(grad_tensor, axis=adjoint_reduction_axes) + qg_reduction_axes = tf.range(1, len(qg.shape)) + bias_sqr_norms = tf.reduce_sum(tf.square(qg), axis=qg_reduction_axes) + sqr_norms += bias_sqr_norms + + return sqr_norms diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 6dd0d49af..d5284ebc6 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,11 +13,20 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" +from typing import Any, Union, Iterable, Text, Callable, Tuple, TypeAlias + from absl import logging import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr + +InputTensor: TypeAlias = Union[ + tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor] +] +GeneratorFunction: TypeAlias = Callable[[Any, Tuple, dict], Tuple[Any, Any]] + -def has_internal_compute_graph(input_object): +def has_internal_compute_graph(input_object: Any): """Checks if input is a TF model and has a TF internal compute graph.""" return ( isinstance(input_object, tf.keras.Model) @@ -28,7 +37,7 @@ def has_internal_compute_graph(input_object): ) -def _get_internal_layers(input_layer): +def _get_internal_layers(input_layer: tf.keras.layers.Layer): """Returns a list of layers that are nested within a given layer.""" internal_layers = [] if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'): @@ -39,7 +48,11 @@ def _get_internal_layers(input_layer): return internal_layers -def model_forward_pass(input_model, inputs, generator_fn=None): +def model_forward_pass( + input_model: tf.keras.Model, + inputs: InputTensor, + generator_fn: GeneratorFunction | None = None, +): """Does a forward pass of a model and returns useful intermediates. NOTE: the graph traversal algorithm is an adaptation of the logic in the @@ -118,7 +131,9 @@ def generator_fn(layer_instance, args, kwargs): return node_layer_outputs, generator_outputs_list -def all_trainable_layers_are_registered(input_model, layer_registry): +def all_trainable_layers_are_registered( + input_model: tf.keras.Model, layer_registry: lr.LayerRegistry +): """Check if an input model's trainable layers are all registered. Args: @@ -140,7 +155,11 @@ def all_trainable_layers_are_registered(input_model, layer_registry): def add_aggregate_noise( - input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier + input_model: tf.keras.Model, + x_batch: InputTensor, + clipped_grads: list[tf.Tensor], + l2_norm_clip: float, + noise_multiplier: float, ): """Adds noise to a collection of clipped gradients. @@ -148,10 +167,9 @@ def add_aggregate_noise( input model's loss function. Args: - input_model: The Keras model to obtain the layers from. - x_batch: A collection of Tensors to be fed into the input layer of the - model. - clipped_grads: A list of tensors representing the clipped gradients. + input_model: The `tf.keras.Model` to obtain the layers from. + x_batch: An `InputTensor` to be fed into the input layer of the model. + clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. l2_norm_clip: Clipping norm (max L2 norm of each gradient). noise_multiplier: Ratio of the standard deviation to the clipping norm. @@ -187,7 +205,7 @@ def add_noise(g): return tf.nest.map_structure(add_noise, clipped_grads) -def generate_model_outputs_using_core_keras_layers(input_model): +def generate_model_outputs_using_core_keras_layers(input_model: tf.keras.Model): """Returns the model outputs generated by only core Keras layers.""" cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects()) cust_hash_set = set([hash(v) for v in cust_obj_dict.values()]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index c8279baf4..b19bb18cb 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -40,8 +40,22 @@ Details of this decomposition can be found in https://arxiv.org/abs/1510.01799 """ +from typing import Callable, Type, Any, TypeAlias import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import einsum_utils + +# ============================================================================== +# Type aliases +# ============================================================================== +SquareNormFunction: TypeAlias = Callable[[Any], tf.Tensor] + +RegistryFunctionOutput: TypeAlias = tuple[Any, tf.Tensor, SquareNormFunction] + +RegistryFunction: TypeAlias = Callable[ + [Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput +] + # ============================================================================== # Main class @@ -54,15 +68,19 @@ def __init__(self): self._layer_class_dict = {} self._registry = {} - def is_elem(self, layer_instance): + def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool: """Checks if a layer instance's class is in the registry.""" return hash(layer_instance.__class__) in self._registry - def lookup(self, layer_instance): + def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction: """Returns the layer registry function for a given layer instance.""" return self._registry[hash(layer_instance.__class__)] - def insert(self, layer_class, layer_registry_function): + def insert( + self, + layer_class: Type[tf.keras.layers.Layer], + layer_registry_function: RegistryFunction, + ): """Inserts a layer registry function into the internal dictionaries.""" layer_key = hash(layer_class) self._layer_class_dict[layer_key] = layer_class @@ -72,7 +90,11 @@ def insert(self, layer_class, layer_registry_function): # ============================================================================== # Supported Keras layers # ============================================================================== -def dense_layer_computation(layer_instance, inputs, tape): +def dense_layer_computation( + layer_instance: tf.keras.layers.Dense, + inputs: tuple[tf.Tensor], + tape: tf.GradientTape, +) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Dense`. The logic for this computation is based on the following paper: @@ -106,26 +128,70 @@ def dense_layer_computation(layer_instance, inputs, tape): tape.watch(base_vars) layer_instance.activation = orig_activation outputs = orig_activation(base_vars) if orig_activation else base_vars - def sqr_norm_fn(base_vars_grads): - sqr_inputs = tf.square(*inputs) - inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs)) - input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes) - if layer_instance.use_bias: - # Adding a bias term is equivalent to a layer with no bias term and which - # adds an additional variable to the layer input that only takes a - # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum - # of the squared values of the inputs. - input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype) - reduction_axes = tf.range(1, tf.rank(base_vars_grads)) - base_vars_sqr_norms = tf.reduce_sum( - tf.square(base_vars_grads), axis=reduction_axes + def sqr_norm_fn(grads): + # `Dense` layers are special instances of `EinsumDense` layers + return einsum_utils.compute_fast_einsum_squared_gradient_norm( + "...b,bc->...c", + inputs[0], + grads, + "c" if layer_instance.use_bias else None, + ) + + return base_vars, outputs, sqr_norm_fn + + +def einsum_layer_computation( + layer_instance: tf.keras.layers.EinsumDense, + inputs: tuple[tf.Tensor], + tape: tf.GradientTape, +): + """Registry function for `tf.keras.layers.EinsumDense`. + + For the technical details, see the documentation of + `einsum_utils.compute_fast_einsum_gradient_norm()`. + + Args: + layer_instance: A `tf.keras.layers.EinsumDense` instance. + inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., + `layer_instance(inputs)` returns a valid output. + tape: A `tf.GradientTape` instance that will be used to watch the output + `base_vars`. + + Returns: + A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the + intermediate Tensor used in the chain-rule / "fast" clipping trick, + `outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is + a function that takes one input, a `tf.Tensor` that represents the output + of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a + `tf.GradientTape` instance that records the dense layer computation and + `summed_loss` is the sum of the per-example losses of the underlying model. + This function then returns the per-example squared L2 gradient norms of the + trainable variables in `layer_instance`. These squared norms should be a 1D + `tf.Tensor` of length `batch_size`. + """ + orig_activation = layer_instance.activation + layer_instance.activation = None + base_vars = layer_instance(*inputs) + tape.watch(base_vars) + layer_instance.activation = orig_activation + outputs = orig_activation(base_vars) if orig_activation else base_vars + + def sqr_norm_fn(grads): + return einsum_utils.compute_fast_einsum_squared_gradient_norm( + layer_instance.equation, + inputs[0], + grads, + layer_instance.bias_axes, ) - return input_sqr_norms * base_vars_sqr_norms return base_vars, outputs, sqr_norm_fn -def embedding_layer_computation(layer_instance, inputs, tape): +def embedding_layer_computation( + layer_instance: tf.keras.layers.Embedding, + inputs: tuple[tf.Tensor], + tape: tf.GradientTape, +) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Embedding`. The logic of this computation is based on the `tf.keras.layers.Dense` @@ -225,8 +291,9 @@ def sqr_norm_fn(base_vars_grads): # ============================================================================== # Main factory methods # ============================================================================== -def make_default_layer_registry(): +def make_default_layer_registry() -> LayerRegistry: registry = LayerRegistry() registry.insert(tf.keras.layers.Dense, dense_layer_computation) registry.insert(tf.keras.layers.Embedding, embedding_layer_computation) + registry.insert(tf.keras.layers.EinsumDense, einsum_layer_computation) return registry