forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
decoder.py
369 lines (324 loc) · 14.9 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer decoder that mimics a BERT encoder, to load BERT checkpoints."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.transformer import model_utils as transformer_utils
class TransformerDecoder(tf.keras.layers.Layer):
"""Transformer decoder stack."""
def __init__(self,
num_hidden_layers=12,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
attend_to_last_layer=True,
multi_channel_cross_attention=False,
**kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf_utils.get_activation(
intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.attend_to_last_layer = attend_to_last_layer
self.multi_channel_cross_attention = multi_channel_cross_attention
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(
layers.TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
intermediate_activation=self.intermediate_activation,
dropout_rate=self.hidden_dropout_prob,
attention_dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.initializer_range),
multi_channel_cross_attention=self.multi_channel_cross_attention,
name=("layer_%d" % i)))
super(TransformerDecoder, self).build(unused_input_shapes)
def call(self, inputs, cache=None, decode_loop_step=None):
"""Return the output of the decoder layer stacks.
Args:
inputs: A dictionary of inputs. `decoder_inputs` is a tf.int32 tensor for
input ids. `encoder_outputs` is a list of tensors with shape
[batch_size, input_length, hidden_size]. `self_attention_mask` is the
bias for decoder self-attention layer. [1, 1, target_length,
target_length]. `attention_mask` is the bias for encoder-decoder
attention layer, [batch_size, 1, 1, input_length].
cache: A dictionary of cache tensors, including key & value attentions.
decode_loop_step: an integer to indicate the step inside a decoding loop.
Returns:
Output of decoder layer stack.
float32 tensor with shape [batch_size, target_length, hidden_size]
"""
decoder_inputs = inputs["decoder_inputs"]
encoder_outputs = inputs["encoder_outputs"]
self_attention_mask = inputs["self_attention_mask"]
attention_mask = inputs["attention_mask"]
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
def _to_bert_self_attention_mask(matrix):
"""[1, 1, target_len, target_len] -> [bs, target_len, target_len]."""
matrix = tf.squeeze(matrix, axis=[1])
matrix = tf.tile(matrix, [batch_size, 1, 1])
return matrix
def _to_bert_encdec_attention_mask(matrix):
"""[bs, 1, 1, input_len] -> [bs, target_len, input_len]."""
if self.multi_channel_cross_attention:
matrix = tf.expand_dims(matrix, axis=2)
matrix = tf.tile(matrix, [1, 1, decoder_length, 1])
else:
matrix = tf.squeeze(matrix, axis=[1])
matrix = tf.tile(matrix, [1, decoder_length, 1])
return matrix
attention_mask = _to_bert_encdec_attention_mask(attention_mask)
self_attention_mask = _to_bert_self_attention_mask(self_attention_mask)
output_tensor = decoder_inputs
for layer_idx in range(self.num_hidden_layers):
if self.attend_to_last_layer:
memory = encoder_outputs[-1]
else:
memory = encoder_outputs[layer_idx]
if self.multi_channel_cross_attention:
transformer_inputs = [
output_tensor, memory, attention_mask, self_attention_mask,
inputs["doc_attention_probs"]
]
else:
transformer_inputs = [
output_tensor, memory, attention_mask, self_attention_mask
]
# Gets the cache for decoding.
if cache is None:
output_tensor, _ = self.layers[layer_idx](transformer_inputs)
else:
cache_layer_idx = str(layer_idx)
output_tensor, cache[cache_layer_idx] = self.layers[layer_idx](
transformer_inputs,
cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step)
return output_tensor, cache
def get_attention_bias(input_tensor,
bias_type,
padding_value=0,
max_length=None):
"""A helper function to get various attention bias tensors."""
if bias_type not in ("single_cross", "multi_cross", "decoder_self"):
raise ValueError("Invalid attention bias type: %s" % bias_type)
if bias_type == "single_cross":
length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1]
bias = transformer_utils.get_padding_bias(
input_tensor, padding_value=padding_value)
elif bias_type == "multi_cross":
length = tf_utils.get_shape_list(input_tensor, expected_rank=3)[2]
padding = transformer_utils.get_padding(
input_tensor, padding_value=padding_value)
bias = padding * -1e9
else:
if max_length is not None:
length = max_length
else:
length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1]
bias = transformer_utils.get_decoder_self_attention_bias(length)
return tf.where(bias < 0, tf.zeros_like(bias), tf.ones_like(bias))
class AttentionBias(tf.keras.layers.Layer):
def __init__(self, bias_type, **kwargs):
super(AttentionBias, self).__init__(**kwargs)
self.bias_type = bias_type
def call(self, inputs):
return get_attention_bias(inputs, self.bias_type)
class EmbeddingPostprocessor(tf.keras.layers.Layer):
"""Performs various post-processing on a word embedding tensor."""
def __init__(self,
use_type_embeddings=False,
token_type_vocab_size=None,
use_position_embeddings=True,
max_position_embeddings=512,
dropout_prob=0.0,
initializer_range=0.02,
initializer=None,
**kwargs):
super(EmbeddingPostprocessor, self).__init__(**kwargs)
self.use_type_embeddings = use_type_embeddings
self.token_type_vocab_size = token_type_vocab_size
self.use_position_embeddings = use_position_embeddings
self.max_position_embeddings = max_position_embeddings
self.dropout_prob = dropout_prob
self.initializer_range = initializer_range
if not initializer:
self.initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_range)
else:
self.initializer = initializer
if self.use_type_embeddings and not self.token_type_vocab_size:
raise ValueError("If `use_type_embeddings` is True, then "
"`token_type_vocab_size` must be specified.")
def build(self, input_shapes):
"""Implements build() for the layer."""
(word_embeddings_shape, _) = input_shapes
width = word_embeddings_shape.as_list()[-1]
self.type_embeddings = None
if self.use_type_embeddings:
self.type_embeddings = self.add_weight(
"type_embeddings",
shape=[self.token_type_vocab_size, width],
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.initializer_range),
dtype=self.dtype)
self.position_embeddings = None
if self.use_position_embeddings:
self.position_embeddings = self.add_weight(
"position_embeddings",
shape=[self.max_position_embeddings, width],
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.initializer_range),
dtype=self.dtype)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
self.output_dropout = tf.keras.layers.Dropout(
rate=self.dropout_prob, dtype=tf.float32)
super(EmbeddingPostprocessor, self).build(input_shapes)
def __call__(self, word_embeddings, token_type_ids=None, **kwargs):
inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids])
return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
unpacked_inputs = tf_utils.unpack_inputs(inputs)
word_embeddings = unpacked_inputs[0]
token_type_ids = unpacked_inputs[1]
input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
width = input_shape[2]
output = word_embeddings
if self.use_type_embeddings:
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
token_type_embeddings = tf.gather(self.type_embeddings,
flat_token_type_ids)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings
if self.use_position_embeddings:
position_embeddings = tf.expand_dims(
tf.slice(self.position_embeddings, [0, 0], [seq_length, width]),
axis=0)
output += position_embeddings
output = self.output_layer_norm(output)
output = self.output_dropout(output)
return output
class Decoder(tf.keras.layers.Layer):
"""The decoder network which can reuse encoder embeddings for target."""
def __init__(self, config, embedding_lookup=None, **kwargs):
super(Decoder, self).__init__(**kwargs)
self.config = config
# Shares vocabulary embedding.
self.embedding_lookup = None
if embedding_lookup:
self.embedding_lookup = embedding_lookup
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
if self.embedding_lookup is None:
self.embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=self.config.vocab_size,
embedding_width=self.config.hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.config.initializer_range),
name="target_embeddings")
self.embedding_postprocessor = EmbeddingPostprocessor(
use_type_embeddings=False,
use_position_embeddings=True,
max_position_embeddings=self.config.max_position_embeddings,
dropout_prob=self.config.hidden_dropout_prob,
initializer=tf.keras.initializers.VarianceScaling(
scale=self.config.initializer_gain,
mode="fan_avg",
distribution="uniform"),
name="embedding_postprocessor")
# Decoder can use a different intermediate size.
self.multi_channel_cross_attention = self.config.get(
"multi_channel_cross_attention", False)
self.decoder = TransformerDecoder(
num_hidden_layers=self.config.num_decoder_layers,
hidden_size=self.config.hidden_size,
num_attention_heads=self.config.num_decoder_attn_heads,
intermediate_size=self.config.decoder_intermediate_size,
intermediate_activation=self.config.hidden_act,
hidden_dropout_prob=self.config.hidden_dropout_prob,
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
initializer_range=self.config.initializer_range,
multi_channel_cross_attention=self.multi_channel_cross_attention,
name="decoder")
super(Decoder, self).build(unused_input_shapes)
def _decoding_step_time_signal(self, target_embeds, decode_loop_step):
"""Applies time signal (positional embeddings) for decoded embeddings."""
# TODO(hongkuny): migrate to keras bert and design a module to handle this.
output = target_embeds
if self.embedding_postprocessor.use_position_embeddings:
position_embeddings = tf.gather(
self.embedding_postprocessor.position_embeddings, [decode_loop_step])
# Broadcasts to all sequences inside a batch.
output += position_embeddings
output = self.embedding_postprocessor.output_layer_norm(output)
output = self.embedding_postprocessor.output_dropout(output)
return output
def call(self,
inputs,
cache=None,
decode_loop_step=None,
padded_decode=False):
"""Implements call() for the layer.
Args:
inputs: a list of input tensors.
cache: A dictionary of cache tensors, including key & value attentions.
Due to the limit of keras, we uses the side effect to update cache and
states of tensors will be mutated.
decode_loop_step: an integer to indicate the step inside a decoding loop.
padded_decode: a boolean indicates if the pass is for padded decoding.
Returns:
Decoder output tensors.
"""
attention_bias = inputs["attention_bias"]
target_ids = inputs["target_ids"]
all_encoder_outputs = inputs["all_encoder_outputs"]
self_attention_bias = inputs["self_attention_bias"]
if not isinstance(all_encoder_outputs, list):
all_encoder_outputs = [all_encoder_outputs]
target_embeds = self.embedding_lookup(target_ids)
if decode_loop_step is None:
target_embeds = self.embedding_postprocessor(target_embeds)
else:
target_embeds = self._decoding_step_time_signal(target_embeds,
decode_loop_step)
decoder_inputs = dict(
decoder_inputs=target_embeds,
encoder_outputs=all_encoder_outputs,
self_attention_mask=self_attention_bias,
attention_mask=attention_bias)
if self.multi_channel_cross_attention:
decoder_inputs["doc_attention_probs"] = inputs["doc_attention_probs"]
decode_outputs, cache = self.decoder(
decoder_inputs, cache, decode_loop_step if padded_decode else None)
return decode_outputs