From 3b4a000bc2d8d5f7b977428716036b687b72f095 Mon Sep 17 00:00:00 2001 From: William Woof Date: Thu, 13 Jul 2017 14:36:39 +0100 Subject: [PATCH 1/7] Create notes.md --- notes.md | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 notes.md diff --git a/notes.md b/notes.md new file mode 100644 index 000000000..3c082f6e9 --- /dev/null +++ b/notes.md @@ -0,0 +1,75 @@ + +```python + +def unmagic_encoder(encoder_input, + hparams, + name="encoder"): + x = encoder_input + + # Summaries don't work in multi-problem setting yet. + summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 + + with tf.variable_scope(name): + pass + return x + +def magic_decoder(decoder_input, + encoder_output, + residual_fn, + encoder_self_attention_bias, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + name="decoder"): + x = decoder_input + y = encoder_output + # Summaries don't work in multi-problem setting yet. + summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 + with tf.variable_scope(name): + for layer in xrange(hparams.num_hidden_layers): + with tf.variable_scope("layer_%d" % layer): + x = residual_fn( + x, + common_attention.multihead_attention( + x, + None, + decoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + summaries=summaries, + name="decoder_self_attention")) + with tf.variable_scope("enc"): + y = residual_fn( + y, + common_attention.multihead_attention( + y, + None, + encoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + summaries=summaries, + name="encoder_self_attention")) + y = residual_fn(y, transformer.transformer_ffn_layer(y, hparams)) + + x = residual_fn( + x, + common_attention.multihead_attention( + x, + y, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + summaries=summaries, + name="encdec_attention")) + x = residual_fn(x, transformer.transformer_ffn_layer(x, hparams)) + return x +``` From 8fa26668202fa436a2e399288f25af6851cc2b73 Mon Sep 17 00:00:00 2001 From: William Date: Thu, 13 Jul 2017 16:10:16 +0100 Subject: [PATCH 2/7] Updated masking --- tensor2tensor/models/common_layers.py | 33 +++++++++++++++- .../models/transformer_alternative.py | 38 +++++++++++-------- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 4c63ce8ba..2597ccf7a 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -1416,6 +1416,36 @@ def global_pool_1d(inputs, pooling_type='MAX', mask=None): return output + +def running_global_pool_1d(inputs): + """ + Same global pool, but only for the elements up to the current element. Useful + for outputs where the state of future elements is not known. + Takes no mask as all elements up to the current element are assumed to exist. + Currently only supports maximum. + + Args + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + Outputs + output: A tensor of dimensions batch_size x sequence_length x input_dims + dimension containing the running 'totals'. + """ + + with tf.name_scope("running_global_pool", [inputs]): + scan_fct = tf.maximum + + # Permute inputs so seq_length is first + elems = tf.transpose(inputs, [1, 0, 2]) + + # Perform scan + cumulatives = tf.scan(scan_fct, elems, swap_memory=True) + + # Permute output to get back to original order + output = tf.transpose(cumulatives, [1, 0, 2]) + + return output + def linear_set_layer(layer_size, inputs, @@ -1455,7 +1485,8 @@ def linear_set_layer(layer_size, if context is not None: # Unfortunately tf doesn't support broadcasting via concat, but we can # simply add the transformed context to get the same effect - context = tf.expand_dims(context, axis=1) + if len(context.get_shape().as_list())==2: + context = tf.expand_dims(context, axis=1) #context_size = context.get_shape().as_list()[-1] cont_tfm = conv1d(context, layer_size, 1, activation=None, name="cont_conv") diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index 90fea6139..b6c2adc74 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -45,8 +45,7 @@ class TransformerAlt(t2t_model.T2TModel): def model_fn_body(self, features): - # - + # Remove dropout if not training hparams = copy.copy(self._hparams) targets = features["targets"] @@ -61,11 +60,8 @@ def model_fn_body(self, features): (decoder_input, decoder_self_attention_bias) = transformer.\ transformer_prepare_decoder(targets, hparams) - # We need masks of the form batch size x input sequences - # Biases seem to be of the form batch_size x 1 x input sequences x vec dim - # Squeeze out dim one, and get the first element of each vector - encoder_mask = tf.squeeze(encoder_attention_bias, [1])[:,:,0] - decoder_mask = tf.squeeze(decoder_self_attention_bias, [1])[:,:,0] + encoder_mask = bias_to_mask(encoder_attention_bias) + decoder_mask = bias_to_mask(decoder_self_attention_bias) def residual_fn(x, y): return common_layers.layer_norm(x + tf.nn.dropout( @@ -86,7 +82,7 @@ def residual_fn(x, y): -def composite_layer(inputs, mask, hparams): +def composite_layer(inputs, mask, hparams, for_output=False): x = inputs # Applies ravanbakhsh on top of each other @@ -97,26 +93,29 @@ def composite_layer(inputs, mask, hparams): hparams.hidden_size, x, mask=mask, - dropout=0.0) + dropout=hparams.relu_dropout) # Transforms elements to get a context, and then uses this in a final layer elif hparams.composite_layer_type == "reembedding": initial_elems = x # Transform elements n times and then pool for layer in xrange(hparams.layers_per_layer): - with tf.variable_scope(".%d" % layer): + with tf.variable_scope("sub_layer_%d" % layer): x = common_layers.linear_set_layer( hparams.hidden_size, x, - dropout=0.0) - context = common_layers.global_pool_1d(x, mask=mask) + dropout=hparams.relu_dropout) + if for_output: + context = common_layers.running_global_pool_1d(x) + else: + context = common_layers.global_pool_1d(x, mask=mask) #Final layer x = common_layers.linear_set_layer( hparams.hidden_size, x, context=context, - dropout=0.0) + dropout=hparams.relu_dropout) return x @@ -169,12 +168,19 @@ def alt_transformer_decoder(decoder_input, summaries=summaries, name="encdec_attention") - x_ = residual_fn(x_, composite_layer(x_, mask, hparams)) + x_ = residual_fn(x_, composite_layer(x_, mask, hparams, for_output=True)) x = residual_fn(x, x_) return x - +def bias_to_mask(bias): + # We need masks of the form batch size x input sequences + # Biases seem to be of the form batch_size x 1 x input sequences x vec dim + # Squeeze out dim one, and get the first element of each vector + bias = tf.squeeze(bias, [1])[:,:,0] + bias = - tf.clip_by_value(bias, -1.0, 1.0) + mask = 1 - bias + return mask @@ -182,7 +188,7 @@ def alt_transformer_decoder(decoder_input, def transformer_alt(): """Set of hyperparameters.""" hparams = transformer.transformer_base() - hparams.batch_size = 64 + hparams.batch_size = 2048 hparams.add_hparam("layers_per_layer", 4) #hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding hparams.add_hparam("composite_layer_type", "reembedding") From 10ba26822689cc0428e239426f5f1062b3c7f4cf Mon Sep 17 00:00:00 2001 From: William Woof Date: Fri, 21 Jul 2017 14:01:30 +0100 Subject: [PATCH 3/7] Update notes.md --- notes.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/notes.md b/notes.md index 3c082f6e9..a08b96244 100644 --- a/notes.md +++ b/notes.md @@ -73,3 +73,36 @@ def magic_decoder(decoder_input, x = residual_fn(x, transformer.transformer_ffn_layer(x, hparams)) return x ``` + +``` +def sliding_window(q, + k, + v, + bias, + window_size=None, + dropout_rate=0.0, + summaries=False, + name=None): + + def single(index, size, q, k, v, **kwargs): + # q initially of form batch x heads x depth + + length = tf.shape(k)[2] + index_begin = tf.maximum(0, index-size) + index_end = tf.minimum(length-1, index+size) + + q = tf.expand_dims(q, 2) + k = k[:,:,index_begin:index_end,:] + v = v[:,:,index_begin:index_end,:] + out = dot_product_attention(q, k, v, **kwargs) + out = tf.squeeze(out, 2) + return out + + q = tf.transpose(q, [2, 0, 1, 3]) + indices = tf.range(tf.shape(q)[0]) + + out = tf.map_fn(lambda ii: single(ii, 10, q[ii], k, v, bias=None), indices, dtype=tf.float32) + out = tf.transpose(out, [1, 2, 0, 3]) + + return out +``` From 73fa681f5a0caa722f71c07f5de5ea99dc77d2b3 Mon Sep 17 00:00:00 2001 From: William Date: Fri, 21 Jul 2017 21:18:42 +0100 Subject: [PATCH 4/7] Added running pooling and sliding window attention. --- notes.md | 108 ------------------ tensor2tensor/models/common_attention.py | 62 +++++++++- tensor2tensor/models/common_layers.py | 45 +++++--- tensor2tensor/models/transformer.py | 3 +- .../models/transformer_alternative.py | 25 ++-- 5 files changed, 104 insertions(+), 139 deletions(-) delete mode 100644 notes.md diff --git a/notes.md b/notes.md deleted file mode 100644 index a08b96244..000000000 --- a/notes.md +++ /dev/null @@ -1,108 +0,0 @@ - -```python - -def unmagic_encoder(encoder_input, - hparams, - name="encoder"): - x = encoder_input - - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 - - with tf.variable_scope(name): - pass - return x - -def magic_decoder(decoder_input, - encoder_output, - residual_fn, - encoder_self_attention_bias, - decoder_self_attention_bias, - encoder_decoder_attention_bias, - hparams, - name="decoder"): - x = decoder_input - y = encoder_output - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 - with tf.variable_scope(name): - for layer in xrange(hparams.num_hidden_layers): - with tf.variable_scope("layer_%d" % layer): - x = residual_fn( - x, - common_attention.multihead_attention( - x, - None, - decoder_self_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout, - summaries=summaries, - name="decoder_self_attention")) - with tf.variable_scope("enc"): - y = residual_fn( - y, - common_attention.multihead_attention( - y, - None, - encoder_self_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout, - summaries=summaries, - name="encoder_self_attention")) - y = residual_fn(y, transformer.transformer_ffn_layer(y, hparams)) - - x = residual_fn( - x, - common_attention.multihead_attention( - x, - y, - encoder_decoder_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout, - summaries=summaries, - name="encdec_attention")) - x = residual_fn(x, transformer.transformer_ffn_layer(x, hparams)) - return x -``` - -``` -def sliding_window(q, - k, - v, - bias, - window_size=None, - dropout_rate=0.0, - summaries=False, - name=None): - - def single(index, size, q, k, v, **kwargs): - # q initially of form batch x heads x depth - - length = tf.shape(k)[2] - index_begin = tf.maximum(0, index-size) - index_end = tf.minimum(length-1, index+size) - - q = tf.expand_dims(q, 2) - k = k[:,:,index_begin:index_end,:] - v = v[:,:,index_begin:index_end,:] - out = dot_product_attention(q, k, v, **kwargs) - out = tf.squeeze(out, 2) - return out - - q = tf.transpose(q, [2, 0, 1, 3]) - indices = tf.range(tf.shape(q)[0]) - - out = tf.map_fn(lambda ii: single(ii, 10, q[ii], k, v, bias=None), indices, dtype=tf.float32) - out = tf.transpose(out, [1, 2, 0, 3]) - - return out -``` diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index b6a5e09d6..e8700433a 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -345,6 +345,57 @@ def dot_product_attention(q, return tf.matmul(weights, v) +def sliding_window_attention(window_size, + q, + k, + v, + bias, + *args): + """ Sliding window wrapper for dot product attention. Each element only + attends to the elements (window_size/2) before and after it. This reduces + the computational complexity for long sequences at the expense of eliminating + long-term dependencies. + + N.B: For short input sequences this is much slower than just using + un-windowed attention. use only for long sequences. + + Args: + window_size: an integer + q: a Tensor with shape [batch, heads, length_q, depth_k] + k: a Tensor with shape [batch, heads, length_kv, depth_k] + v: a Tensor with shape [batch, heads, length_kv, depth_v] + bias: bias Tensor (see attention_bias()) + + Returns: + A Tensor. + """ + + half_size = window_size // 2 + + # Wrapper function for dot product attention with a single query vector + def single(index, size, q, k, v, bias, **kwargs): + length_kv = tf.shape(k)[2] + index_begin = tf.maximum(0, index-size) + index_end = tf.minimum(length_kv-1, index+size) + q = tf.expand_dims(q, 2) + bias = tf.expand_dims(bias, 3) + k = k[:,:,index_begin:index_end,:] + v = v[:,:,index_begin:index_end,:] + out = dot_product_attention(q, k, v, bias, *args) + out = tf.squeeze(out, 2) + return out + + # We'll loop over each element of q, computing it's corresponding output. + q = tf.transpose(q, [2, 0, 1, 3]) + indices = tf.range(tf.shape(q)[0]) + out = tf.map_fn( + lambda ii: single(ii, half_size, q[ii], k, v, bias[:,:,:,ii]), + indices, + dtype=tf.float32) + out = tf.transpose(out, [1, 2, 0, 3]) + return out + + def multihead_attention(query_antecedent, memory_antecedent, bias, @@ -355,6 +406,7 @@ def multihead_attention(query_antecedent, dropout_rate, summaries=False, image_shapes=None, + window_size=None, name=None): """Multihead scaled-dot-product attention with input/output transformations. @@ -370,6 +422,8 @@ def multihead_attention(query_antecedent, summaries: a boolean image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() + window_size: option size of window for attention. Useful only for very long + sequence lengths. name: an optional string Returns: @@ -403,8 +457,12 @@ def multihead_attention(query_antecedent, v = split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 - x = dot_product_attention( - q, k, v, bias, dropout_rate, summaries, image_shapes) + if window_size is None: + x = dot_product_attention( + q, k, v, bias, dropout_rate, summaries, image_shapes) + else: + x = sliding_window_attention( + window_size, q, k, v, bias, dropout_rate, False, image_shapes) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 2597ccf7a..1c93077aa 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -1379,11 +1379,13 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence): logits=logits, labels=soft_targets) return xentropy - normalizing + def global_pool_1d(inputs, pooling_type='MAX', mask=None): """ Pools elements across the last dimension. Useful to a list of vectors into a single vector to get a representation of a set. + Concatenating Args inputs: A tensor of dimensions batch_size x sequence_length x input_dims @@ -1415,18 +1417,19 @@ def global_pool_1d(inputs, pooling_type='MAX', mask=None): output = tf.reduce_mean(inputs, axis=1) return output - -def running_global_pool_1d(inputs): + +def running_global_pool_1d(inputs, pooling_type='MAX'): """ Same global pool, but only for the elements up to the current element. Useful for outputs where the state of future elements is not known. Takes no mask as all elements up to the current element are assumed to exist. - Currently only supports maximum. + Currently only supports maximum. Equivalent to using a lower triangle bias. Args inputs: A tensor of dimensions batch_size x sequence_length x input_dims containing the sequences of input vectors. + pooling_type: Pooling type to use. Currently only supports 'MAX'. Outputs output: A tensor of dimensions batch_size x sequence_length x input_dims dimension containing the running 'totals'. @@ -1438,7 +1441,7 @@ def running_global_pool_1d(inputs): # Permute inputs so seq_length is first elems = tf.transpose(inputs, [1, 0, 2]) - # Perform scan + # Perform scan cumulatives = tf.scan(scan_fct, elems, swap_memory=True) # Permute output to get back to original order @@ -1446,7 +1449,7 @@ def running_global_pool_1d(inputs): return output - + def linear_set_layer(layer_size, inputs, context=None, @@ -1464,15 +1467,14 @@ def linear_set_layer(layer_size, layer_size: Dimension to transform the input vectors to inputs: A tensor of dimensions batch_size x sequence_length x input_dims containing the sequences of input vectors. - context: A tensor of dimensions batch_size x context_dims - containing a global statistic about the set. + context: A tensor of dimensions batch_size x context_dims or batch_size x + sequence_length x context_dims containing a global statistic about the + set. dropout: Dropout probability. activation_fn: The activation function to use. Outputs output: A tensor of dimensions batch_size x sequence_length x output_dims dimension containing the sequences of transformed vectors. - - TODO: Add bias add. """ with tf.variable_scope(name, "linear_set_layer", [inputs]): @@ -1500,10 +1502,12 @@ def linear_set_layer(layer_size, return outputs - + + def ravanbakhsh_set_layer(layer_size, inputs, mask=None, + sequential=False, activation_fn=tf.nn.tanh, dropout=0.0, name=None): @@ -1518,18 +1522,27 @@ def ravanbakhsh_set_layer(layer_size, containing the sequences of input vectors. mask: A tensor of dimensions batch_size x sequence_length containing a mask for the inputs with 1's for existing elements, and 0's elsewhere. - activation_fn: The activation function to use. + sequential: If true, will use a running global pool so each element will + only depend on those before it. Set true if this layer is being used in + an ouput sequence. Outputs output: A tensor of dimensions batch_size x sequence_length x vector dimension containing the sequences of transformed vectors. """ with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]): - output = linear_set_layer( - layer_size, - inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1), - activation_fn=activation_fn, - name=name) + if sequential: + output = linear_set_layer( + layer_size, + inputs - running_global_pool_1d(inputs), + activation_fn=activation_fn, + name=name) + else: + output = linear_set_layer( + layer_size, + inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1), + activation_fn=activation_fn, + name=name) return output diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 544035efd..0b6c97153 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -162,7 +162,8 @@ def transformer_encoder(encoder_input, hparams.num_heads, hparams.attention_dropout, summaries=summaries, - name="encoder_self_attention")) + name="encoder_self_attention", + window_size=20)) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index b6c2adc74..5ea6942a4 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -59,9 +59,8 @@ def model_fn_body(self, features): transformer_prepare_encoder(inputs, target_space, hparams) ) (decoder_input, decoder_self_attention_bias) = transformer.\ transformer_prepare_decoder(targets, hparams) - + encoder_mask = bias_to_mask(encoder_attention_bias) - decoder_mask = bias_to_mask(decoder_self_attention_bias) def residual_fn(x, y): return common_layers.layer_norm(x + tf.nn.dropout( @@ -69,11 +68,12 @@ def residual_fn(x, y): encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout) + encoder_output = alt_transformer_encoder( encoder_input, residual_fn, encoder_mask, hparams) decoder_output = alt_transformer_decoder( - decoder_input, encoder_output, residual_fn, decoder_mask, + decoder_input, encoder_output, residual_fn, encoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) @@ -93,6 +93,7 @@ def composite_layer(inputs, mask, hparams, for_output=False): hparams.hidden_size, x, mask=mask, + sequential=for_output, dropout=hparams.relu_dropout) # Transforms elements to get a context, and then uses this in a final layer @@ -127,12 +128,11 @@ def alt_transformer_encoder(encoder_input, hparams, name="encoder"): - x = encoder_input - # Summaries don't work in multi-problem setting yet. summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 - + with tf.variable_scope(name): + x = encoder_input for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x = residual_fn(x, composite_layer(x, mask, hparams)) @@ -143,16 +143,15 @@ def alt_transformer_encoder(encoder_input, def alt_transformer_decoder(decoder_input, encoder_output, residual_fn, - mask, encoder_decoder_attention_bias, hparams, name="decoder"): - x = decoder_input - # Summaries don't work in multi-problem setting yet. summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 + with tf.variable_scope(name): + x = decoder_input for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -168,7 +167,7 @@ def alt_transformer_decoder(decoder_input, summaries=summaries, name="encdec_attention") - x_ = residual_fn(x_, composite_layer(x_, mask, hparams, for_output=True)) + x_ = residual_fn(x_, composite_layer(x_, None, hparams, for_output=True)) x = residual_fn(x, x_) return x @@ -177,6 +176,7 @@ def bias_to_mask(bias): # We need masks of the form batch size x input sequences # Biases seem to be of the form batch_size x 1 x input sequences x vec dim # Squeeze out dim one, and get the first element of each vector + bias = tf.squeeze(bias, [1])[:,:,0] bias = - tf.clip_by_value(bias, -1.0, 1.0) mask = 1 - bias @@ -189,8 +189,9 @@ def transformer_alt(): """Set of hyperparameters.""" hparams = transformer.transformer_base() hparams.batch_size = 2048 + hparams.num_hidden_layers = 3 hparams.add_hparam("layers_per_layer", 4) - #hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding - hparams.add_hparam("composite_layer_type", "reembedding") + hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding + #hparams.add_hparam("composite_layer_type", "reembedding") return hparams From 37e7dedf22063d3e1e1cc965a8b98e29ce5964a6 Mon Sep 17 00:00:00 2001 From: William Date: Fri, 28 Jul 2017 14:06:55 +0100 Subject: [PATCH 5/7] Updated sliding window --- .gitignore | 1 + tensor2tensor/models/common_attention.py | 263 +++++++++++++++++++---- tensor2tensor/models/models.py | 1 + tensor2tensor/models/transformer.py | 3 +- 4 files changed, 229 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index c9dd3db88..fbd98dca5 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ _pycache__/ # PyPI distribution artifacts. build/ dist/ +data/ # Sublime project files *.sublime-project diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index e8700433a..c1d469eb1 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -345,55 +345,244 @@ def dot_product_attention(q, return tf.matmul(weights, v) -def sliding_window_attention(window_size, - q, - k, - v, - bias, - *args): - """ Sliding window wrapper for dot product attention. Each element only - attends to the elements (window_size/2) before and after it. This reduces +def masked_local_attention_1d( + q, k, v, block_length=128, name=None): + """Attention to the source position and a neigborhood to the left of it. + + The sequence is divided into blocks of length block_size. + Attention for a given query position can only see memory positions + less than or equal to the query position, in the corresponding block + and the previous block. + + If mask_right is True, then a target position cannot see greater source + positions. + + Args: + q: a Tensor with shape [batch, heads, length, depth_k] + k: a Tensor with shape [batch, heads, length, depth_k] + v: a Tensor with shape [batch, heads, length, depth_v] + block_length: an integer + name: an optional string + + Returns: + a Tensor of shape [batch, heads, length, depth_v] + """ + with tf.variable_scope(name, default_name="local_attention_1d", + values=[q, k, v]): + v_shape = v.get_shape() + batch = tf.shape(q)[0] + heads = tf.shape(q)[1] + length = tf.shape(q)[2] + # If (length < 2 * block_length), then we use only one block. + block_length = tf.where(tf.less(length, block_length * 2), + length, block_length) + depth_k = tf.shape(q)[3] + depth_v = tf.shape(v)[3] + original_length = length + padding_size = tf.mod(-length, block_length) + length += padding_size + padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]] + q = tf.pad(q, padding) + k = tf.pad(k, padding) + v = tf.pad(v, padding) + num_blocks = tf.div(length, block_length) + + # compute attention for the first query block. + first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_output = dot_product_attention( + first_q, first_k, first_v, attention_bias_lower_triangle(block_length), + name="fist_block") + + # compute attention for all subsequent query blocks. + q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k]) + k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k]) + v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v]) + + def local(x): + """Create a local version of the keys or values.""" + prev_block = tf.slice( + x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1]) + cur_block = tf.slice( + x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) + return tf.concat([prev_block, cur_block], 3) + local_k = local(k) + local_v = local(v) + tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) + + local_length = tf.shape(local_k)[3] + + # [batch, heads, num_blocks - 1, block_length, local_length] + attention = tf.matmul(tail_q, local_k, transpose_b=True) + + # make sure source_pos <= target_pos + good_part = tf.matrix_band_part( + tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) + mask = (1.0 - good_part) * -1e9 + attention += tf.reshape(mask, [1, 1, 1, block_length, local_length]) + attention = tf.nn.softmax(attention) + # TODO(noam): figure out how to show a summary for the remaining blocks. + # The naive way currently causes errors due to empty tensors. + # output: [batch, heads, num_blocks-1, block_length, depth_v] + output = tf.matmul(attention, local_v) + output = tf.reshape(output, [batch, heads, -1, depth_v]) + output = tf.concat([first_output, output], axis=2) + output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) + output.set_shape(v_shape) + return output + + +def unmasked_local_attention_1d(q, k, v, block_length=128, filter_width=100, + name=None): + """strided block local self-attention. + + Args: + q: a Tensor with shape [batch, heads, length, depth_k] + k: a Tensor with shape [batch, heads, length, depth_k] + v: a Tensor with shape [batch, heads, length, depth_v] + block_length: an integer + filter_width: an integer indicating how much to look left. + name: an optional string + + Returns: + a Tensor of shape [batch, heads, length, depth_v] + """ + with tf.variable_scope(name, default_name="local_self_attention_1d", + values=[q, k, v]): + v_shape = v.get_shape() + depth_v = tf.shape(v)[3] + batch_size = tf.shape(q)[0] + num_heads = tf.shape(q)[1] + original_length = tf.shape(q)[2] + # making sure q is a multiple of d + def pad_to_multiple(x, pad_length): + x_length = tf.shape(x)[2] + return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]]) + def pad_l_and_r(x, pad_length): + return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]]) + q = pad_to_multiple(q, block_length) + k = pad_to_multiple(k, block_length) + v = pad_to_multiple(v, block_length) + + # Setting up q blocks + new_q_shape = tf.shape(q) + # Setting up q blocks + q = tf.reshape(q, [new_q_shape[0], new_q_shape[1], + new_q_shape[2]//block_length, + block_length, new_q_shape[3]]) + + # Setting up k and v values + k = pad_l_and_r(k, filter_width) + v = pad_l_and_r(v, filter_width) + + length = tf.shape(k)[2] + full_filter_width = block_length + 2*filter_width + # getting gather indices + indices = tf.range(0, length, delta=1, name="index_range") + # making indices [1, length, 1] to appy convs + indices = tf.reshape(indices, [1, -1, 1]) + kernel = tf.expand_dims(tf.eye(full_filter_width), axis=1) + gather_indices = tf.nn.conv1d( + tf.cast(indices, tf.float32), + kernel, + block_length, + padding="VALID", + name="gather_conv") + + gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0) + + # [length, batch, heads, dim] + k_t = tf.transpose(k, [2, 0, 1, 3]) + k_new = tf.gather(k_t, gather_indices) + + # [batch, heads, blocks, block_length, dim] + k_new = tf.transpose(k_new, [2, 3, 0, 1, 4]) + + attention_bias = tf.expand_dims( + tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) + + v_t = tf.transpose(v, [2, 0, 1, 3]) + v_new = tf.gather(v_t, gather_indices) + v_new = tf.transpose(v_new, [2, 3, 0, 1, 4]) + + logits = tf.matmul(q, k_new, transpose_b=True) + + attention = tf.nn.softmax(logits+attention_bias) + output = tf.matmul(attention, v_new) + + output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) + # Remove the padding if introduced + output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) + output.set_shape(v_shape) + return output + + +def windowed_local_attention_1d(q, + k, + v, + window_start, + window_end, + bias, + *args): + """ Local window wrapper for dot product attention. Each element only + attends to the elements from window_start to window_end. This reduces the computational complexity for long sequences at the expense of eliminating long-term dependencies. N.B: For short input sequences this is much slower than just using - un-windowed attention. use only for long sequences. + un-windowed attention. Use only for long sequences. Args: window_size: an integer q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] + window_start: an integer Tensor with shape [length_q] + window_end: an integer Tensor with shape [length_q] bias: bias Tensor (see attention_bias()) Returns: A Tensor. """ - - half_size = window_size // 2 - - # Wrapper function for dot product attention with a single query vector - def single(index, size, q, k, v, bias, **kwargs): - length_kv = tf.shape(k)[2] - index_begin = tf.maximum(0, index-size) - index_end = tf.minimum(length_kv-1, index+size) - q = tf.expand_dims(q, 2) - bias = tf.expand_dims(bias, 3) - k = k[:,:,index_begin:index_end,:] - v = v[:,:,index_begin:index_end,:] - out = dot_product_attention(q, k, v, bias, *args) - out = tf.squeeze(out, 2) + with tf.name_scope("windowed"): + + # Wrapper function for dot product attention with a single query vector + def single(index_begin, index_end, q, k, v, bias): + #Normalise range + #Reshape to right shape + q = tf.expand_dims(q, 2) + bias = tf.expand_dims(bias, 3) + #Get slices + k = k[:,:,index_begin:index_end,:] + v = v[:,:,index_begin:index_end,:] + out = dot_product_attention(q, k, v, bias, *args) + out = tf.squeeze(out, 2) + return out + + # We'll loop over each element of q, computing its corresponding output. + q = tf.transpose(q, [2, 0, 1, 3]) + indices = tf.range(tf.shape(q)[0]) + out = tf.map_fn( + lambda ii: single( + window_start[ii], + window_end[ii], + q[ii], + k, + v, + bias[:,:,:,ii]), + indices, + dtype=tf.float32) + out = tf.transpose(out, [1, 2, 0, 3]) return out - - # We'll loop over each element of q, computing it's corresponding output. - q = tf.transpose(q, [2, 0, 1, 3]) - indices = tf.range(tf.shape(q)[0]) - out = tf.map_fn( - lambda ii: single(ii, half_size, q[ii], k, v, bias[:,:,:,ii]), - indices, - dtype=tf.float32) - out = tf.transpose(out, [1, 2, 0, 3]) - return out + + +def local_sliding_window(length, window_size, look_right=True): + indices = tf.range(length) + size = window_size + starts = tf.maximum(0, indices-size) + ends = tf.minimum(length-1, indices+size) + return starts, ends def multihead_attention(query_antecedent, @@ -420,8 +609,6 @@ def multihead_attention(query_antecedent, num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number summaries: a boolean - image_shapes: optional tuple of integer scalars. - see comments for attention_image_summary() window_size: option size of window for attention. Useful only for very long sequence lengths. name: an optional string @@ -461,8 +648,10 @@ def multihead_attention(query_antecedent, x = dot_product_attention( q, k, v, bias, dropout_rate, summaries, image_shapes) else: - x = sliding_window_attention( - window_size, q, k, v, bias, dropout_rate, False, image_shapes) + length = tf.shape(k)[2] + window_start, window_end = local_sliding_window(length, window_size) + x = windowed_local_attention_1d( + q, k, v, window_start, window_end, bias, dropout_rate, False) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index b8f0811e5..ae0e0da61 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -32,5 +32,6 @@ from tensor2tensor.models import neural_gpu from tensor2tensor.models import slicenet from tensor2tensor.models import transformer +from tensor2tensor.models import transformer_alternative from tensor2tensor.models import xception # pylint: enable=unused-import diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 0b6c97153..544035efd 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -162,8 +162,7 @@ def transformer_encoder(encoder_input, hparams.num_heads, hparams.attention_dropout, summaries=summaries, - name="encoder_self_attention", - window_size=20)) + name="encoder_self_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x From d6a6924886b78f1f8f75d27b523d9140fedc3e10 Mon Sep 17 00:00:00 2001 From: William Date: Fri, 28 Jul 2017 18:10:55 +0100 Subject: [PATCH 6/7] Added middle window for local attention --- tensor2tensor/models/common_attention.py | 160 +++++------------- tensor2tensor/models/common_attention_test.py | 64 +++++++ 2 files changed, 110 insertions(+), 114 deletions(-) create mode 100644 tensor2tensor/models/common_attention_test.py diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index c1d469eb1..abf989402 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -346,7 +346,7 @@ def dot_product_attention(q, def masked_local_attention_1d( - q, k, v, block_length=128, name=None): + q, k, v, block_length=128, mask_right=False, name=None): """Attention to the source position and a neigborhood to the left of it. The sequence is divided into blocks of length block_size. @@ -362,6 +362,7 @@ def masked_local_attention_1d( k: a Tensor with shape [batch, heads, length, depth_k] v: a Tensor with shape [batch, heads, length, depth_v] block_length: an integer + mask_right: a bool name: an optional string Returns: @@ -373,150 +374,76 @@ def masked_local_attention_1d( batch = tf.shape(q)[0] heads = tf.shape(q)[1] length = tf.shape(q)[2] - # If (length < 2 * block_length), then we use only one block. - block_length = tf.where(tf.less(length, block_length * 2), - length, block_length) depth_k = tf.shape(q)[3] depth_v = tf.shape(v)[3] + original_length = length + + # If (length < 2 * block_length), then we use only one block. + block_length = tf.where(tf.less(length, block_length * 2), + length, block_length) padding_size = tf.mod(-length, block_length) length += padding_size + num_blocks = tf.div(length, block_length) + padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]] q = tf.pad(q, padding) - k = tf.pad(k, padding) - v = tf.pad(v, padding) - num_blocks = tf.div(length, block_length) - # compute attention for the first query block. - first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1]) - first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1]) - first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1]) - first_output = dot_product_attention( - first_q, first_k, first_v, attention_bias_lower_triangle(block_length), - name="fist_block") + if mask_right: + #Add extra padding so we son't have to do an initial query + extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]] + else: + #We shift everything over by half a block so query is in centre + pad_right = block_length // 2 + pad_left = block_length - pad_right + extra_padding = [[0, 0], [0, 0], + [pad_left,padding_size+pad_right], [0, 0]] + + k = tf.pad(k, extra_padding) + v = tf.pad(v, extra_padding) + # compute attention for all subsequent query blocks. q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k]) - k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k]) - v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v]) + k = tf.reshape(k, [batch, heads, num_blocks+1, block_length, depth_k]) + v = tf.reshape(v, [batch, heads, num_blocks+1, block_length, depth_v]) def local(x): """Create a local version of the keys or values.""" prev_block = tf.slice( - x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1]) + x, [0, 0, 0, 0, 0], [-1, -1, num_blocks, -1, -1]) cur_block = tf.slice( x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) return tf.concat([prev_block, cur_block], 3) + local_k = local(k) local_v = local(v) - tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) local_length = tf.shape(local_k)[3] - # [batch, heads, num_blocks - 1, block_length, local_length] - attention = tf.matmul(tail_q, local_k, transpose_b=True) + # [batch, heads, num_blocks, block_length, local_length] + attention = tf.matmul(q, local_k, transpose_b=True) - # make sure source_pos <= target_pos good_part = tf.matrix_band_part( - tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) - mask = (1.0 - good_part) * -1e9 - attention += tf.reshape(mask, [1, 1, 1, block_length, local_length]) + tf.ones([block_length, local_length]), 0, tf.to_int64(block_length)) + + good_part = tf.cast(good_part, tf.float64) + attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length]) attention = tf.nn.softmax(attention) - # TODO(noam): figure out how to show a summary for the remaining blocks. - # The naive way currently causes errors due to empty tensors. - # output: [batch, heads, num_blocks-1, block_length, depth_v] + output = tf.matmul(attention, local_v) output = tf.reshape(output, [batch, heads, -1, depth_v]) - output = tf.concat([first_output, output], axis=2) + + # remove added padding output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) output.set_shape(v_shape) return output -def unmasked_local_attention_1d(q, k, v, block_length=128, filter_width=100, - name=None): - """strided block local self-attention. - - Args: - q: a Tensor with shape [batch, heads, length, depth_k] - k: a Tensor with shape [batch, heads, length, depth_k] - v: a Tensor with shape [batch, heads, length, depth_v] - block_length: an integer - filter_width: an integer indicating how much to look left. - name: an optional string - Returns: - a Tensor of shape [batch, heads, length, depth_v] - """ - with tf.variable_scope(name, default_name="local_self_attention_1d", - values=[q, k, v]): - v_shape = v.get_shape() - depth_v = tf.shape(v)[3] - batch_size = tf.shape(q)[0] - num_heads = tf.shape(q)[1] - original_length = tf.shape(q)[2] - # making sure q is a multiple of d - def pad_to_multiple(x, pad_length): - x_length = tf.shape(x)[2] - return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]]) - def pad_l_and_r(x, pad_length): - return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]]) - q = pad_to_multiple(q, block_length) - k = pad_to_multiple(k, block_length) - v = pad_to_multiple(v, block_length) - - # Setting up q blocks - new_q_shape = tf.shape(q) - # Setting up q blocks - q = tf.reshape(q, [new_q_shape[0], new_q_shape[1], - new_q_shape[2]//block_length, - block_length, new_q_shape[3]]) - - # Setting up k and v values - k = pad_l_and_r(k, filter_width) - v = pad_l_and_r(v, filter_width) - - length = tf.shape(k)[2] - full_filter_width = block_length + 2*filter_width - # getting gather indices - indices = tf.range(0, length, delta=1, name="index_range") - # making indices [1, length, 1] to appy convs - indices = tf.reshape(indices, [1, -1, 1]) - kernel = tf.expand_dims(tf.eye(full_filter_width), axis=1) - gather_indices = tf.nn.conv1d( - tf.cast(indices, tf.float32), - kernel, - block_length, - padding="VALID", - name="gather_conv") - - gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0) - - # [length, batch, heads, dim] - k_t = tf.transpose(k, [2, 0, 1, 3]) - k_new = tf.gather(k_t, gather_indices) - - # [batch, heads, blocks, block_length, dim] - k_new = tf.transpose(k_new, [2, 3, 0, 1, 4]) - - attention_bias = tf.expand_dims( - tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) - - v_t = tf.transpose(v, [2, 0, 1, 3]) - v_new = tf.gather(v_t, gather_indices) - v_new = tf.transpose(v_new, [2, 3, 0, 1, 4]) - - logits = tf.matmul(q, k_new, transpose_b=True) - - attention = tf.nn.softmax(logits+attention_bias) - output = tf.matmul(attention, v_new) - - output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) - # Remove the padding if introduced - output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) - output.set_shape(v_shape) - return output +############################################################################### +### Not used, left in for reference ########################################### def windowed_local_attention_1d(q, k, @@ -556,12 +483,13 @@ def single(index_begin, index_end, q, k, v, bias): #Get slices k = k[:,:,index_begin:index_end,:] v = v[:,:,index_begin:index_end,:] - out = dot_product_attention(q, k, v, bias, *args) + out = dot_product_attention(q, k, v, *args) out = tf.squeeze(out, 2) return out # We'll loop over each element of q, computing its corresponding output. q = tf.transpose(q, [2, 0, 1, 3]) + bias = tf.transpose(bias, [3, 0, 1, 2]) indices = tf.range(tf.shape(q)[0]) out = tf.map_fn( lambda ii: single( @@ -570,13 +498,12 @@ def single(index_begin, index_end, q, k, v, bias): q[ii], k, v, - bias[:,:,:,ii]), + bias[ii]), indices, dtype=tf.float32) out = tf.transpose(out, [1, 2, 0, 3]) return out - def local_sliding_window(length, window_size, look_right=True): indices = tf.range(length) size = window_size @@ -584,6 +511,11 @@ def local_sliding_window(length, window_size, look_right=True): ends = tf.minimum(length-1, indices+size) return starts, ends +### ### +############################################################################### + + + def multihead_attention(query_antecedent, memory_antecedent, @@ -648,7 +580,7 @@ def multihead_attention(query_antecedent, x = dot_product_attention( q, k, v, bias, dropout_rate, summaries, image_shapes) else: - length = tf.shape(k)[2] + length = tf.shape(q)[2] window_start, window_end = local_sliding_window(length, window_size) x = windowed_local_attention_1d( q, k, v, window_start, window_end, bias, dropout_rate, False) diff --git a/tensor2tensor/models/common_attention_test.py b/tensor2tensor/models/common_attention_test.py new file mode 100644 index 000000000..14754794c --- /dev/null +++ b/tensor2tensor/models/common_attention_test.py @@ -0,0 +1,64 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for common layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np +from tensor2tensor.models import common_attention + +import tensorflow as tf + + +class CommonAttentionTest(tf.test.TestCase): + + def testLocalAttention(self): + #q = np.array([[[ [1.0, 0.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0] ]]]) + #k = np.array([[[ [0.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 0.0] ]]]) + #v = np.ones((1, 1, 8, 1)) + + q = np.random.rand(5, 7, 13, 3) + k = np.random.rand(5, 7, 13, 3) + v = np.random.rand(5, 7, 13, 11) + + with self.test_session() as session: + q_ = tf.constant(q) + k_ = tf.constant(k) + v_ = tf.constant(v) + y = common_attention.masked_local_attention_1d(q_, k_, v_, block_length=tf.constant(3)) + res = session.run(y) + self.assertEqual(res.shape, (5, 7, 13, 11)) + + +if __name__ == "__main__": + tf.test.main() From 2ced78dbb2a9bb921ebd3e327c704efb790dc140 Mon Sep 17 00:00:00 2001 From: William Date: Tue, 1 Aug 2017 11:08:24 +0100 Subject: [PATCH 7/7] Unify methods and started work on Bias --- tensor2tensor/models/common_attention.py | 207 ++++++++---------- tensor2tensor/models/common_attention_test.py | 62 ++++-- .../models/transformer_alternative.py | 7 +- 3 files changed, 138 insertions(+), 138 deletions(-) diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index abf989402..2004e1bac 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -345,24 +345,34 @@ def dot_product_attention(q, return tf.matmul(weights, v) -def masked_local_attention_1d( - q, k, v, block_length=128, mask_right=False, name=None): - """Attention to the source position and a neigborhood to the left of it. - The sequence is divided into blocks of length block_size. - Attention for a given query position can only see memory positions - less than or equal to the query position, in the corresponding block - and the previous block. +def local_attention_1d(q, k, v, bias=None, + block_length=128, look_right=True, use_whole_block=False, + truncate_bias=True, name=None): + """Attention to the source position and a neigborhood around it. - If mask_right is True, then a target position cannot see greater source + The sequence is divided into blocks of length block_size. Attention for a + given query position can only see memory positions within a certain number + of positions before and behind it. + + If look_right is True then each query will attend to block_length//2 + positions either side, otherwise it will attend to block_length previous positions. + If use_whole_block is True then no mask will be applied to the local blocks + meaning the full blocks are used (if look_right is True then the elements to + the right of the current position are still masked out). This allows use to + attend to more elements without additional overhead, but means we have + inconsistent window positions and sizes. + Args: - q: a Tensor with shape [batch, heads, length, depth_k] - k: a Tensor with shape [batch, heads, length, depth_k] - v: a Tensor with shape [batch, heads, length, depth_v] + q: a Tensor with shape [batch, heads, length_q, depth_k] + k: a Tensor with shape [batch, heads, length_kv, depth_k] + v: a Tensor with shape [batch, heads, length_kv, depth_v] + bias: Not currently used [batch, heads, length_q, length_k] block_length: an integer - mask_right: a bool + look_right: a bool + use_whole_block: a bool name: an optional string Returns: @@ -379,8 +389,9 @@ def masked_local_attention_1d( original_length = length - # If (length < 2 * block_length), then we use only one block. - block_length = tf.where(tf.less(length, block_length * 2), + #Pad to desired length + #If (length < 2 * block_length), then we use only one block. + block_length = tf.where(tf.less(length, block_length), length, block_length) padding_size = tf.mod(-length, block_length) length += padding_size @@ -389,25 +400,27 @@ def masked_local_attention_1d( padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]] q = tf.pad(q, padding) - if mask_right: + if not look_right: #Add extra padding so we son't have to do an initial query extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]] + bp = [[0, 0], [0, 0], [0, padding_size], [block_length, padding_size]] else: #We shift everything over by half a block so query is in centre pad_right = block_length // 2 pad_left = block_length - pad_right extra_padding = [[0, 0], [0, 0], - [pad_left,padding_size+pad_right], [0, 0]] - + [pad_left, padding_size+pad_right], [0, 0]] + bp = [[0, 0], [0, 0], + [0, padding_size], [pad_left, padding_size+pad_right]] k = tf.pad(k, extra_padding) v = tf.pad(v, extra_padding) - - # compute attention for all subsequent query blocks. + # Reshape into blocks q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k]) k = tf.reshape(k, [batch, heads, num_blocks+1, block_length, depth_k]) v = tf.reshape(v, [batch, heads, num_blocks+1, block_length, depth_v]) + # Get local blocks by slicing def local(x): """Create a local version of the keys or values.""" prev_block = tf.slice( @@ -415,108 +428,72 @@ def local(x): cur_block = tf.slice( x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) return tf.concat([prev_block, cur_block], 3) - local_k = local(k) local_v = local(v) - local_length = tf.shape(local_k)[3] # [batch, heads, num_blocks, block_length, local_length] attention = tf.matmul(q, local_k, transpose_b=True) + + # Apply bias (N.B: This is not currently working) + if bias is not None: + with tf.name_scope('bias'): + b_batch = tf.shape(bias)[0] + b_heads = tf.shape(bias)[1] + bias_ = bias + #bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0) + if truncate_bias: + # Use only the query dimension + bias = tf.expand_dims(bias[:,:,:,0], 2) + bias = tf.pad(bias, extra_padding, name='bias_pad_b')# 17, 5, 3 + bias = tf.reshape(bias, + [b_batch, b_heads, 1, num_blocks+1, block_length], + name='divide_blocks') + local_b = tf.reshape(local(bias), + [b_batch, b_heads, num_blocks, 1, -1], name='reshape_local') + else: + bias = tf.pad(bias, bp, name='pad') + bias = tf.reshape(bias, + [b_batch, b_heads, num_blocks, block_length, + num_blocks+1, block_length], name='divide_blocks') + bias = tf.transpose(bias, [4,2,0,1,3,5]) + bias = tf.reshape(bias, + [num_blocks*(num_blocks+1), b_batch, b_heads, + block_length, block_length], name='combine') + indices = (num_blocks+1)*tf.range(num_blocks) + prev_block = tf.gather(bias, indices) + cur_block = tf.gather(bias, indices+num_blocks) + local_b = tf.concat([prev_block, cur_block], 4) + local_b = tf.transpose(local_b, [1,2,0,3,4]) + return l-local_b + attention += local_b + + attention = tf.nn.softmax(attention) + + # Get local mask + if not use_whole_block: + good_part = tf.matrix_band_part( + tf.ones([block_length, local_length]), 0, tf.to_int64(block_length)) + elif not look_right: + good_part = tf.matrix_band_part( + tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) + else: + good_part = tf.ones([block_length, local_length]) - good_part = tf.matrix_band_part( - tf.ones([block_length, local_length]), 0, tf.to_int64(block_length)) - - good_part = tf.cast(good_part, tf.float64) + #good_part = tf.cast(good_part, tf.float64) attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length]) - attention = tf.nn.softmax(attention) + output = tf.matmul(attention, local_v) output = tf.reshape(output, [batch, heads, -1, depth_v]) - # remove added padding + # Remove added padding output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) output.set_shape(v_shape) return output - -############################################################################### -### Not used, left in for reference ########################################### - -def windowed_local_attention_1d(q, - k, - v, - window_start, - window_end, - bias, - *args): - """ Local window wrapper for dot product attention. Each element only - attends to the elements from window_start to window_end. This reduces - the computational complexity for long sequences at the expense of eliminating - long-term dependencies. - - N.B: For short input sequences this is much slower than just using - un-windowed attention. Use only for long sequences. - - Args: - window_size: an integer - q: a Tensor with shape [batch, heads, length_q, depth_k] - k: a Tensor with shape [batch, heads, length_kv, depth_k] - v: a Tensor with shape [batch, heads, length_kv, depth_v] - window_start: an integer Tensor with shape [length_q] - window_end: an integer Tensor with shape [length_q] - bias: bias Tensor (see attention_bias()) - - Returns: - A Tensor. - """ - with tf.name_scope("windowed"): - - # Wrapper function for dot product attention with a single query vector - def single(index_begin, index_end, q, k, v, bias): - #Normalise range - #Reshape to right shape - q = tf.expand_dims(q, 2) - bias = tf.expand_dims(bias, 3) - #Get slices - k = k[:,:,index_begin:index_end,:] - v = v[:,:,index_begin:index_end,:] - out = dot_product_attention(q, k, v, *args) - out = tf.squeeze(out, 2) - return out - - # We'll loop over each element of q, computing its corresponding output. - q = tf.transpose(q, [2, 0, 1, 3]) - bias = tf.transpose(bias, [3, 0, 1, 2]) - indices = tf.range(tf.shape(q)[0]) - out = tf.map_fn( - lambda ii: single( - window_start[ii], - window_end[ii], - q[ii], - k, - v, - bias[ii]), - indices, - dtype=tf.float32) - out = tf.transpose(out, [1, 2, 0, 3]) - return out - -def local_sliding_window(length, window_size, look_right=True): - indices = tf.range(length) - size = window_size - starts = tf.maximum(0, indices-size) - ends = tf.minimum(length-1, indices+size) - return starts, ends - -### ### -############################################################################### - - - - def multihead_attention(query_antecedent, memory_antecedent, bias, @@ -527,7 +504,8 @@ def multihead_attention(query_antecedent, dropout_rate, summaries=False, image_shapes=None, - window_size=None, + attention_type="dot_product", + block_length=128, name=None): """Multihead scaled-dot-product attention with input/output transformations. @@ -540,9 +518,11 @@ def multihead_attention(query_antecedent, output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number - summaries: a boolean - window_size: option size of window for attention. Useful only for very long - sequence lengths. + image_shapes: optional tuple of integer scalars. + see comments for attention_image_summary() + attention_type: a string, either "dot_product" or "local" or + "local_mask_right" + block_length: an integer - relevant for "local_mask_right" name: an optional string Returns: @@ -576,14 +556,15 @@ def multihead_attention(query_antecedent, v = split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 - if window_size is None: + if attention_type == "dot_product": x = dot_product_attention( - q, k, v, bias, dropout_rate, summaries, image_shapes) + q, k, v, bias, dropout_rate, image_shapes) + elif attention_type == "local": + x = local_attention_1d(q, k, v, block_length=block_length) else: - length = tf.shape(q)[2] - window_start, window_end = local_sliding_window(length, window_size) - x = windowed_local_attention_1d( - q, k, v, window_start, window_end, bias, dropout_rate, False) + assert attention_type == "local_mask_right" + x = local_attention_1d( + q, k, v, block_length=block_length, look_right=False) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x diff --git a/tensor2tensor/models/common_attention_test.py b/tensor2tensor/models/common_attention_test.py index 14754794c..2e534ba1a 100644 --- a/tensor2tensor/models/common_attention_test.py +++ b/tensor2tensor/models/common_attention_test.py @@ -29,35 +29,53 @@ class CommonAttentionTest(tf.test.TestCase): def testLocalAttention(self): - #q = np.array([[[ [1.0, 0.0, 0.0, 0.0], - # [1.0, 0.0, 0.0, 0.0], - # [1.0, 0.0, 0.0, 0.0], - # [1.0, 0.0, 0.0, 0.0], - # [1.0, 0.0, 0.0, 0.0], - # [1.0, 0.0, 0.0, 0.0], - # [1.0, 0.0, 0.0, 0.0], - # [1.0, 0.0, 0.0, 0.0] ]]]) - #k = np.array([[[ [0.0, 0.0, 0.0, 0.0], - # [0.0, 0.0, 0.0, 0.0], - # [0.0, 0.0, 0.0, 0.0], - # [0.0, 0.0, 0.0, 0.0], - # [0.0, 0.0, 0.0, 0.0], - # [0.0, 0.0, 0.0, 0.0], - # [0.0, 0.0, 0.0, 0.0], - # [0.0, 0.0, 0.0, 0.0] ]]]) - #v = np.ones((1, 1, 8, 1)) + q = np.array([[[ [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0] ]]]) - q = np.random.rand(5, 7, 13, 3) - k = np.random.rand(5, 7, 13, 3) - v = np.random.rand(5, 7, 13, 11) + k = np.array([[[ [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0] ]]]) + + b = np.array([[[ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ]]]) + + #b = np.ones((1,1,8,8)) + #b = (1-b) * (-1e9) + v = np.ones((1, 1, 8, 1)) + + #q = np.random.rand(5, 7, 13, 3) + #k = np.random.rand(5, 7, 13, 3) + #v = np.random.rand(5, 7, 13, 11) + #b = np.random.rand(5, 1, 13, 1) with self.test_session() as session: q_ = tf.constant(q) k_ = tf.constant(k) v_ = tf.constant(v) - y = common_attention.masked_local_attention_1d(q_, k_, v_, block_length=tf.constant(3)) + b_ = tf.constant(b) + y = common_attention.local_attention_1d(q_, k_, v_, b_, block_length=tf.constant(2)) res = session.run(y) - self.assertEqual(res.shape, (5, 7, 13, 11)) + #print(q) + #rint(k) + print(res) + #self.assertEqual(res.shape, (5, 7, 13, 11)) if __name__ == "__main__": diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index 5ea6942a4..d0e04f078 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -174,8 +174,9 @@ def alt_transformer_decoder(decoder_input, def bias_to_mask(bias): # We need masks of the form batch size x input sequences - # Biases seem to be of the form batch_size x 1 x input sequences x vec dim - # Squeeze out dim one, and get the first element of each vector + # Biases are of the form batch_size x num_heads x input sequences x + # output sequences. Squeeze out dim one, and get the first element of + # each vector. bias = tf.squeeze(bias, [1])[:,:,0] bias = - tf.clip_by_value(bias, -1.0, 1.0) @@ -189,7 +190,7 @@ def transformer_alt(): """Set of hyperparameters.""" hparams = transformer.transformer_base() hparams.batch_size = 2048 - hparams.num_hidden_layers = 3 + hparams.num_hidden_layers = 10 hparams.add_hparam("layers_per_layer", 4) hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding #hparams.add_hparam("composite_layer_type", "reembedding")