Skip to content

Commit

Permalink
mha+gqa bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Aug 12, 2024
1 parent 8ac156e commit cb169ee
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 131 deletions.
129 changes: 60 additions & 69 deletions fairseq/modules/fast_grouped_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import Parameter

import json
import math
import torch
import torch.nn as nn
from torch import Tensor
Expand Down Expand Up @@ -48,29 +49,29 @@ def __init__(
rope_args=None,
fused_qkv=False,
):
super().__init__(embed_dim, num_heads, dictionary=dictionary)
self.embed_dim = embed_dim
super().__init__(
embed_dim,
num_heads,
kdim=kdim,
vdim=vdim,
dropout=dropout,
bias=bias,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=self_attention,
encoder_decoder_attention=encoder_decoder_attention,
dictionary=dictionary,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
del self.dropout_module

self.num_kv_heads = num_kv_heads
self.q_per_kv = self.num_heads // self.num_kv_heads
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

self.is_decoder = is_decoder
self.num_heads = num_heads
self.dropout_p = dropout

self.fused_qkv = fused_qkv and self.qkv_same_dim and self_attention

self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"

self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention

self.rope = rope_args is not None and self.self_attention
self.fused_qkv = fused_qkv and self_attention
self.rope = rope_args is not None and self_attention
self._new_kv_dim = self.num_kv_heads * self.head_dim

if self.rope:
Expand All @@ -90,50 +91,52 @@ def __init__(
else None
)

assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"
if self.fused_qkv:
# remove the q_proj, k_proj, v_proj from the parent class
del self.q_proj, self.k_proj, self.v_proj
self.split_dims = [embed_dim, self._new_kv_dim, self._new_kv_dim]

if not self.fused_qkv:
self.k_proj = quant_noise(
nn.Linear(self.kdim, self._new_kv_dim, bias=bias),
q_noise,
qn_block_size,
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, self._new_kv_dim, bias=bias),
self.qkv_proj = quant_noise(
nn.Linear(embed_dim, sum(self.split_dims), bias=bias),
q_noise,
qn_block_size,
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias),
else:
self.k_proj = quant_noise(
nn.Linear(embed_dim, self._new_kv_dim, bias=bias),
q_noise,
qn_block_size,
)
else:
fused_dim = embed_dim + (2 * self._new_kv_dim)
self.split_dims = [embed_dim, self._new_kv_dim, self._new_kv_dim]

self.qkv_proj = quant_noise(
nn.Linear(embed_dim, fused_dim, bias=bias),
self.v_proj = quant_noise(
nn.Linear(embed_dim, self._new_kv_dim, bias=bias),
q_noise,
qn_block_size,
)

self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.reset_parameters()

if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
if hasattr(self, "qkv_proj"):
nn.init.xavier_uniform_(self.qkv_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
self.bias_k = self.bias_v = None
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)

self.add_zero_attn = add_zero_attn
self.beam_size = 1
self.reset_parameters()
self.init_incremental_state()
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)

def forward(
self,
Expand Down Expand Up @@ -189,34 +192,21 @@ def forward(

if self.self_attention:
if not self.fused_qkv:
q, k, v = self.q_proj(query), self.k_proj(query), self.v_proj(query)
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
else:
q, k, v = self.qkv_proj(query).split(self.split_dims, dim=-1)
elif self.encoder_decoder_attention:
else:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
:, :, 0, :
]
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(
-1, self.beam_size, key_padding_mask.size(1)
)[:, 0, :]
k = self.k_proj(key)
v = self.v_proj(key)

else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)

if self.bias_k is not None:
assert self.bias_v is not None
k, v, attn_mask, key_padding_mask = self._add_bias(
Expand All @@ -236,6 +226,7 @@ def forward(
)
# q shape: (bsz * self.num_heads, tgt_len, head_dim)
kv_bsz = bsz # need default value for scripting

if k is not None:
kv_bsz = k.size(1)
k = rearrange(
Expand All @@ -254,7 +245,7 @@ def forward(
)
# v shape: (bsz * self.num_kv_heads, src_len, head_dim)

if self.num_heads != self.num_kv_heads:
if (self.num_heads != self.num_kv_heads) and k is not None and v is not None:
# self.num_heads == self.num_kv_heads * self.q_per_kv
k = rearrange(k, "(b h) t d -> b h 1 t d", h=self.num_kv_heads)
k = k.expand(bsz, self.num_kv_heads, self.q_per_kv, -1, self.head_dim)
Expand Down Expand Up @@ -308,10 +299,10 @@ def forward(
)

saved_state["prev_key"] = k.view(
kv_bsz, self.num_heads, self.q_per_kv, -1, self.head_dim
kv_bsz, self.num_kv_heads, self.q_per_kv, -1, self.head_dim
)
saved_state["prev_value"] = v.view(
kv_bsz, self.num_heads, self.q_per_kv, -1, self.head_dim
kv_bsz, self.num_kv_heads, self.q_per_kv, -1, self.head_dim
)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
Expand Down
109 changes: 47 additions & 62 deletions fairseq/modules/fast_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import Parameter

import json
import math
import torch
import torch.nn as nn
from torch import Tensor
Expand Down Expand Up @@ -49,27 +50,26 @@ def __init__(
rope_args=None,
fused_qkv=False,
):
super().__init__(embed_dim, num_heads, dictionary=dictionary)
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
super().__init__(
embed_dim,
num_heads,
kdim=kdim,
vdim=vdim,
dropout=dropout,
bias=bias,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=self_attention,
encoder_decoder_attention=encoder_decoder_attention,
dictionary=dictionary,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
del self.dropout_module

self.is_decoder = is_decoder
self.num_heads = num_heads
self.dropout_p = dropout

self.fused_qkv = fused_qkv and self.qkv_same_dim and self_attention

self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"

self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention

self.rope = rope_args is not None and self.self_attention
self.fused_qkv = fused_qkv and self_attention
self.rope = rope_args is not None and self_attention

if self.rope:
rope_args = json.loads(rope_args)
Expand All @@ -88,39 +88,37 @@ def __init__(
else None
)

assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"

if not self.fused_qkv:
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
else:
if self.fused_qkv:
# remove the q_proj, k_proj, v_proj from the parent class
del self.q_proj, self.k_proj, self.v_proj
self.qkv_proj = quant_noise(
nn.Linear(embed_dim, 3 * embed_dim, bias=bias), q_noise, qn_block_size
)

self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.reset_parameters()

if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
if hasattr(self, "qkv_proj"):
nn.init.xavier_uniform_(self.qkv_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
self.bias_k = self.bias_v = None
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)

self.add_zero_attn = add_zero_attn
self.beam_size = 1
self.reset_parameters()
self.init_incremental_state()
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)

def forward(
self,
Expand Down Expand Up @@ -177,34 +175,21 @@ def forward(

if self.self_attention:
if not self.fused_qkv:
q, k, v = self.q_proj(query), self.k_proj(query), self.v_proj(query)
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
else:
q, k, v = self.qkv_proj(query).chunk(3, dim=-1)
elif self.encoder_decoder_attention:
else:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
:, :, 0, :
]
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(
-1, self.beam_size, key_padding_mask.size(1)
)[:, 0, :]
k = self.k_proj(key)
v = self.v_proj(key)

else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)

if self.bias_k is not None:
assert self.bias_v is not None
k, v, attn_mask, key_padding_mask = self._add_bias(
Expand Down
1 change: 1 addition & 0 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

self.num_heads = num_heads
self.dropout_p = dropout
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
Expand Down

0 comments on commit cb169ee

Please sign in to comment.