Skip to content

Commit

Permalink
Implement and test a registry function for `tfm.nlp.layers.EinsumDens…
Browse files Browse the repository at this point in the history
…e` + small formatting fixes.

PiperOrigin-RevId: 576215816
  • Loading branch information
wwkong authored and tensorflower-gardener committed Oct 24, 2023
1 parent 8b52ba2 commit 39c8a8c
Show file tree
Hide file tree
Showing 6 changed files with 442 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ py_library(
name = "einsum_utils",
srcs = ["einsum_utils.py"],
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils"],
)

py_test(
Expand All @@ -24,6 +25,33 @@ py_test(
deps = [":einsum_utils"],
)

py_library(
name = "einsum_dense",
srcs = ["einsum_dense.py"],
srcs_version = "PY3",
deps = [
":einsum_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
],
)

py_test(
name = "einsum_dense_test",
size = "large",
srcs = ["einsum_dense_test.py"],
python_version = "PY3",
shard_count = 12,
srcs_version = "PY3",
deps = [
":dense",
":einsum_dense",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)

py_library(
name = "dense",
srcs = ["dense.py"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tfm.nlp.layers.EinsumDense`."""

from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils


def einsum_layer_computation(
layer_instance: tf.keras.layers.EinsumDense,
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.EinsumDense`.
For the technical details, see the documentation of
`einsum_utils.compute_fast_einsum_gradient_norm()`.
Args:
layer_instance: A `tf.keras.layers.EinsumDense` instance.
input_args: See `dense_layer_computation()` in `dense.py`.
input_kwargs: See `dense_layer_computation()` in `dense.py`.
tape: See `dense_layer_computation()` in `dense.py`.
num_microbatches: See `dense_layer_computation()` in `dense.py`.
Returns:
See `dense_layer_computation()` in `dense.py`.
"""
if input_kwargs:
raise ValueError("EinsumDense layer calls should not receive kwargs.")
del input_kwargs
if len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
orig_activation = layer_instance.activation
# Some activation functions may not apply a transform to the elements of the
# output individually (which is needed for the fast clipping trick to work).
# To avoid this case, we watch the variables that are only generated by the
# linear transformation of the `EinsumDense` layer instance.
layer_instance.activation = None
base_vars = layer_instance(*input_args)
tape.watch(base_vars)
layer_instance.activation = orig_activation
outputs = orig_activation(base_vars) if orig_activation else base_vars

def sqr_norm_fn(grads):
return einsum_utils.compute_fast_einsum_squared_gradient_norm(
layer_instance.equation,
input_args[0],
grads,
layer_instance.bias_axes,
num_microbatches,
)

return base_vars, outputs, sqr_norm_fn
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized
import tensorflow as tf
import tensorflow_models as tfm
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense


def get_einsum_layer_generators():
def pure_einsum_layer(equation, output_dims, bias_axes):
return tf.keras.layers.EinsumDense(
equation, output_dims, bias_axes=bias_axes
)

def sigmoid_einsum_layer(equation, output_dims, bias_axes):
return tf.keras.layers.EinsumDense(
equation, output_dims, bias_axes=bias_axes, activation='sigmoid'
)

return {
'pure_einsum': pure_einsum_layer,
'sigmoid_einsum': sigmoid_einsum_layer,
}


def get_einsum_parameter_tuples():
# (equation, input_dims, output_dims, bias_axes)
return [
# Case C1.
('ab,bc->ac', [2], [3], None),
('ab,bc->ac', [2], [3], 'c'),
('abc,cd->abd', [2, 3], [2, 4], None),
('abc,cd->abd', [2, 3], [2, 4], 'b'),
('abc,cd->abd', [2, 3], [2, 4], 'd'),
('abc,cd->abd', [2, 3], [2, 4], 'bd'),
('abc,cef->abef', [2, 3], [2, 4, 5], None),
('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'),
# Case C2.
('...b,bc->...c', [2, 3], [4], None),
('...b,bc->...c', [2, 3], [4], 'c'),
('...ab,bc->...ac', [2, 3], [2, 4], None),
('...ab,bc->...ac', [2, 4], [2, 4], 'c'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'),
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'),
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None),
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'),
# Case C3.
('ab...,bc->ac...', [2, 3], [4, 3], None),
('ab...,bc->ac...', [2, 3], [4, 3], 'c'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'),
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'),
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None),
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'),
]


def get_einsum_layer_registry():
einsum_registry = layer_registry.LayerRegistry()
einsum_registry.insert(
tfm.nlp.layers.EinsumDense,
einsum_dense.einsum_layer_computation,
)
return einsum_registry


class GradNormTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
self.using_tpu = False

@parameterized.product(
layer_name=list(get_einsum_layer_generators()),
param_tuple=get_einsum_parameter_tuples(),
num_microbatches=[None, 2],
is_eager=[True, False],
)
def test_gradient_norms_on_various_models(
self,
layer_name,
param_tuple,
num_microbatches,
is_eager,
):
# Parse inputs to generate test data. Note that each batched input is a
# reshape of a `tf.range()` call.
equation, input_dims, output_dims, bias_axes = param_tuple
batch_size = 4
example_size = tf.reduce_prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims)

# Make the layer generator via currying.
einsum_generator = get_einsum_layer_generators()[layer_name]

def curried_generator(a, b):
del a, b
return einsum_generator(equation, output_dims, bias_axes)

# Load shared assets to all devices.
with self.strategy.scope():
model = common_test_utils.get_model_from_generator(
model_generator=common_test_utils.make_one_layer_functional_model,
layer_generator=curried_generator,
input_dims=input_dims,
output_dims=output_dims,
is_eager=is_eager,
)

# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=None,
num_microbatches=num_microbatches,
x_batch=x,
registry=get_einsum_layer_registry(),
)

# TPUs can only run `tf.function`-decorated functions.
if self.using_tpu:
test_op = tf.function(test_op, autograph=False)

# TPUs use lower precision than CPUs, so we relax our criterion.
# E.g., one of the TPU runs generated the following results:
#
# computed_norm = 93.48296
# true_norm = 93.31176
# abs_diff = 0.17120361
# rel_diff = 0.00183475
#
# which is a reasonable level of error for computing gradient norms.
# Other trials also give an absolute (resp. relative) error of around
# 0.05 (resp. 0.0015).
rtol = 1e-2 if self.using_tpu else 1e-3
atol = 5e-1 if self.using_tpu else 1e-2

# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
# TPUs return replica contexts, which must be unwrapped.
if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
expected_size = num_microbatches or batch_size
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense_test


class GradNormTpuTest(einsum_dense_test.GradNormTest):

def setUp(self):
super(einsum_dense_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit 39c8a8c

Please sign in to comment.