Skip to content

Commit 0bf0225

Browse files
authored
Add a MLMHead layer (#132)
* Add a MLMClassificationHead layer * improve docs * Review comments * rename and reformat * Address more review coments * Fixup * typo fix
1 parent fca13e8 commit 0bf0225

File tree

3 files changed

+393
-0
lines changed

3 files changed

+393
-0
lines changed

keras_nlp/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.layers.fnet_encoder import FNetEncoder
16+
from keras_nlp.layers.mlm_head import MLMHead
1617
from keras_nlp.layers.position_embedding import PositionEmbedding
1718
from keras_nlp.layers.preprocessing import MLMMaskGenerator
1819
from keras_nlp.layers.sine_position_encoding import SinePositionEncoding

keras_nlp/layers/mlm_head.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Masked Language Model (MLM) head."""
16+
17+
import tensorflow as tf
18+
from tensorflow import keras
19+
20+
21+
class MLMHead(keras.layers.Layer):
22+
"""Masked Language Model (MLM) head.
23+
24+
This layer takes two inputs:
25+
- `inputs`: which should be a tensor of encoded tokens with shape
26+
`(batch_size, sequence_length, encoding_dim)`.
27+
- `mask_positions`: which should be a tensor of integer positions to
28+
predict with shape `(batch_size, masks_per_sequence)`.
29+
30+
The token encodings should usually be the last output of an encoder model,
31+
and mask positions should be the interger positions you would like to
32+
predict for the MLM task.
33+
34+
The layer will first gather the token encodings at the mask positions. These
35+
gathered tokens will be passed through a dense layer the same size as
36+
encoding dimension, then transformed to predictions the same size as the
37+
input vocabulary. This layer will produce a single output with shape
38+
`(batch_size, masks_per_sequence, vocabulary_size)`, which can be used to
39+
compute an MLM loss function.
40+
41+
This layer is often be paired with `keras_nlp.layers.MLMMaskGenerator`,
42+
which will help prepare inputs for the MLM task.
43+
44+
Args:
45+
vocabulary_size: The total size of the vocabulary for predictions.
46+
embedding_weights: Optional. The weights of the word embedding used
47+
to transform input token ids. The transpose of this weight matrix
48+
will be used to project a token embedding vector to a prediction
49+
over all input words, as described in [1].
50+
intermediate_activation: The activation function of inner dense layer.
51+
activation: The activation function for the outputs of the layer.
52+
Usually either `None` (return logits), or `"softmax"`
53+
(return probabilities).
54+
layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer
55+
normalization components.
56+
kernel_initializer: string or tf.keras.initializers initializer,
57+
defaults to "glorot_uniform". The kernel initializer for
58+
the dense and multiheaded attention layers.
59+
bias_initializer: string or tf.keras.initializers initializer,
60+
defaults to "zeros". The bias initializer for
61+
the dense and multiheaded attention layers.
62+
name: string, defaults to None. The name of the layer.
63+
**kwargs: other keyword arguments.
64+
65+
Examples:
66+
67+
```python
68+
batch_size = 32
69+
vocab_size = 100
70+
encoding_size = 32
71+
seq_length = 50
72+
mask_length = 10
73+
74+
# Generate a random encoding.
75+
encoded_tokens = tf.random.normal([batch_size, seq_length, encoding_size])
76+
# Generate random positions and labels
77+
mask_positions = tf.random.uniform(
78+
[batch_size, mask_length], maxval=seq_length, dtype="int32"
79+
)
80+
mask_ids = tf.random.uniform(
81+
[batch_size, mask_length], maxval=vocab_size, dtype="int32"
82+
)
83+
84+
# Predict an output word for each masked input token.
85+
mask_preds = keras_nlp.layers.MLMHead(
86+
vocabulary_size=vocab_size,
87+
activation="softmax",
88+
)(encoded_tokens, mask_positions=mask_positions)
89+
# Calculate a loss.
90+
keras.losses.sparse_categorical_crossentropy(mask_ids, mask_preds)
91+
```
92+
93+
References:
94+
[1] [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
95+
"""
96+
97+
def __init__(
98+
self,
99+
vocabulary_size=None,
100+
embedding_weights=None,
101+
intermediate_activation="relu",
102+
activation=None,
103+
layer_norm_epsilon=1e-05,
104+
kernel_initializer="glorot_uniform",
105+
bias_initializer="zeros",
106+
name=None,
107+
**kwargs,
108+
):
109+
super().__init__(name=name, **kwargs)
110+
111+
self.vocabulary_size = vocabulary_size
112+
self.embedding_weights = embedding_weights
113+
self.intermediate_activation = keras.activations.get(
114+
intermediate_activation
115+
)
116+
self.activation = keras.activations.get(activation)
117+
self.layer_norm_epsilon = layer_norm_epsilon
118+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
119+
self.bias_initializer = keras.initializers.get(bias_initializer)
120+
self._built = False
121+
122+
if vocabulary_size is None and embedding_weights is None:
123+
raise ValueError(
124+
"One of `vocabulary_size` or `embedding_weights` must be set. "
125+
"Received: `vocabulary_size=None`, `embedding_weights=None`"
126+
)
127+
128+
if embedding_weights is not None:
129+
shape = embedding_weights.shape
130+
if vocabulary_size is not None and vocabulary_size != shape[0]:
131+
raise ValueError(
132+
"`vocabulary_size` should match the first dimension of the "
133+
"shape of `embedding_weights`. Received: "
134+
f"`vocabulary_size={vocabulary_size}`, "
135+
f"`embedding_weights.shape={shape}`"
136+
)
137+
self.vocabulary_size = shape[0]
138+
139+
def _build(self, input_shape):
140+
# Create layers based on input shape.
141+
self._built = True
142+
feature_size = input_shape[-1]
143+
144+
self._dense = keras.layers.Dense(
145+
feature_size,
146+
activation=self.intermediate_activation,
147+
kernel_initializer=self.kernel_initializer,
148+
bias_initializer=self.bias_initializer,
149+
)
150+
self._layer_norm = tf.keras.layers.LayerNormalization(
151+
epsilon=self.layer_norm_epsilon,
152+
)
153+
if self.embedding_weights is None:
154+
self._kernel = self.add_weight(
155+
name="output_kernel",
156+
shape=[feature_size, self.vocabulary_size],
157+
initializer=self.kernel_initializer,
158+
dtype=self.dtype,
159+
)
160+
self._bias = self.add_weight(
161+
name="output_bias",
162+
shape=[self.vocabulary_size],
163+
initializer=self.bias_initializer,
164+
dtype=self.dtype,
165+
)
166+
167+
def call(self, inputs, mask_positions):
168+
if not self._built:
169+
self._build(inputs.shape)
170+
171+
# Gather the encoded tokens at the masked indices.
172+
x = tf.gather(inputs, mask_positions, axis=1, batch_dims=1)
173+
174+
# Apply a trainable linear transformation and a layer norm.
175+
x = self._dense(x)
176+
x = self._layer_norm(x)
177+
178+
# Transform encodings to vocabulary_size predictions.
179+
if self.embedding_weights is None:
180+
outputs = tf.matmul(x, self._kernel)
181+
else:
182+
outputs = tf.matmul(
183+
x,
184+
tf.cast(self.embedding_weights, self.compute_dtype),
185+
transpose_b=True,
186+
)
187+
outputs = outputs + self._bias
188+
189+
# Apply a final activation.
190+
if self.activation is not None:
191+
outputs = self.activation(outputs)
192+
193+
return outputs
194+
195+
def get_config(self):
196+
config = super().get_config()
197+
config.update(
198+
{
199+
"vocabulary_size": self.vocabulary_size,
200+
"intermediate_activation": keras.activations.serialize(
201+
self.intermediate_activation
202+
),
203+
"activation": keras.activations.serialize(self.activation),
204+
"layer_norm_epsilon": self.layer_norm_epsilon,
205+
"kernel_initializer": keras.initializers.serialize(
206+
self.kernel_initializer
207+
),
208+
"bias_initializer": keras.initializers.serialize(
209+
self.bias_initializer
210+
),
211+
}
212+
)
213+
return config

0 commit comments

Comments
 (0)