From c5b573158534f6c9bf87bbcfc98a9c22393e64a8 Mon Sep 17 00:00:00 2001 From: William Kong Date: Fri, 30 Aug 2024 10:25:58 -0700 Subject: [PATCH] Fix a gradient clipping bug for layer normalization layers with microbatch axes. The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug. PiperOrigin-RevId: 669369587 --- .../registry_functions/layer_normalization.py | 5 ++++- .../registry_functions/layer_normalization_test.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py index 849ace69..e79a52b1 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py @@ -80,8 +80,11 @@ def sqr_norm_fn(grads): stacked_grads = tf.stack(grads, axis=-1) if num_microbatches is not None: stacked_grads = common_manip_utils.maybe_add_microbatch_axis( - grads, num_microbatches + stacked_grads, num_microbatches ) + # We will need to sum over the new microbatch size axis (axis=1) in order + # to account for microbatch aggregation. + stacked_grads = tf.reduce_sum(stacked_grads, axis=1) reduction_axes = tf.range(1, tf.rank(stacked_grads)) return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py index c0b18d88..f7204354 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py @@ -134,7 +134,7 @@ def test_op(x_batch): atol = 1e-1 if self.using_tpu else 1e-2 # Each batched input is a reshape of a `tf.range()` call. - batch_size = 2 + batch_size = 6 example_size = np.prod(input_dims) example_values = tf.range(batch_size * example_size, dtype=tf.float32) x_batch = tf.reshape(example_values, [batch_size] + input_dims) @@ -147,7 +147,9 @@ def test_op(x_batch): common_test_utils.assert_replica_values_are_close(self, true_norms) computed_norms = computed_norms.values[0] true_norms = true_norms.values[0] - self.assertEqual(tf.shape(computed_norms)[0], batch_size) + self.assertEqual( + tf.shape(computed_norms)[0], num_microbatches or batch_size + ) self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)