diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py index 849ace69..e79a52b1 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py @@ -80,8 +80,11 @@ def sqr_norm_fn(grads): stacked_grads = tf.stack(grads, axis=-1) if num_microbatches is not None: stacked_grads = common_manip_utils.maybe_add_microbatch_axis( - grads, num_microbatches + stacked_grads, num_microbatches ) + # We will need to sum over the new microbatch size axis (axis=1) in order + # to account for microbatch aggregation. + stacked_grads = tf.reduce_sum(stacked_grads, axis=1) reduction_axes = tf.range(1, tf.rank(stacked_grads)) return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py index c0b18d88..f7204354 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py @@ -134,7 +134,7 @@ def test_op(x_batch): atol = 1e-1 if self.using_tpu else 1e-2 # Each batched input is a reshape of a `tf.range()` call. - batch_size = 2 + batch_size = 6 example_size = np.prod(input_dims) example_values = tf.range(batch_size * example_size, dtype=tf.float32) x_batch = tf.reshape(example_values, [batch_size] + input_dims) @@ -147,7 +147,9 @@ def test_op(x_batch): common_test_utils.assert_replica_values_are_close(self, true_norms) computed_norms = computed_norms.values[0] true_norms = true_norms.values[0] - self.assertEqual(tf.shape(computed_norms)[0], batch_size) + self.assertEqual( + tf.shape(computed_norms)[0], num_microbatches or batch_size + ) self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)