Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #176 from EndingCredits/master
Browse files Browse the repository at this point in the history
Alternative Transformer Fix + Slding Window Attention
  • Loading branch information
lukaszkaiser committed Aug 1, 2017
2 parents 69e40fb + af52f5f commit 0df0f50
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 222 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ _pycache__/
# PyPI distribution artifacts.
build/
dist/
data/

# Sublime project files
*.sublime-project
Expand Down
241 changes: 108 additions & 133 deletions tensor2tensor/models/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,23 +344,33 @@ def dot_product_attention(q,
return tf.matmul(weights, v)


def masked_local_attention_1d(
q, k, v, block_length=128, name=None):
"""Attention to the source position and a neigborhood to the left of it.
def local_attention_1d(q, k, v, bias=None,
block_length=128, look_right=True, use_whole_block=False,
truncate_bias=True, name=None):
"""Attention to the source position and a neigborhood around it.
The sequence is divided into blocks of length block_size.
Attention for a given query position can only see memory positions
less than or equal to the query position, in the corresponding block
and the previous block.
The sequence is divided into blocks of length block_size. Attention for a
given query position can only see memory positions within a certain number
of positions before and behind it.
If mask_right is True, then a target position cannot see greater source
If look_right is True then each query will attend to block_length//2
positions either side, otherwise it will attend to block_length previous
positions.
If use_whole_block is True then no mask will be applied to the local blocks
meaning the full blocks are used (if look_right is True then the elements to
the right of the current position are still masked out). This allows use to
attend to more elements without additional overhead, but means we have
inconsistent window positions and sizes.
Args:
q: a Tensor with shape [batch, heads, length, depth_k]
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
q: a Tensor with shape [batch, heads, length_q, depth_k]
k: a Tensor with shape [batch, heads, length_kv, depth_k]
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: Not currently used [batch, heads, length_q, length_k]
block_length: an integer
look_right: a bool
use_whole_block: a bool
name: an optional string
Returns:
Expand All @@ -372,146 +382,110 @@ def masked_local_attention_1d(
batch = tf.shape(q)[0]
heads = tf.shape(q)[1]
length = tf.shape(q)[2]
# If (length < 2 * block_length), then we use only one block.
block_length = tf.where(tf.less(length, block_length * 2),
length, block_length)
depth_k = tf.shape(q)[3]
depth_v = tf.shape(v)[3]

original_length = length

#Pad to desired length
#If (length < block_length), then we use only one block.
block_length = tf.where(tf.less(length, block_length),
length, block_length)
padding_size = tf.mod(-length, block_length)
length += padding_size
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
q = tf.pad(q, padding)
k = tf.pad(k, padding)
v = tf.pad(v, padding)
num_blocks = tf.div(length, block_length)

# compute attention for the first query block.
first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_output = dot_product_attention(
first_q, first_k, first_v, attention_bias_lower_triangle(block_length),
name="fist_block")
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
q = tf.pad(q, padding)

# compute attention for all subsequent query blocks.
if not look_right:
#Add extra padding so we son't have to do an initial query
extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]]
bp = [[0, 0], [0, 0], [0, padding_size], [block_length, padding_size]]
else:
#We shift everything over by half a block so query is in centre
pad_right = block_length // 2
pad_left = block_length - pad_right
extra_padding = [[0, 0], [0, 0],
[pad_left, padding_size+pad_right], [0, 0]]
bp = [[0, 0], [0, 0],
[0, padding_size], [pad_left, padding_size+pad_right]]
k = tf.pad(k, extra_padding)
v = tf.pad(v, extra_padding)

# Reshape into blocks
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
k = tf.reshape(k, [batch, heads, num_blocks+1, block_length, depth_k])
v = tf.reshape(v, [batch, heads, num_blocks+1, block_length, depth_v])

# Get local blocks by slicing
def local(x):
"""Create a local version of the keys or values."""
prev_block = tf.slice(
x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1])
x, [0, 0, 0, 0, 0], [-1, -1, num_blocks, -1, -1])
cur_block = tf.slice(
x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
return tf.concat([prev_block, cur_block], 3)
local_k = local(k)
local_v = local(v)
tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])

local_length = tf.shape(local_k)[3]

# [batch, heads, num_blocks - 1, block_length, local_length]
attention = tf.matmul(tail_q, local_k, transpose_b=True)

# make sure source_pos <= target_pos
good_part = tf.matrix_band_part(
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
mask = (1.0 - good_part) * -1e9
attention += tf.reshape(mask, [1, 1, 1, block_length, local_length])
# [batch, heads, num_blocks, block_length, local_length]
attention = tf.matmul(q, local_k, transpose_b=True)

# Apply bias (N.B: This is not currently working)
if bias is not None:
with tf.name_scope('bias'):
b_batch = tf.shape(bias)[0]
b_heads = tf.shape(bias)[1]
bias_ = bias
#bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0)
if truncate_bias:
# Use only the query dimension
bias = tf.expand_dims(bias[:,:,:,0], 2)
bias = tf.pad(bias, extra_padding, name='bias_pad_b')# 17, 5, 3
bias = tf.reshape(bias,
[b_batch, b_heads, 1, num_blocks+1, block_length],
name='divide_blocks')
local_b = tf.reshape(local(bias),
[b_batch, b_heads, num_blocks, 1, -1], name='reshape_local')
else:
bias = tf.pad(bias, bp, name='pad')
bias = tf.reshape(bias,
[b_batch, b_heads, num_blocks, block_length,
num_blocks+1, block_length], name='divide_blocks')
bias = tf.transpose(bias, [4,2,0,1,3,5])
bias = tf.reshape(bias,
[num_blocks*(num_blocks+1), b_batch, b_heads,
block_length, block_length], name='combine')
indices = (num_blocks+1)*tf.range(num_blocks)
prev_block = tf.gather(bias, indices)
cur_block = tf.gather(bias, indices+num_blocks)
local_b = tf.concat([prev_block, cur_block], 4)
local_b = tf.transpose(local_b, [1,2,0,3,4])
return l-local_b
attention += local_b

attention = tf.nn.softmax(attention)
# TODO(noam): figure out how to show a summary for the remaining blocks.
# The naive way currently causes errors due to empty tensors.
# output: [batch, heads, num_blocks-1, block_length, depth_v]
output = tf.matmul(attention, local_v)
output = tf.reshape(output, [batch, heads, -1, depth_v])
output = tf.concat([first_output, output], axis=2)
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output.set_shape(v_shape)
return output


# Get local mask
if not use_whole_block:
good_part = tf.matrix_band_part(
tf.ones([block_length, local_length]), 0, tf.to_int64(block_length))
elif not look_right:
good_part = tf.matrix_band_part(
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
else:
good_part = tf.ones([block_length, local_length])

def unmasked_local_attention_1d(q, k, v, block_length=128, filter_width=100,
name=None):
"""strided block local self-attention.
#good_part = tf.cast(good_part, tf.float64)
attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length])

Args:
q: a Tensor with shape [batch, heads, length, depth_k]
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
block_length: an integer
filter_width: an integer indicating how much to look left.
name: an optional string

output = tf.matmul(attention, local_v)
output = tf.reshape(output, [batch, heads, -1, depth_v])

Returns:
a Tensor of shape [batch, heads, length, depth_v]
"""
with tf.variable_scope(name, default_name="local_self_attention_1d",
values=[q, k, v]):
v_shape = v.get_shape()
depth_v = tf.shape(v)[3]
batch_size = tf.shape(q)[0]
num_heads = tf.shape(q)[1]
original_length = tf.shape(q)[2]
# making sure q is a multiple of d
def pad_to_multiple(x, pad_length):
x_length = tf.shape(x)[2]
return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]])
def pad_l_and_r(x, pad_length):
return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]])
q = pad_to_multiple(q, block_length)
k = pad_to_multiple(k, block_length)
v = pad_to_multiple(v, block_length)

# Setting up q blocks
new_q_shape = tf.shape(q)
# Setting up q blocks
q = tf.reshape(q, [new_q_shape[0], new_q_shape[1],
new_q_shape[2]//block_length,
block_length, new_q_shape[3]])

# Setting up k and v values
k = pad_l_and_r(k, filter_width)
v = pad_l_and_r(v, filter_width)

length = tf.shape(k)[2]
full_filter_width = block_length + 2*filter_width
# getting gather indices
indices = tf.range(0, length, delta=1, name="index_range")
# making indices [1, length, 1] to appy convs
indices = tf.reshape(indices, [1, -1, 1])
kernel = tf.expand_dims(tf.eye(full_filter_width), axis=1)
gather_indices = tf.nn.conv1d(
tf.cast(indices, tf.float32),
kernel,
block_length,
padding="VALID",
name="gather_conv")

gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0)

# [length, batch, heads, dim]
k_t = tf.transpose(k, [2, 0, 1, 3])
k_new = tf.gather(k_t, gather_indices)

# [batch, heads, blocks, block_length, dim]
k_new = tf.transpose(k_new, [2, 3, 0, 1, 4])

attention_bias = tf.expand_dims(
tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2)

v_t = tf.transpose(v, [2, 0, 1, 3])
v_new = tf.gather(v_t, gather_indices)
v_new = tf.transpose(v_new, [2, 3, 0, 1, 4])

logits = tf.matmul(q, k_new, transpose_b=True)

attention = tf.nn.softmax(logits+attention_bias)
output = tf.matmul(attention, v_new)

output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
# Remove the padding if introduced
# Remove added padding
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output.set_shape(v_shape)
return output
Expand Down Expand Up @@ -542,8 +516,8 @@ def multihead_attention(query_antecedent,
dropout_rate: a floating point number
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
attention_type: a string, either "dot_product" or "local_mask_right" or
"local_unmasked"
attention_type: a string, either "dot_product" or "local" or
"local_mask_right"
block_length: an integer - relevant for "local_mask_right"
name: an optional string
Expand Down Expand Up @@ -592,11 +566,12 @@ def multihead_attention(query_antecedent,
if attention_type == "dot_product":
x = dot_product_attention(
q, k, v, bias, dropout_rate, image_shapes)
elif attention_type == "local_mask_right":
x = masked_local_attention_1d(q, k, v, block_length=block_length)
elif attention_type == "local":
x = local_attention_1d(q, k, v, block_length=block_length)
else:
assert attention_type == "local_unmasked"
x = unmasked_local_attention_1d(q, k, v, block_length=block_length)
assert attention_type == "local_mask_right"
x = local_attention_1d(
q, k, v, block_length=block_length, look_right=False)
x = combine_heads(x)
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
return x
Expand Down
82 changes: 82 additions & 0 deletions tensor2tensor/models/common_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for common layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import numpy as np
from tensor2tensor.models import common_attention

import tensorflow as tf


class CommonAttentionTest(tf.test.TestCase):

def testLocalAttention(self):
q = np.array([[[ [1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0] ]]])

k = np.array([[[ [1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0] ]]])

b = np.array([[[ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ]]])

#b = np.ones((1,1,8,8))
#b = (1-b) * (-1e9)
v = np.ones((1, 1, 8, 1))

#q = np.random.rand(5, 7, 13, 3)
#k = np.random.rand(5, 7, 13, 3)
#v = np.random.rand(5, 7, 13, 11)
#b = np.random.rand(5, 1, 13, 1)

with self.test_session() as session:
q_ = tf.constant(q)
k_ = tf.constant(k)
v_ = tf.constant(v)
b_ = tf.constant(b)
y = common_attention.local_attention_1d(q_, k_, v_, b_, block_length=tf.constant(2))
res = session.run(y)
#print(q)
#rint(k)
print(res)
#self.assertEqual(res.shape, (5, 7, 13, 11))


if __name__ == "__main__":
tf.test.main()
Loading

0 comments on commit 0df0f50

Please sign in to comment.