Skip to content

Commit

Permalink
Support for GQA, which is MQA with >1 kv_heads
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627817922
  • Loading branch information
lingvo-bot authored and copybara-github committed Jun 5, 2024
1 parent b77456f commit d47a8fe
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 16 deletions.
56 changes: 46 additions & 10 deletions lingvo/core/batch_major_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,12 @@ def Params(cls):
'When it is None, use `query_input_dim` for the output projection.',
)
p.Define('num_heads', 1, 'Num of attention heads.')
p.Define(
'num_kv_heads',
None,
'Number of kv heads. Defaults to num_heads. num_heads % num_kv_heads'
' = 0. Based on GQA - https://arxiv.org/pdf/2305.13245.pdf',
)
p.Define(
'dim_per_head',
None,
Expand Down Expand Up @@ -718,7 +724,18 @@ def ProjectInput(input_dim, dim_per_head=None, num_heads=None):
value_input_dim = p.input_dim
query_input_dim = p.input_dim

num_kv_heads = p.num_kv_heads
if num_kv_heads is None:
num_kv_heads = p.num_heads

if p.use_mqa:
if num_kv_heads > 1:
assert (
num_kv_heads <= p.num_heads
), 'num_kv_heads needs to be <= num_heads.'
assert (
p.num_heads % num_kv_heads == 0
), 'num_kv_heads needs to divide num_heads exactly.'
assert (
not p.enable_qk_proj_in_onestep
), 'enable_qk_proj_in_onestep is not supported for use_mqa'
Expand All @@ -736,7 +753,9 @@ def ProjectInput(input_dim, dim_per_head=None, num_heads=None):
self.CreateChild(
'kv',
ProjectInput(
key_input_dim, dim_per_head=self.dim_per_head * 2, num_heads=1
key_input_dim,
dim_per_head=self.dim_per_head * 2,
num_heads=num_kv_heads,
),
)
else:
Expand Down Expand Up @@ -855,10 +874,25 @@ def _AttenLogits(self, theta, query, key):
A Tensor of shape [B, N, T, S]
"""
del theta
p = self.params
num_kv_heads = p.num_kv_heads
if num_kv_heads is None:
num_kv_heads = p.num_heads
query, key = self.ToAqtActActInputs(query, key)
qlayer = self if self.params.qdomain is not None else None
if self.params.use_mqa:
logits = self.QEinsum('BTNH,BSH->BNTS', query, key)
_, s, k, _ = py_utils.GetShape(key, 4)
if num_kv_heads > 1:
b, t, n, h = py_utils.GetShape(query, 4)
query = tf.reshape(
query, [b, t, num_kv_heads, p.num_heads // num_kv_heads, h]
)
logits = self.QEinsum('BTKnH,BSKH->BnKTS', query, key)
logits = tf.reshape(logits, [b, n, t, s])
else:
assert k == 1
key = tf.squeeze(key, axis=2)
logits = self.QEinsum('BTNH,BSH->BNTS', query, key)
else:
logits = attention_util.AttenLogits(query, key, qlayer=qlayer)
logits = self.FromAqtActActMatmul(logits)
Expand Down Expand Up @@ -927,10 +961,7 @@ def AttenProbs(

key = py_utils.HasRank(key, 4)
b, s, n, h = py_utils.GetShape(key, 4)
if p.use_mqa:
assert n == 1
key = tf.squeeze(key, axis=2)
n = p.num_heads
n = p.num_heads
query = py_utils.HasShape(query, [b, -1, n, h])
t = py_utils.GetShape(query)[1]
if segment_mask is not None and self.params.packed_input:
Expand Down Expand Up @@ -1003,10 +1034,15 @@ def _AttenContext(self, theta, probs, value):
)
qlayer = self if self.params.qdomain is not None else None
if self.params.use_mqa:
_, _, kv_heads, _ = py_utils.GetShape(value, 4)
assert kv_heads == 1
value = tf.squeeze(value, axis=2)
encoded = self.QEinsum('BNTS,BSH->BTNH', probs, value)
_, _, kv_heads, h = py_utils.GetShape(value, 4)
if kv_heads == 1:
value = tf.squeeze(value, axis=2)
encoded = self.QEinsum('BNTS,BSH->BTNH', probs, value)
else:
b, n, t, s = py_utils.GetShape(probs, 4)
probs = tf.reshape(probs, [b, n // kv_heads, kv_heads, t, s])
encoded = self.QEinsum('BnKTS,BSKH->BTnKH', probs, value)
encoded = tf.reshape(encoded, [b, t, n, h])
else:
encoded = attention_util.AttenContext(probs, value, qlayer=qlayer)
return self.FromAqtActActMatmul(encoded)
Expand Down
26 changes: 20 additions & 6 deletions lingvo/core/batch_major_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,21 @@ def testMultiHeadedProjectionLayerOutputMode(self, batch_dims):
('_qkv_one_step_false_qk_one_step_true', False, True),
('_qkv_one_step_true', True),
('_use_mqa', False, False, True),
('_use_gqa', False, False, True, 2),
)
def testMultiHeadedAttentionDotProductOutputDim(
self,
enable_qkv_proj_in_onestep=False,
enable_qk_proj_in_onestep=False,
use_mqa=False,
num_kv_heads=1,
):
# input_batch:6, seq_len:6. Test n = 2 case.
bsz, slen = 6, 6
input_dim = 2
hidden_dim = 4
output_dim = 4
num_heads = 2
num_heads = 4
with self.session(use_gpu=True) as sess:
input_vecs, input_padding, _, _ = self._AttentionInputs(
input_dim=input_dim
Expand All @@ -258,6 +260,7 @@ def testMultiHeadedAttentionDotProductOutputDim(
enable_qkv_proj_in_onestep=enable_qkv_proj_in_onestep,
enable_qk_proj_in_onestep=enable_qk_proj_in_onestep,
use_mqa=use_mqa,
num_kv_heads=num_kv_heads,
)

l = p.Instantiate()
Expand Down Expand Up @@ -579,18 +582,20 @@ def testMultiHeadedAttentionShapedAttentionEquality(self):
('qkv_one_step_false_qk_one_step_true', False, True),
('qkv_one_step_true', True),
('_use_mqa', False, False, True),
('_use_gqa', False, False, True, 2),
)
def testMultiHeadedAttentionVariableDim(
self,
enable_qkv_proj_in_onestep=False,
enable_qk_proj_in_onestep=False,
use_mqa=False,
num_kv_heads=1,
):
# input_batch:6, seq_len:6. Test n = 2 case.
input_dim = 2
hidden_dim = 4
output_dim = 4
num_heads = 2
num_heads = 4

p = attention.MultiHeadedAttention.Params().Set(
name='self_atten',
Expand All @@ -601,6 +606,7 @@ def testMultiHeadedAttentionVariableDim(
enable_qkv_proj_in_onestep=enable_qkv_proj_in_onestep,
enable_qk_proj_in_onestep=enable_qk_proj_in_onestep,
use_mqa=use_mqa,
num_kv_heads=num_kv_heads,
)

l = p.Instantiate()
Expand All @@ -626,10 +632,18 @@ def testMultiHeadedAttentionVariableDim(
self.assertNotIn('key', l.vars)
elif use_mqa:
self.assertIn('kv', l.vars)
self.assertEqual(
l.kv.theta.w.get_shape(),
tf.TensorShape([input_dim, 1, hidden_dim // num_heads * 2]),
)
if num_kv_heads is not None:
self.assertEqual(
l.kv.theta.w.get_shape(),
tf.TensorShape(
[input_dim, num_kv_heads, hidden_dim // num_heads * 2]
),
)
else:
self.assertEqual(
l.kv.theta.w.get_shape(),
tf.TensorShape([input_dim, 1, hidden_dim // num_heads * 2]),
)
self.assertIn('query', l.vars)
else:
self.assertNotIn('qkv', l.vars)
Expand Down
1 change: 1 addition & 0 deletions lingvo/core/self_attention_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def testTransformerStackV2WithSimplifiedTransformerWithMQA(
atten_builder.atten_tpl.enable_shaped_attention = True
atten_builder.atten_tpl.enable_ctx_post_proj = False
atten_builder.atten_tpl.use_mqa = True
atten_builder.atten_tpl.num_kv_heads = 1
builder = atten_builder.Instantiate()
if use_v1_stack:
p = builder.TransformerStack('atten', num_layers=3)
Expand Down

0 comments on commit d47a8fe

Please sign in to comment.