Skip to content

Commit

Permalink
Adding Shaped Attention as described in Eq.5 of https://arxiv.org/pdf…
Browse files Browse the repository at this point in the history
…/2311.01906.pdf

PiperOrigin-RevId: 623560053
  • Loading branch information
lingvo-bot authored and copybara-github committed Jun 5, 2024
1 parent 8402baf commit 8fa400e
Show file tree
Hide file tree
Showing 2 changed files with 417 additions and 11 deletions.
132 changes: 121 additions & 11 deletions lingvo/core/batch_major_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ def Params(cls):
' multi-headed. Based on https://arxiv.org/pdf/1911.02150.pdf.'
),
)
p.Define(
'enable_shaped_attention',
False,
'If True, perturbs attention based on Eq. 5 of'
' https://arxiv.org/pdf/2311.01906.pdf.',
)
p.Define(
'query_stride',
1,
Expand Down Expand Up @@ -700,6 +706,35 @@ def ProjectInput(input_dim, dim_per_head=None, num_heads=None):
if p.use_scale_invariant_atten:
assert not (p.enable_scaling_code_motion or p.atten_extra_logit)

if p.enable_shaped_attention:
self.CreateVariable(
'shaped_attn_alpha',
py_utils.WeightParams(
shape=[1, p.num_heads, 1, 1],
init=py_utils.WeightInit.Constant(1.0),
dtype=p.dtype,
collections=[self.__class__.__name__ + '_vars'],
),
)
self.CreateVariable(
'shaped_attn_beta',
py_utils.WeightParams(
shape=[1, p.num_heads, 1, 1],
init=py_utils.WeightInit.Constant(1.0),
dtype=p.dtype,
collections=[self.__class__.__name__ + '_vars'],
),
)
self.CreateVariable(
'shaped_attn_gamma',
py_utils.WeightParams(
shape=[1, p.num_heads, 1, 1],
init=py_utils.WeightInit.Constant(1.0),
dtype=p.dtype,
collections=[self.__class__.__name__ + '_vars'],
),
)

@property
def dim_per_head(self):
"""Returns the dimension per attention head."""
Expand Down Expand Up @@ -897,14 +932,83 @@ def _AttenContextOneStep(self, theta, probs, value, time_step, h):
encoded = self.QEinsum('SBN,SBNH->BNH', probs, value)
return self.FromAqtActActMatmul(encoded)

def _DotAtten(self,
theta,
query,
key,
value,
paddings,
segment_mask,
per_step_padding=None):
def _GetStridedIdentity(self, length, stride):
diag = tf.eye(length, dtype=self.params.dtype)
if stride > 1:
diag = diag[:None:stride, :]
diag = tf.expand_dims(tf.expand_dims(diag, axis=0), axis=0)
return diag

def _SoftmaxZeroOrderTerm(self, probs, segment_mask=None):
"""Zero order term in the Taylor series expansion of Softmax.
Eq. 8 in https://arxiv.org/pdf/2306.17759.pdf.
Args:
probs: [B, N, T, S].
segment_mask: [B, 1, T, S]. A mask that is applied to prevent attention
between different segments. This is already been converted into large
negative logits. Only applied if packed_input = True.
Returns:
center: Softmax zero order term of shape [B, 1, T, S].
"""
p = self.params
_, _, query_len, key_len = py_utils.GetShape(probs, 4)
center = tf.ones(shape=[1, 1, query_len, key_len], dtype=probs.dtype)
# Without packing, each sequence has a single segment, so segment length
# will be same as key length.
segment_len = key_len
if p.packed_input and segment_mask is not None:
# With packing, segment mask consists of 0's for one segment and high
# negatives for the rest of the segments. By computing the number of 0s
# and their sum, we're getting the single segment length, which is then
# used to divide center.
zero_mask = tf.equal(segment_mask, 0)
# [B, 1, T, 1]
segment_len = tf.reduce_sum(
tf.cast(zero_mask, tf.int32), axis=-1, keepdims=True
) # Count True (1) values along axis=-1
center = center * tf.cast(zero_mask, dtype=center.dtype)
center = center / tf.cast(segment_len, center.dtype)
return center

def ShapedAttention(self, theta, probs, segment_mask):
"""Shaped Attention as in Eq.5 https://arxiv.org/pdf/2311.01906.pdf.
Args:
theta: A `.NestedMap` object containing weights' values of this layer and
its children layers.
probs: [B, N, T, S].
segment_mask: [B, 1, T, S]. A mask that is applied to prevent attention
between different segments. This is already been converted into large
negative logits. Only applied if packed_input = True.
Returns:
probs: Perturbed probs of shape [B, N, T, S].
"""
shaped_attn_beta = tf.cast(theta.shaped_attn_beta, probs.dtype)
_, _, query_len, key_len = py_utils.GetShape(probs, 4)
query_stride = key_len // query_len
diag = self._GetStridedIdentity(key_len, query_stride)
center = self._SoftmaxZeroOrderTerm(probs, segment_mask)
shaped_probs = (
theta.shaped_attn_alpha * diag
+ shaped_attn_beta * probs
- theta.shaped_attn_gamma * center
)
return shaped_probs

def _DotAtten(
self,
theta,
query,
key,
value,
paddings,
segment_mask,
per_step_padding=None,
):
"""Main attention function.
Args:
Expand All @@ -931,12 +1035,18 @@ def _DotAtten(self,
if p.enable_per_dim_scale:
query = self.per_dim_scale.FProp(theta.per_dim_scale, query)
else:
query *= (p.hidden_dim // p.num_heads)**-0.5
query *= (p.hidden_dim // p.num_heads) ** -0.5

# Compute prob with shape [batch, heads, target_time, source_time].
with tf.name_scope('probs'):
probs, probs_sum = self.AttenProbs(theta, query, key, paddings,
segment_mask, per_step_padding)
probs, probs_sum = self.AttenProbs(
theta, query, key, paddings, segment_mask, per_step_padding
)
if p.enable_shaped_attention:
assert (
not p.enable_scaling_code_motion
), 'Shaped attention is not supported with scaling code motion.'
probs = self.ShapedAttention(theta, probs, segment_mask)
# Apply dropout to probs.
probs = self.atten_dropout.FProp(theta.atten_dropout, probs)

Expand Down
Loading

0 comments on commit 8fa400e

Please sign in to comment.