diff --git a/.gitignore b/.gitignore index c9dd3db88..fbd98dca5 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ _pycache__/ # PyPI distribution artifacts. build/ dist/ +data/ # Sublime project files *.sublime-project diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index 1a8b2c79d..94d75b48d 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -344,23 +344,33 @@ def dot_product_attention(q, return tf.matmul(weights, v) -def masked_local_attention_1d( - q, k, v, block_length=128, name=None): - """Attention to the source position and a neigborhood to the left of it. +def local_attention_1d(q, k, v, bias=None, + block_length=128, look_right=True, use_whole_block=False, + truncate_bias=True, name=None): + """Attention to the source position and a neigborhood around it. - The sequence is divided into blocks of length block_size. - Attention for a given query position can only see memory positions - less than or equal to the query position, in the corresponding block - and the previous block. + The sequence is divided into blocks of length block_size. Attention for a + given query position can only see memory positions within a certain number + of positions before and behind it. - If mask_right is True, then a target position cannot see greater source + If look_right is True then each query will attend to block_length//2 + positions either side, otherwise it will attend to block_length previous positions. + If use_whole_block is True then no mask will be applied to the local blocks + meaning the full blocks are used (if look_right is True then the elements to + the right of the current position are still masked out). This allows use to + attend to more elements without additional overhead, but means we have + inconsistent window positions and sizes. + Args: - q: a Tensor with shape [batch, heads, length, depth_k] - k: a Tensor with shape [batch, heads, length, depth_k] - v: a Tensor with shape [batch, heads, length, depth_v] + q: a Tensor with shape [batch, heads, length_q, depth_k] + k: a Tensor with shape [batch, heads, length_kv, depth_k] + v: a Tensor with shape [batch, heads, length_kv, depth_v] + bias: Not currently used [batch, heads, length_q, length_k] block_length: an integer + look_right: a bool + use_whole_block: a bool name: an optional string Returns: @@ -372,146 +382,110 @@ def masked_local_attention_1d( batch = tf.shape(q)[0] heads = tf.shape(q)[1] length = tf.shape(q)[2] - # If (length < 2 * block_length), then we use only one block. - block_length = tf.where(tf.less(length, block_length * 2), - length, block_length) depth_k = tf.shape(q)[3] depth_v = tf.shape(v)[3] + original_length = length + + #Pad to desired length + #If (length < block_length), then we use only one block. + block_length = tf.where(tf.less(length, block_length), + length, block_length) padding_size = tf.mod(-length, block_length) length += padding_size - padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]] - q = tf.pad(q, padding) - k = tf.pad(k, padding) - v = tf.pad(v, padding) num_blocks = tf.div(length, block_length) - # compute attention for the first query block. - first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1]) - first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1]) - first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1]) - first_output = dot_product_attention( - first_q, first_k, first_v, attention_bias_lower_triangle(block_length), - name="fist_block") + padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]] + q = tf.pad(q, padding) - # compute attention for all subsequent query blocks. + if not look_right: + #Add extra padding so we son't have to do an initial query + extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]] + bp = [[0, 0], [0, 0], [0, padding_size], [block_length, padding_size]] + else: + #We shift everything over by half a block so query is in centre + pad_right = block_length // 2 + pad_left = block_length - pad_right + extra_padding = [[0, 0], [0, 0], + [pad_left, padding_size+pad_right], [0, 0]] + bp = [[0, 0], [0, 0], + [0, padding_size], [pad_left, padding_size+pad_right]] + k = tf.pad(k, extra_padding) + v = tf.pad(v, extra_padding) + + # Reshape into blocks q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k]) - k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k]) - v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v]) + k = tf.reshape(k, [batch, heads, num_blocks+1, block_length, depth_k]) + v = tf.reshape(v, [batch, heads, num_blocks+1, block_length, depth_v]) + # Get local blocks by slicing def local(x): """Create a local version of the keys or values.""" prev_block = tf.slice( - x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1]) + x, [0, 0, 0, 0, 0], [-1, -1, num_blocks, -1, -1]) cur_block = tf.slice( x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) return tf.concat([prev_block, cur_block], 3) local_k = local(k) local_v = local(v) - tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) - local_length = tf.shape(local_k)[3] - # [batch, heads, num_blocks - 1, block_length, local_length] - attention = tf.matmul(tail_q, local_k, transpose_b=True) - - # make sure source_pos <= target_pos - good_part = tf.matrix_band_part( - tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) - mask = (1.0 - good_part) * -1e9 - attention += tf.reshape(mask, [1, 1, 1, block_length, local_length]) + # [batch, heads, num_blocks, block_length, local_length] + attention = tf.matmul(q, local_k, transpose_b=True) + + # Apply bias (N.B: This is not currently working) + if bias is not None: + with tf.name_scope('bias'): + b_batch = tf.shape(bias)[0] + b_heads = tf.shape(bias)[1] + bias_ = bias + #bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0) + if truncate_bias: + # Use only the query dimension + bias = tf.expand_dims(bias[:,:,:,0], 2) + bias = tf.pad(bias, extra_padding, name='bias_pad_b')# 17, 5, 3 + bias = tf.reshape(bias, + [b_batch, b_heads, 1, num_blocks+1, block_length], + name='divide_blocks') + local_b = tf.reshape(local(bias), + [b_batch, b_heads, num_blocks, 1, -1], name='reshape_local') + else: + bias = tf.pad(bias, bp, name='pad') + bias = tf.reshape(bias, + [b_batch, b_heads, num_blocks, block_length, + num_blocks+1, block_length], name='divide_blocks') + bias = tf.transpose(bias, [4,2,0,1,3,5]) + bias = tf.reshape(bias, + [num_blocks*(num_blocks+1), b_batch, b_heads, + block_length, block_length], name='combine') + indices = (num_blocks+1)*tf.range(num_blocks) + prev_block = tf.gather(bias, indices) + cur_block = tf.gather(bias, indices+num_blocks) + local_b = tf.concat([prev_block, cur_block], 4) + local_b = tf.transpose(local_b, [1,2,0,3,4]) + return l-local_b + attention += local_b + attention = tf.nn.softmax(attention) - # TODO(noam): figure out how to show a summary for the remaining blocks. - # The naive way currently causes errors due to empty tensors. - # output: [batch, heads, num_blocks-1, block_length, depth_v] - output = tf.matmul(attention, local_v) - output = tf.reshape(output, [batch, heads, -1, depth_v]) - output = tf.concat([first_output, output], axis=2) - output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) - output.set_shape(v_shape) - return output - + + # Get local mask + if not use_whole_block: + good_part = tf.matrix_band_part( + tf.ones([block_length, local_length]), 0, tf.to_int64(block_length)) + elif not look_right: + good_part = tf.matrix_band_part( + tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) + else: + good_part = tf.ones([block_length, local_length]) -def unmasked_local_attention_1d(q, k, v, block_length=128, filter_width=100, - name=None): - """strided block local self-attention. + #good_part = tf.cast(good_part, tf.float64) + attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length]) - Args: - q: a Tensor with shape [batch, heads, length, depth_k] - k: a Tensor with shape [batch, heads, length, depth_k] - v: a Tensor with shape [batch, heads, length, depth_v] - block_length: an integer - filter_width: an integer indicating how much to look left. - name: an optional string + + output = tf.matmul(attention, local_v) + output = tf.reshape(output, [batch, heads, -1, depth_v]) - Returns: - a Tensor of shape [batch, heads, length, depth_v] - """ - with tf.variable_scope(name, default_name="local_self_attention_1d", - values=[q, k, v]): - v_shape = v.get_shape() - depth_v = tf.shape(v)[3] - batch_size = tf.shape(q)[0] - num_heads = tf.shape(q)[1] - original_length = tf.shape(q)[2] - # making sure q is a multiple of d - def pad_to_multiple(x, pad_length): - x_length = tf.shape(x)[2] - return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]]) - def pad_l_and_r(x, pad_length): - return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]]) - q = pad_to_multiple(q, block_length) - k = pad_to_multiple(k, block_length) - v = pad_to_multiple(v, block_length) - - # Setting up q blocks - new_q_shape = tf.shape(q) - # Setting up q blocks - q = tf.reshape(q, [new_q_shape[0], new_q_shape[1], - new_q_shape[2]//block_length, - block_length, new_q_shape[3]]) - - # Setting up k and v values - k = pad_l_and_r(k, filter_width) - v = pad_l_and_r(v, filter_width) - - length = tf.shape(k)[2] - full_filter_width = block_length + 2*filter_width - # getting gather indices - indices = tf.range(0, length, delta=1, name="index_range") - # making indices [1, length, 1] to appy convs - indices = tf.reshape(indices, [1, -1, 1]) - kernel = tf.expand_dims(tf.eye(full_filter_width), axis=1) - gather_indices = tf.nn.conv1d( - tf.cast(indices, tf.float32), - kernel, - block_length, - padding="VALID", - name="gather_conv") - - gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0) - - # [length, batch, heads, dim] - k_t = tf.transpose(k, [2, 0, 1, 3]) - k_new = tf.gather(k_t, gather_indices) - - # [batch, heads, blocks, block_length, dim] - k_new = tf.transpose(k_new, [2, 3, 0, 1, 4]) - - attention_bias = tf.expand_dims( - tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) - - v_t = tf.transpose(v, [2, 0, 1, 3]) - v_new = tf.gather(v_t, gather_indices) - v_new = tf.transpose(v_new, [2, 3, 0, 1, 4]) - - logits = tf.matmul(q, k_new, transpose_b=True) - - attention = tf.nn.softmax(logits+attention_bias) - output = tf.matmul(attention, v_new) - - output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) - # Remove the padding if introduced + # Remove added padding output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) output.set_shape(v_shape) return output @@ -542,8 +516,8 @@ def multihead_attention(query_antecedent, dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() - attention_type: a string, either "dot_product" or "local_mask_right" or - "local_unmasked" + attention_type: a string, either "dot_product" or "local" or + "local_mask_right" block_length: an integer - relevant for "local_mask_right" name: an optional string @@ -592,11 +566,12 @@ def multihead_attention(query_antecedent, if attention_type == "dot_product": x = dot_product_attention( q, k, v, bias, dropout_rate, image_shapes) - elif attention_type == "local_mask_right": - x = masked_local_attention_1d(q, k, v, block_length=block_length) + elif attention_type == "local": + x = local_attention_1d(q, k, v, block_length=block_length) else: - assert attention_type == "local_unmasked" - x = unmasked_local_attention_1d(q, k, v, block_length=block_length) + assert attention_type == "local_mask_right" + x = local_attention_1d( + q, k, v, block_length=block_length, look_right=False) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x diff --git a/tensor2tensor/models/common_attention_test.py b/tensor2tensor/models/common_attention_test.py new file mode 100644 index 000000000..2e534ba1a --- /dev/null +++ b/tensor2tensor/models/common_attention_test.py @@ -0,0 +1,82 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for common layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np +from tensor2tensor.models import common_attention + +import tensorflow as tf + + +class CommonAttentionTest(tf.test.TestCase): + + def testLocalAttention(self): + q = np.array([[[ [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0] ]]]) + + k = np.array([[[ [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0] ]]]) + + b = np.array([[[ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ]]]) + + #b = np.ones((1,1,8,8)) + #b = (1-b) * (-1e9) + v = np.ones((1, 1, 8, 1)) + + #q = np.random.rand(5, 7, 13, 3) + #k = np.random.rand(5, 7, 13, 3) + #v = np.random.rand(5, 7, 13, 11) + #b = np.random.rand(5, 1, 13, 1) + + with self.test_session() as session: + q_ = tf.constant(q) + k_ = tf.constant(k) + v_ = tf.constant(v) + b_ = tf.constant(b) + y = common_attention.local_attention_1d(q_, k_, v_, b_, block_length=tf.constant(2)) + res = session.run(y) + #print(q) + #rint(k) + print(res) + #self.assertEqual(res.shape, (5, 7, 13, 11)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index e98531d88..ae6d0cede 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -1420,22 +1420,22 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence): return xentropy - normalizing -def global_pool_1d(inputs, pooling_type="MAX", mask=None): - """Pool elements across the last dimension. - Useful to convert a list of vectors into a single vector so as - to get a representation of a set. - - Args: - inputs: A tensor of dimensions batch_size x sequence_length x input_dims - containing the sequences of input vectors. - pooling_type: the pooling type to use, MAX or AVR - mask: A tensor of dimensions batch_size x sequence_length containing a - mask for the inputs with 1's for existing elements, and 0's elsewhere. - - Returns: - output: A tensor of dimensions batch_size x input_dims - dimension containing the sequences of transformed vectors. +def global_pool_1d(inputs, pooling_type='MAX', mask=None): + """ + Pools elements across the last dimension. Useful to a list of vectors into a + single vector to get a representation of a set. + Concatenating + + Args + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + pooling_type: the pooling type to use, MAX or AVR + mask: A tensor of dimensions batch_size x sequence_length containing a + mask for the inputs with 1's for existing elements, and 0's elsewhere. + Returns + output: A tensor of dimensions batch_size x input_dims + dimension containing the sequences of transformed vectors. """ with tf.name_scope("global_pool", [inputs]): if mask is not None: @@ -1457,6 +1457,38 @@ def global_pool_1d(inputs, pooling_type="MAX", mask=None): return output + +def running_global_pool_1d(inputs, pooling_type='MAX'): + """ + Same global pool, but only for the elements up to the current element. Useful + for outputs where the state of future elements is not known. + Takes no mask as all elements up to the current element are assumed to exist. + Currently only supports maximum. Equivalent to using a lower triangle bias. + + Args + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + pooling_type: Pooling type to use. Currently only supports 'MAX'. + Returns + output: A tensor of dimensions batch_size x sequence_length x input_dims + dimension containing the running 'totals'. + """ + + with tf.name_scope("running_global_pool", [inputs]): + scan_fct = tf.maximum + + # Permute inputs so seq_length is first + elems = tf.transpose(inputs, [1, 0, 2]) + + # Perform scan + cumulatives = tf.scan(scan_fct, elems, swap_memory=True) + + # Permute output to get back to original order + output = tf.transpose(cumulatives, [1, 0, 2]) + + return output + + def linear_set_layer(layer_size, inputs, context=None, @@ -1470,21 +1502,19 @@ def linear_set_layer(layer_size, e.g. One can use global_pool_1d to get a representation of the set which can then be used as the context for the next layer. - TODO: Add bias add (or control the biases used). + Args + layer_size: Dimension to transform the input vectors to + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + context: A tensor of dimensions batch_size x context_dims or batch_size x + sequence_length x context_dims containing a global statistic about the + set. + dropout: Dropout probability. + activation_fn: The activation function to use. + Returns + output: A tensor of dimensions batch_size x sequence_length x output_dims + dimension containing the sequences of transformed vectors. - Args: - layer_size: Dimension to transform the input vectors to. - inputs: A tensor of dimensions batch_size x sequence_length x input_dims - containing the sequences of input vectors. - context: A tensor of dimensions batch_size x context_dims - containing a global statistic about the set. - activation_fn: The activation function to use. - dropout: Dropout probability. - name: name. - - Returns: - output: A tensor of dimensions batch_size x sequence_length x output_dims - dimension containing the sequences of transformed vectors. """ with tf.variable_scope(name, "linear_set_layer", [inputs]): # Apply 1D convolution to apply linear filter to each element @@ -1494,10 +1524,12 @@ def linear_set_layer(layer_size, # Apply the context if it exists. if context is not None: # Unfortunately tf doesn't support broadcasting via concat, but we can - # simply add the transformed context to get the same effect. - context = tf.expand_dims(context, axis=1) - cont_tfm = conv1d( - context, layer_size, 1, activation=None, name="cont_conv") + # simply add the transformed context to get the same effect + if len(context.get_shape().as_list())==2: + context = tf.expand_dims(context, axis=1) + #context_size = context.get_shape().as_list()[-1] + cont_tfm = conv1d(context, layer_size, 1, + activation=None, name="cont_conv") outputs += cont_tfm if activation_fn is not None: @@ -1512,6 +1544,7 @@ def linear_set_layer(layer_size, def ravanbakhsh_set_layer(layer_size, inputs, mask=None, + sequential=False, activation_fn=tf.nn.tanh, dropout=0.0, name=None): @@ -1519,26 +1552,35 @@ def ravanbakhsh_set_layer(layer_size, More parameter-efficient verstion of a linear-set-layer with context. - Args: - layer_size: Dimension to transform the input vectors to. - inputs: A tensor of dimensions batch_size x sequence_length x vector - containing the sequences of input vectors. - mask: A tensor of dimensions batch_size x sequence_length containing a - mask for the inputs with 1's for existing elements, and 0's elsewhere. - activation_fn: The activation function to use. - dropout: dropout. - name: name. - - Returns: - output: A tensor of dimensions batch_size x sequence_length x vector - dimension containing the sequences of transformed vectors. + Args + layer_size: Dimension to transform the input vectors to. + inputs: A tensor of dimensions batch_size x sequence_length x vector + containing the sequences of input vectors. + mask: A tensor of dimensions batch_size x sequence_length containing a + mask for the inputs with 1's for existing elements, and 0's elsewhere. + sequential: If true, will use a running global pool so each element will + only depend on those before it. Set true if this layer is being used in + an ouput sequence. + Returns + output: A tensor of dimensions batch_size x sequence_length x vector + dimension containing the sequences of transformed vectors. """ with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]): - output = linear_set_layer( - layer_size, - inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1), - activation_fn=activation_fn, - dropout=dropout, - name=name) + + if sequential: + output = linear_set_layer( + layer_size, + inputs - running_global_pool_1d(inputs), + activation_fn=activation_fn, + name=name) + else: + output = linear_set_layer( + layer_size, + inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1), + activation_fn=activation_fn, + name=name) + + return output + return output diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index 62413c325..78398471a 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -50,17 +50,13 @@ def model_fn_body(self, features): inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) - (encoder_input, encoder_attention_bias, - _) = transformer.transformer_prepare_encoder(inputs, target_space, hparams) - (decoder_input, - decoder_self_attention_bias) = transformer.transformer_prepare_decoder( - targets, hparams) - # We need masks of the form batch size x input sequences - # Biases seem to be of the form batch_size x 1 x input sequences x vec dim - # Squeeze out dim one, and get the first element of each vector. - encoder_mask = tf.squeeze(encoder_attention_bias, [1])[:, :, 0] - decoder_mask = tf.squeeze(decoder_self_attention_bias, [1])[:, :, 0] + (encoder_input, encoder_attention_bias, _) = (transformer.\ + transformer_prepare_encoder(inputs, target_space, hparams) ) + (decoder_input, decoder_self_attention_bias) = transformer.\ + transformer_prepare_decoder(targets, hparams) + + encoder_mask = bias_to_mask(encoder_attention_bias) def residual_fn(x, y): return common_layers.layer_norm(x + tf.nn.dropout( @@ -68,20 +64,20 @@ def residual_fn(x, y): encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout) + encoder_output = alt_transformer_encoder( encoder_input, residual_fn, encoder_mask, hparams) decoder_output = alt_transformer_decoder( - decoder_input, encoder_output, residual_fn, decoder_mask, + decoder_input, encoder_output, residual_fn, encoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output - -def composite_layer(inputs, mask, hparams): - """Composite layer.""" + +def composite_layer(inputs, mask, hparams, for_output=False): x = inputs # Applies ravanbakhsh on top of each other. @@ -89,28 +85,32 @@ def composite_layer(inputs, mask, hparams): for layer in xrange(hparams.layers_per_layer): with tf.variable_scope(".%d" % layer): x = common_layers.ravanbakhsh_set_layer( - hparams.hidden_size, - x, - mask=mask, - dropout=0.0) - - # Transforms elements to get a context, and then uses this in a final layer. + hparams.hidden_size, + x, + mask=mask, + sequential=for_output, + dropout=hparams.relu_dropout) + + # Transforms elements to get a context, and then uses this in a final layer elif hparams.composite_layer_type == "reembedding": # Transform elements n times and then pool. for layer in xrange(hparams.layers_per_layer): - with tf.variable_scope(".%d" % layer): + with tf.variable_scope("sub_layer_%d" % layer): x = common_layers.linear_set_layer( + hparams.hidden_size, + x, + dropout=hparams.relu_dropout) + if for_output: + context = common_layers.running_global_pool_1d(x) + else: + context = common_layers.global_pool_1d(x, mask=mask) + + #Final layer + x = common_layers.linear_set_layer( hparams.hidden_size, x, - dropout=0.0) - context = common_layers.global_pool_1d(x, mask=mask) - - # Final layer. - x = common_layers.linear_set_layer( - hparams.hidden_size, - x, - context=context, - dropout=0.0) + context=context, + dropout=hparams.relu_dropout) return x @@ -120,10 +120,12 @@ def alt_transformer_encoder(encoder_input, mask, hparams, name="encoder"): + """Alternative encoder.""" x = encoder_input with tf.variable_scope(name): + x = encoder_input for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x = residual_fn(x, composite_layer(x, mask, hparams)) @@ -134,14 +136,12 @@ def alt_transformer_encoder(encoder_input, def alt_transformer_decoder(decoder_input, encoder_output, residual_fn, - mask, encoder_decoder_attention_bias, hparams, name="decoder"): - """Alternative decoder.""" - x = decoder_input with tf.variable_scope(name): + x = decoder_input for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -156,17 +156,33 @@ def alt_transformer_decoder(decoder_input, hparams.attention_dropout, name="encdec_attention") - x_ = residual_fn(x_, composite_layer(x_, mask, hparams)) + x_ = residual_fn(x_, composite_layer(x_, None, hparams, for_output=True)) x = residual_fn(x, x_) - + return x +def bias_to_mask(bias): + # We need masks of the form batch size x input sequences + # Biases are of the form batch_size x num_heads x input sequences x + # output sequences. Squeeze out dim one, and get the first element of + # each vector. + + bias = tf.squeeze(bias, [1])[:,:,0] + bias = - tf.clip_by_value(bias, -1.0, 1.0) + mask = 1 - bias + return mask + + + @registry.register_hparams def transformer_alt(): """Set of hyperparameters.""" hparams = transformer.transformer_base() - hparams.batch_size = 64 + hparams.batch_size = 2048 + hparams.num_hidden_layers = 10 hparams.add_hparam("layers_per_layer", 4) - hparams.add_hparam("composite_layer_type", "reembedding") + hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding + #hparams.add_hparam("composite_layer_type", "reembedding") + return hparams