|
| 1 | +# Copyright 2022 The KerasNLP Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Masked Language Model (MLM) head.""" |
| 16 | + |
| 17 | +import tensorflow as tf |
| 18 | +from tensorflow import keras |
| 19 | + |
| 20 | + |
| 21 | +class MLMHead(keras.layers.Layer): |
| 22 | + """Masked Language Model (MLM) head. |
| 23 | +
|
| 24 | + This layer takes two inputs: |
| 25 | + - `inputs`: which should be a tensor of encoded tokens with shape |
| 26 | + `(batch_size, sequence_length, encoding_dim)`. |
| 27 | + - `mask_positions`: which should be a tensor of integer positions to |
| 28 | + predict with shape `(batch_size, masks_per_sequence)`. |
| 29 | +
|
| 30 | + The token encodings should usually be the last output of an encoder model, |
| 31 | + and mask positions should be the interger positions you would like to |
| 32 | + predict for the MLM task. |
| 33 | +
|
| 34 | + The layer will first gather the token encodings at the mask positions. These |
| 35 | + gathered tokens will be passed through a dense layer the same size as |
| 36 | + encoding dimension, then transformed to predictions the same size as the |
| 37 | + input vocabulary. This layer will produce a single output with shape |
| 38 | + `(batch_size, masks_per_sequence, vocabulary_size)`, which can be used to |
| 39 | + compute an MLM loss function. |
| 40 | +
|
| 41 | + This layer is often be paired with `keras_nlp.layers.MLMMaskGenerator`, |
| 42 | + which will help prepare inputs for the MLM task. |
| 43 | +
|
| 44 | + Args: |
| 45 | + vocabulary_size: The total size of the vocabulary for predictions. |
| 46 | + embedding_weights: Optional. The weights of the word embedding used |
| 47 | + to transform input token ids. The transpose of this weight matrix |
| 48 | + will be used to project a token embedding vector to a prediction |
| 49 | + over all input words, as described in [1]. |
| 50 | + intermediate_activation: The activation function of inner dense layer. |
| 51 | + activation: The activation function for the outputs of the layer. |
| 52 | + Usually either `None` (return logits), or `"softmax"` |
| 53 | + (return probabilities). |
| 54 | + layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer |
| 55 | + normalization components. |
| 56 | + kernel_initializer: string or tf.keras.initializers initializer, |
| 57 | + defaults to "glorot_uniform". The kernel initializer for |
| 58 | + the dense and multiheaded attention layers. |
| 59 | + bias_initializer: string or tf.keras.initializers initializer, |
| 60 | + defaults to "zeros". The bias initializer for |
| 61 | + the dense and multiheaded attention layers. |
| 62 | + name: string, defaults to None. The name of the layer. |
| 63 | + **kwargs: other keyword arguments. |
| 64 | +
|
| 65 | + Examples: |
| 66 | +
|
| 67 | + ```python |
| 68 | + batch_size = 32 |
| 69 | + vocab_size = 100 |
| 70 | + encoding_size = 32 |
| 71 | + seq_length = 50 |
| 72 | + mask_length = 10 |
| 73 | +
|
| 74 | + # Generate a random encoding. |
| 75 | + encoded_tokens = tf.random.normal([batch_size, seq_length, encoding_size]) |
| 76 | + # Generate random positions and labels |
| 77 | + mask_positions = tf.random.uniform( |
| 78 | + [batch_size, mask_length], maxval=seq_length, dtype="int32" |
| 79 | + ) |
| 80 | + mask_ids = tf.random.uniform( |
| 81 | + [batch_size, mask_length], maxval=vocab_size, dtype="int32" |
| 82 | + ) |
| 83 | +
|
| 84 | + # Predict an output word for each masked input token. |
| 85 | + mask_preds = keras_nlp.layers.MLMHead( |
| 86 | + vocabulary_size=vocab_size, |
| 87 | + activation="softmax", |
| 88 | + )(encoded_tokens, mask_positions=mask_positions) |
| 89 | + # Calculate a loss. |
| 90 | + keras.losses.sparse_categorical_crossentropy(mask_ids, mask_preds) |
| 91 | + ``` |
| 92 | +
|
| 93 | + References: |
| 94 | + [1] [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859) |
| 95 | + """ |
| 96 | + |
| 97 | + def __init__( |
| 98 | + self, |
| 99 | + vocabulary_size=None, |
| 100 | + embedding_weights=None, |
| 101 | + intermediate_activation="relu", |
| 102 | + activation=None, |
| 103 | + layer_norm_epsilon=1e-05, |
| 104 | + kernel_initializer="glorot_uniform", |
| 105 | + bias_initializer="zeros", |
| 106 | + name=None, |
| 107 | + **kwargs, |
| 108 | + ): |
| 109 | + super().__init__(name=name, **kwargs) |
| 110 | + |
| 111 | + self.vocabulary_size = vocabulary_size |
| 112 | + self.embedding_weights = embedding_weights |
| 113 | + self.intermediate_activation = keras.activations.get( |
| 114 | + intermediate_activation |
| 115 | + ) |
| 116 | + self.activation = keras.activations.get(activation) |
| 117 | + self.layer_norm_epsilon = layer_norm_epsilon |
| 118 | + self.kernel_initializer = keras.initializers.get(kernel_initializer) |
| 119 | + self.bias_initializer = keras.initializers.get(bias_initializer) |
| 120 | + self._built = False |
| 121 | + |
| 122 | + if vocabulary_size is None and embedding_weights is None: |
| 123 | + raise ValueError( |
| 124 | + "One of `vocabulary_size` or `embedding_weights` must be set. " |
| 125 | + "Received: `vocabulary_size=None`, `embedding_weights=None`" |
| 126 | + ) |
| 127 | + |
| 128 | + if embedding_weights is not None: |
| 129 | + shape = embedding_weights.shape |
| 130 | + if vocabulary_size is not None and vocabulary_size != shape[0]: |
| 131 | + raise ValueError( |
| 132 | + "`vocabulary_size` should match the first dimension of the " |
| 133 | + "shape of `embedding_weights`. Received: " |
| 134 | + f"`vocabulary_size={vocabulary_size}`, " |
| 135 | + f"`embedding_weights.shape={shape}`" |
| 136 | + ) |
| 137 | + self.vocabulary_size = shape[0] |
| 138 | + |
| 139 | + def _build(self, input_shape): |
| 140 | + # Create layers based on input shape. |
| 141 | + self._built = True |
| 142 | + feature_size = input_shape[-1] |
| 143 | + |
| 144 | + self._dense = keras.layers.Dense( |
| 145 | + feature_size, |
| 146 | + activation=self.intermediate_activation, |
| 147 | + kernel_initializer=self.kernel_initializer, |
| 148 | + bias_initializer=self.bias_initializer, |
| 149 | + ) |
| 150 | + self._layer_norm = tf.keras.layers.LayerNormalization( |
| 151 | + epsilon=self.layer_norm_epsilon, |
| 152 | + ) |
| 153 | + if self.embedding_weights is None: |
| 154 | + self._kernel = self.add_weight( |
| 155 | + name="output_kernel", |
| 156 | + shape=[feature_size, self.vocabulary_size], |
| 157 | + initializer=self.kernel_initializer, |
| 158 | + dtype=self.dtype, |
| 159 | + ) |
| 160 | + self._bias = self.add_weight( |
| 161 | + name="output_bias", |
| 162 | + shape=[self.vocabulary_size], |
| 163 | + initializer=self.bias_initializer, |
| 164 | + dtype=self.dtype, |
| 165 | + ) |
| 166 | + |
| 167 | + def call(self, inputs, mask_positions): |
| 168 | + if not self._built: |
| 169 | + self._build(inputs.shape) |
| 170 | + |
| 171 | + # Gather the encoded tokens at the masked indices. |
| 172 | + x = tf.gather(inputs, mask_positions, axis=1, batch_dims=1) |
| 173 | + |
| 174 | + # Apply a trainable linear transformation and a layer norm. |
| 175 | + x = self._dense(x) |
| 176 | + x = self._layer_norm(x) |
| 177 | + |
| 178 | + # Transform encodings to vocabulary_size predictions. |
| 179 | + if self.embedding_weights is None: |
| 180 | + outputs = tf.matmul(x, self._kernel) |
| 181 | + else: |
| 182 | + outputs = tf.matmul( |
| 183 | + x, |
| 184 | + tf.cast(self.embedding_weights, self.compute_dtype), |
| 185 | + transpose_b=True, |
| 186 | + ) |
| 187 | + outputs = outputs + self._bias |
| 188 | + |
| 189 | + # Apply a final activation. |
| 190 | + if self.activation is not None: |
| 191 | + outputs = self.activation(outputs) |
| 192 | + |
| 193 | + return outputs |
| 194 | + |
| 195 | + def get_config(self): |
| 196 | + config = super().get_config() |
| 197 | + config.update( |
| 198 | + { |
| 199 | + "vocabulary_size": self.vocabulary_size, |
| 200 | + "intermediate_activation": keras.activations.serialize( |
| 201 | + self.intermediate_activation |
| 202 | + ), |
| 203 | + "activation": keras.activations.serialize(self.activation), |
| 204 | + "layer_norm_epsilon": self.layer_norm_epsilon, |
| 205 | + "kernel_initializer": keras.initializers.serialize( |
| 206 | + self.kernel_initializer |
| 207 | + ), |
| 208 | + "bias_initializer": keras.initializers.serialize( |
| 209 | + self.bias_initializer |
| 210 | + ), |
| 211 | + } |
| 212 | + ) |
| 213 | + return config |
0 commit comments