Skip to content

Commit

Permalink
Add utility functions for unwrapping BERT encoder layers into individ…
Browse files Browse the repository at this point in the history
…ual Keras layers.

PiperOrigin-RevId: 584412294
  • Loading branch information
wwkong authored and tensorflower-gardener committed Nov 22, 2023
1 parent b19088f commit e67d39c
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ py_library(
srcs_version = "PY3",
)

py_library(
name = "bert_encoder_utils",
srcs = ["bert_encoder_utils.py"],
srcs_version = "PY3",
deps = [":gradient_clipping_utils"],
)

py_test(
name = "bert_encoder_utils_test",
srcs = ["bert_encoder_utils_test.py"],
srcs_version = "PY3",
deps = [":bert_encoder_utils"],
)

py_library(
name = "common_manip_utils",
srcs = ["common_manip_utils.py"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for manipulating official Tensorflow BERT encoders."""

import tensorflow as tf
import tensorflow_models as tfm
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils


def dedup_bert_encoder(input_bert_encoder: tfm.nlp.networks.BertEncoder):
"""Deduplicates the layer names in a BERT encoder."""

def _dedup(layer, attr_name, new_name):
sublayer = getattr(layer, attr_name)
if sublayer is None:
return
else:
sublayer_config = sublayer.get_config()
sublayer_config["name"] = new_name
setattr(layer, attr_name, sublayer.from_config(sublayer_config))

for layer in input_bert_encoder.layers:
# NOTE: the ordering of the renames is important for the ordering of the
# variables in the computed gradients. This is why we use three `for-loop`
# instead of one.
if isinstance(layer, tfm.nlp.layers.TransformerEncoderBlock):
# pylint: disable=protected-access
for attr_name in ["inner_dropout_layer", "attention_dropout"]:
_dedup(layer, "_" + attr_name, layer.name + "/" + attr_name)
# Some layers are nested within the main attention layer (if it exists).
if layer._attention_layer is not None:
prefix = layer.name + "/" + layer._attention_layer.name
_dedup(layer, "_attention_layer", prefix + "/attention_layer")
_dedup(
layer._attention_layer,
"_dropout_layer",
prefix + "/attention_inner_dropout_layer",
)
for attr_name in ["attention_layer_norm", "intermediate_dense"]:
_dedup(layer, "_" + attr_name, layer.name + "/" + attr_name)
# This is one of the few times that we cannot build from a config, due
# to the presence of lambda functions.
if layer._intermediate_activation_layer is not None:
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
policy = tf.float32
layer._intermediate_activation_layer = tf.keras.layers.Activation(
layer._inner_activation,
dtype=policy,
name=layer.name + "/intermediate_activation_layer",
)
for attr_name in ["output_dense", "output_dropout", "output_layer_norm"]:
_dedup(layer, "_" + attr_name, layer.name + "/" + attr_name)
# pylint: enable=protected-access


def get_unwrapped_bert_encoder(
input_bert_encoder: tfm.nlp.networks.BertEncoder,
) -> tfm.nlp.networks.BertEncoder:
"""Creates a new BERT encoder whose layers are core Keras layers."""
dedup_bert_encoder(input_bert_encoder)
core_test_outputs = (
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
input_bert_encoder,
custom_layer_set={tfm.nlp.layers.TransformerEncoderBlock},
)
)
return tf.keras.Model(
inputs=input_bert_encoder.inputs,
outputs=core_test_outputs,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of `bert_encoder_utils.py`."""

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_models as tfm
from tensorflow_privacy.privacy.fast_gradient_clipping import bert_encoder_utils


def compute_bert_sample_inputs(
batch_size, sequence_length, vocab_size, num_types
):
"""Returns a set of BERT encoder inputs."""
word_id_sample = np.random.randint(
vocab_size, size=(batch_size, sequence_length)
)
mask_sample = np.random.randint(2, size=(batch_size, sequence_length))
type_id_sample = np.random.randint(
num_types,
size=(batch_size, sequence_length),
)
return [word_id_sample, mask_sample, type_id_sample]


def get_small_bert_encoder_and_sample_inputs(dict_outputs=False):
"""Returns a small BERT encoder for testing."""
hidden_size = 2
vocab_size = 3
num_types = 4
max_sequence_length = 5
inner_dense_units = 6
output_range = 1
num_heads = 2
num_transformer_layers = 3
seed = 777

bert_encoder = tfm.nlp.networks.BertEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_heads,
num_layers=num_transformer_layers,
max_sequence_length=max_sequence_length,
inner_dim=inner_dense_units,
type_vocab_size=num_types,
output_range=output_range,
initializer=tf.keras.initializers.GlorotUniform(seed),
dict_outputs=dict_outputs,
)

batch_size = 3
bert_sample_inputs = compute_bert_sample_inputs(
batch_size,
max_sequence_length,
vocab_size,
num_types,
)

return bert_encoder, bert_sample_inputs


def get_shared_trainable_variables(model1, model2):
"""Returns the shared trainable variables (by name) between models."""
common_names = {v.name for v in model1.trainable_variables} & {
v.name for v in model2.trainable_variables
}
tvars1 = [v for v in model1.trainable_variables if v.name in common_names]
tvars2 = [v for v in model2.trainable_variables if v.name in common_names]
return tvars1, tvars2


def custom_reduced_loss(y_batch, y_pred):
del y_batch
# Create a loss multiplier to avoid small gradients.
large_value_multiplier = 1e10
sqr_outputs = []
for t in y_pred:
reduction_axes = tf.range(1, len(t.shape))
sqr_outputs.append(tf.reduce_sum(tf.square(t), axis=reduction_axes))
sqr_tsr = tf.stack(sqr_outputs, axis=1)
return large_value_multiplier * tf.reduce_sum(sqr_tsr, axis=1)


class BertEncoderUtilsTest(tf.test.TestCase, parameterized.TestCase):

def test_outputs_are_equal(self):
true_encoder, sample_inputs = get_small_bert_encoder_and_sample_inputs()
unwrapped_encoder = bert_encoder_utils.get_unwrapped_bert_encoder(
true_encoder
)
true_outputs = true_encoder(sample_inputs)
computed_outputs = unwrapped_encoder(sample_inputs)
self.assertAllClose(true_outputs, computed_outputs)

def test_shared_trainable_variables_are_equal(self):
true_encoder, sample_inputs = get_small_bert_encoder_and_sample_inputs()
unwrapped_encoder = bert_encoder_utils.get_unwrapped_bert_encoder(
true_encoder
)
# Initializes the trainable variable shapes.
true_encoder(sample_inputs)
unwrapped_encoder(sample_inputs)
# The official BERT encoder may initialize trainable variables that are
# not used in a model forward pass. Hence, they are invisible when we
# try to unwrapping layers using our utility function.
true_vars, computed_vars = get_shared_trainable_variables(
true_encoder, unwrapped_encoder
)
self.assertAllClose(true_vars, computed_vars)

def test_shared_gradients_are_equal(self):
true_encoder, sample_inputs = get_small_bert_encoder_and_sample_inputs()
unwrapped_encoder = bert_encoder_utils.get_unwrapped_bert_encoder(
true_encoder
)
# Create a loss multiplier to avoid small gradients.
dummy_labels = None
with tf.GradientTape(persistent=True) as tape:
true_outputs = true_encoder(sample_inputs)
true_sqr_sum = tf.reduce_sum(
custom_reduced_loss(dummy_labels, true_outputs)
)
computed_outputs = unwrapped_encoder(sample_inputs)
computed_sqr_sum = tf.reduce_sum(
custom_reduced_loss(dummy_labels, computed_outputs)
)
# The official BERT encoder may initialize trainable variables that are
# not used in a model forward pass. Hence, they are invisible when we
# try to unwrapping layers using our utility function.
true_vars, computed_vars = get_shared_trainable_variables(
true_encoder, unwrapped_encoder
)
true_grads = tape.gradient(true_sqr_sum, true_vars)
computed_grads = tape.gradient(computed_sqr_sum, computed_vars)
self.assertEqual(len(true_grads), len(computed_grads))
for g1, g2 in zip(true_grads, computed_grads):
self.assertEqual(type(g1), type(g2))
if isinstance(g1, tf.IndexedSlices):
self.assertAllClose(g1.values, g2.values)
self.assertAllEqual(g2.indices, g2.indices)
else:
self.assertAllClose(g1, g2)


if __name__ == '__main__':
tf.test.main()

0 comments on commit e67d39c

Please sign in to comment.