Skip to content

Commit a35cda3

Browse files
wwkongtensorflower-gardener
authored andcommitted
Implement and test a registry function for tfm.nlp.layers.EinsumDense + small formatting fixes.
PiperOrigin-RevId: 568854078
1 parent 0d1bd9d commit a35cda3

File tree

6 files changed

+446
-2
lines changed

6 files changed

+446
-2
lines changed

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ py_library(
1313
name = "einsum_utils",
1414
srcs = ["einsum_utils.py"],
1515
srcs_version = "PY3",
16+
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils"],
1617
)
1718

1819
py_test(
@@ -24,6 +25,33 @@ py_test(
2425
deps = [":einsum_utils"],
2526
)
2627

28+
py_library(
29+
name = "einsum_dense",
30+
srcs = ["einsum_dense.py"],
31+
srcs_version = "PY3",
32+
deps = [
33+
":einsum_utils",
34+
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
35+
"//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases",
36+
],
37+
)
38+
39+
py_test(
40+
name = "einsum_dense_test",
41+
size = "large",
42+
srcs = ["einsum_dense_test.py"],
43+
python_version = "PY3",
44+
shard_count = 12,
45+
srcs_version = "PY3",
46+
deps = [
47+
":dense",
48+
":einsum_dense",
49+
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
50+
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
51+
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
52+
],
53+
)
54+
2755
py_library(
2856
name = "dense",
2957
srcs = ["dense.py"],
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2023, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Fast clipping function for `tfm.nlp.layers.EinsumDense`."""
15+
16+
from collections.abc import Mapping, Sequence
17+
from typing import Any, Optional
18+
import tensorflow as tf
19+
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
20+
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils
21+
22+
23+
def einsum_layer_computation(
24+
layer_instance: tf.keras.layers.Layer,
25+
input_args: Sequence[Any],
26+
input_kwargs: Mapping[str, Any],
27+
tape: tf.GradientTape,
28+
num_microbatches: Optional[tf.Tensor] = None,
29+
) -> type_aliases.RegistryFunctionOutput:
30+
"""Registry function for `tf.keras.layers.EinsumDense`.
31+
32+
For the technical details, see the documentation of
33+
`einsum_utils.compute_fast_einsum_gradient_norm()`.
34+
35+
Args:
36+
layer_instance: A `tf.keras.layers.EinsumDense` instance.
37+
input_args: See `dense_layer_computation()` in `dense.py`.
38+
input_kwargs: See `dense_layer_computation()` in `dense.py`.
39+
tape: See `dense_layer_computation()` in `dense.py`.
40+
num_microbatches: See `dense_layer_computation()` in `dense.py`.
41+
42+
Returns:
43+
See `dense_layer_computation()` in `dense.py`.
44+
"""
45+
if input_kwargs:
46+
raise ValueError("EinsumDense layer calls should not receive kwargs.")
47+
del input_kwargs
48+
if len(input_args) != 1:
49+
raise ValueError("Only layer inputs of length 1 are permitted.")
50+
orig_activation = layer_instance.activation
51+
layer_instance.activation = None
52+
base_vars = layer_instance(*input_args)
53+
tape.watch(base_vars)
54+
layer_instance.activation = orig_activation
55+
outputs = orig_activation(base_vars) if orig_activation else base_vars
56+
57+
def sqr_norm_fn(grads):
58+
return einsum_utils.compute_fast_einsum_squared_gradient_norm(
59+
layer_instance.equation,
60+
input_args[0],
61+
grads,
62+
layer_instance.bias_axes,
63+
num_microbatches,
64+
)
65+
66+
return base_vars, outputs, sqr_norm_fn
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright 2023, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import parameterized
16+
import tensorflow as tf
17+
import tensorflow_models as tfm
18+
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
19+
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
20+
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense
21+
22+
23+
def get_einsum_layer_generators():
24+
def pure_einsum_layer(equation, output_dims, bias_axes):
25+
return tf.keras.layers.EinsumDense(
26+
equation, output_dims, bias_axes=bias_axes
27+
)
28+
29+
def sigmoid_einsum_layer(equation, output_dims, bias_axes):
30+
return tf.keras.layers.EinsumDense(
31+
equation, output_dims, bias_axes=bias_axes, activation='sigmoid'
32+
)
33+
34+
return {
35+
'pure_einsum': pure_einsum_layer,
36+
'sigmoid_einsum': sigmoid_einsum_layer,
37+
}
38+
39+
40+
def get_einsum_model_generators():
41+
return {
42+
'func1': common_test_utils.make_one_layer_functional_model,
43+
}
44+
45+
46+
def get_einsum_parameter_tuples():
47+
# (equation, input_dims, output_dims, bias_axes)
48+
return [
49+
# Case C1.
50+
('ab,bc->ac', [2], [3], None),
51+
('ab,bc->ac', [2], [3], 'c'),
52+
('abc,cd->abd', [2, 3], [2, 4], None),
53+
('abc,cd->abd', [2, 3], [2, 4], 'b'),
54+
('abc,cd->abd', [2, 3], [2, 4], 'd'),
55+
('abc,cd->abd', [2, 3], [2, 4], 'bd'),
56+
('abc,cef->abef', [2, 3], [2, 4, 5], None),
57+
('abc,cef->abef', [2, 3], [2, 4, 5], 'bf'),
58+
# Case C2.
59+
('...b,bc->...c', [2, 3], [4], None),
60+
('...b,bc->...c', [2, 3], [4], 'c'),
61+
('...ab,bc->...ac', [2, 3], [2, 4], None),
62+
('...ab,bc->...ac', [2, 4], [2, 4], 'c'),
63+
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], None),
64+
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'b'),
65+
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'd'),
66+
('...abc,cd->...abd', [2, 3, 4], [2, 3, 5], 'bd'),
67+
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], None),
68+
('...abc,cef->...abef', [2, 3, 4], [2, 3, 5, 6], 'bf'),
69+
# Case C3.
70+
('ab...,bc->ac...', [2, 3], [4, 3], None),
71+
('ab...,bc->ac...', [2, 3], [4, 3], 'c'),
72+
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], None),
73+
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'b'),
74+
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'd'),
75+
('abc...,cd->abd...', [2, 3, 4], [2, 5, 4], 'bd'),
76+
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], None),
77+
('abc...,cef->abef...', [2, 3, 4], [2, 5, 6, 4], 'bf'),
78+
]
79+
80+
81+
def get_einsum_layer_registries():
82+
einsum_registry = layer_registry.LayerRegistry()
83+
einsum_registry.insert(
84+
tfm.nlp.layers.EinsumDense,
85+
einsum_dense.einsum_layer_computation,
86+
)
87+
return {
88+
'einsum_and_dense': einsum_registry,
89+
}
90+
91+
92+
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
93+
94+
def setUp(self):
95+
super().setUp()
96+
self.strategy = tf.distribute.get_strategy()
97+
self.using_tpu = False
98+
99+
@parameterized.product(
100+
model_name=list(get_einsum_model_generators()),
101+
layer_name=list(get_einsum_layer_generators()),
102+
layer_registry_name=list(get_einsum_layer_registries()),
103+
param_tuple=get_einsum_parameter_tuples(),
104+
num_microbatches=[None, 2],
105+
is_eager=[True, False],
106+
)
107+
def test_gradient_norms_on_various_models(
108+
self,
109+
model_name,
110+
layer_name,
111+
layer_registry_name,
112+
param_tuple,
113+
num_microbatches,
114+
is_eager,
115+
):
116+
# Parse inputs to generate test data. Note that each batched input is a
117+
# reshape of a `tf.range()` call.
118+
equation, input_dims, output_dims, bias_axes = param_tuple
119+
batch_size = 4
120+
example_size = tf.reduce_prod(input_dims)
121+
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
122+
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
123+
x_input = [x_batch, x_batch] if model_name == 'tower2' else x_batch
124+
125+
# Make the layer generator via currying.
126+
einsum_generator = get_einsum_layer_generators()[layer_name]
127+
128+
def curried_generator(a, b):
129+
del a, b
130+
return einsum_generator(equation, output_dims, bias_axes)
131+
132+
# Load shared assets to all devices.
133+
with self.strategy.scope():
134+
model = common_test_utils.get_model_from_generator(
135+
model_generator=get_einsum_model_generators()[model_name],
136+
layer_generator=curried_generator,
137+
input_dims=input_dims,
138+
output_dims=output_dims,
139+
is_eager=is_eager,
140+
)
141+
142+
# Define the main testing ops. These may be later compiled to a Graph op.
143+
def test_op(x_input):
144+
return common_test_utils.get_computed_and_true_norms_from_model(
145+
model=model,
146+
per_example_loss_fn=None,
147+
num_microbatches=num_microbatches,
148+
x_batch=x_input,
149+
registry=get_einsum_layer_registries()[layer_registry_name],
150+
)
151+
152+
# TPUs can only run `tf.function`-decorated functions.
153+
if self.using_tpu:
154+
test_op = tf.function(test_op, autograph=False)
155+
156+
# TPUs use lower precision than CPUs, so we relax our criterion.
157+
# E.g., one of the TPU runs generated the following results:
158+
#
159+
# computed_norm = 93.48296
160+
# true_norm = 93.31176
161+
# abs_diff = 0.17120361
162+
# rel_diff = 0.00183475
163+
#
164+
# which is a reasonable level of error for computing gradient norms.
165+
# Other trials also give an absolute (resp. relative) error of around
166+
# 0.05 (resp. 0.0015).
167+
rtol = 1e-2 if self.using_tpu else 1e-3
168+
atol = 5e-1 if self.using_tpu else 1e-2
169+
170+
# Set up the device ops and run the test.
171+
computed_norms, true_norms = self.strategy.run(test_op, args=(x_input,))
172+
# TPUs return replica contexts, which must be unwrapped.
173+
if self.using_tpu:
174+
common_test_utils.assert_replica_values_are_close(self, computed_norms)
175+
common_test_utils.assert_replica_values_are_close(self, true_norms)
176+
computed_norms = computed_norms.values[0]
177+
true_norms = true_norms.values[0]
178+
expected_size = num_microbatches or batch_size
179+
self.assertEqual(tf.shape(computed_norms)[0], expected_size)
180+
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
181+
182+
183+
if __name__ == '__main__':
184+
tf.test.main()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2023, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
17+
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_dense_test
18+
19+
20+
class GradNormTpuTest(einsum_dense_test.GradNormTest):
21+
22+
def setUp(self):
23+
super(einsum_dense_test.GradNormTest, self).setUp()
24+
self.strategy = common_test_utils.create_tpu_strategy()
25+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
26+
self.using_tpu = True
27+
28+
29+
if __name__ == '__main__':
30+
tf.test.main()

0 commit comments

Comments
 (0)