Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Aug 24, 2024
1 parent cb169ee commit 29bcbed
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 56 deletions.
15 changes: 4 additions & 11 deletions fairseq/modules/fast_grouped_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,10 @@ def forward(

if (self.num_heads != self.num_kv_heads) and k is not None and v is not None:
# self.num_heads == self.num_kv_heads * self.q_per_kv
k = rearrange(k, "(b h) t d -> b h 1 t d", h=self.num_kv_heads)
k = k.expand(bsz, self.num_kv_heads, self.q_per_kv, -1, self.head_dim)
k = rearrange(
k, "b h nq t d -> (b h nq) t d", h=self.num_kv_heads, nq=self.q_per_kv
)

v = rearrange(v, "(b h) t d -> b h 1 t d", h=self.num_kv_heads)
v = v.expand(bsz, self.num_kv_heads, self.q_per_kv, -1, self.head_dim)
v = rearrange(
v, "b h nq t d -> (b h nq) t d", h=self.num_kv_heads, nq=self.q_per_kv
)
k = rearrange(k, "(b h) t d -> b h t d", h=self.num_kv_heads)
k = torch.repeat_interleave(k, dim=1, repeats=self.q_per_kv)
v = rearrange(v, "(b h) t d -> b h t d", h=self.num_kv_heads)
v = torch.repeat_interleave(v, dim=1, repeats=self.q_per_kv)

if saved_state is not None:
# saved states are stored with shape (bsz, num_kv_heads, q_per_kv, seq_len, head_dim)
Expand Down
91 changes: 46 additions & 45 deletions fairseq/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,42 +150,43 @@ def _prune_fc_layer(self, remove_index: List[int]):
self.fc2.bias = torch.nn.Parameter(new_fc2_bias)

def build_self_attention(self, embed_dim, cfg):
if self.attn_implementation == "fairseq":
return MultiheadAttention(
is_fused = self.attn_implementation.endswith("fused")

if self.attn_implementation.startswith("fast"):
return FastMultiheadAttention(
embed_dim,
cfg.encoder.attention_heads,
dropout=cfg.attention_dropout,
self_attention=True,
fused_qkv=is_fused,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
rope_args=getattr(cfg, "rope_args", None),
)
elif self.attn_implementation == "fast" or self.attn_implementation == "fast_fused":
return FastMultiheadAttention(
elif (
self.attn_implementation.startswith("fast_gqa")
and getattr(cfg.encoder, "kv_attention_heads", None) is not None
):
return FastGroupedQueryAttention(
embed_dim,
cfg.encoder.attention_heads,
cfg.encoder.kv_attention_heads,
dropout=cfg.attention_dropout,
self_attention=True,
fused_qkv=is_fused,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
rope_args=getattr(cfg, "rope_args", None),
fused_qkv=(self.attn_implementation == "fast_fused")
)
elif self.attn_implementation == "fast_gqa" or self.attn_implementation == "fast_gqa_fused":
return FastGroupedQueryAttention(
else:
return MultiheadAttention(
embed_dim,
cfg.encoder.attention_heads,
cfg.encoder.kv_attention_heads,
dropout=cfg.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
rope_args=getattr(cfg, "rope_args", None),
fused_qkv=(self.attn_implementation == "fast_gqa_fused")
)
else:
raise NotImplementedError(
f"Unknown attention implementation: {self.attn_implementation}"
xformers_att_config=cfg.encoder.xformers_att_config,
)

def build_normalization(self, dim, rms=False):
Expand Down Expand Up @@ -401,19 +402,9 @@ def build_glu(self, input_dim, intermediate_dim, activation_fn="silu", bias=Fals
def build_self_attention(
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
):
if self.attn_implementation == "fairseq":
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
dropout=cfg.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.decoder.xformers_att_config,
)
elif self.attn_implementation == "fast" or self.attn_implementation == "fast_fused":
is_fused = self.attn_implementation.endswith("fused")

if self.attn_implementation.startswith("fast"):
return FastMultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
Expand All @@ -422,12 +413,15 @@ def build_self_attention(
add_zero_attn=add_zero_attn,
self_attention=True,
is_decoder=True,
fused_qkv=is_fused,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
rope_args=getattr(cfg, "rope_args", None),
fused_qkv=(self.attn_implementation == "fast_fused")
)
elif self.attn_implementation == "fast_gqa" or self.attn_implementation == "fast_gqa_fused":
elif (
self.attn_implementation.startswith("fast_gqa")
and getattr(cfg.decoder, "kv_attention_heads", None) is not None
):
return FastGroupedQueryAttention(
embed_dim,
cfg.decoder.attention_heads,
Expand All @@ -437,30 +431,26 @@ def build_self_attention(
add_zero_attn=add_zero_attn,
self_attention=True,
is_decoder=True,
fused_qkv=is_fused,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
rope_args=getattr(cfg, "rope_args", None),
fused_qkv=(self.attn_implementation == "fast_gqa_fused")
)
else:
raise NotImplementedError(
f"Unknown attention implementation: {self.attn_implementation}"
)

def build_encoder_attention(self, embed_dim, cfg):
if self.attn_implementation == "fairseq":
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
kdim=cfg.encoder.embed_dim,
vdim=cfg.encoder.embed_dim,
dropout=cfg.attention_dropout,
encoder_decoder_attention=True,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
xformers_att_config=cfg.decoder.xformers_att_config,
)
elif self.attn_implementation == "fast" or self.attn_implementation == "fast_fused":

def build_encoder_attention(self, embed_dim, cfg):
if self.attn_implementation.startswith("fast"):
return FastMultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
Expand All @@ -473,7 +463,10 @@ def build_encoder_attention(self, embed_dim, cfg):
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
elif self.attn_implementation == "fast_gqa" or self.attn_implementation == "fast_gqa_fused":
elif (
self.attn_implementation.startswith("fast_gqa")
and getattr(cfg.decoder, "kv_attention_heads", None) is not None
):
return FastGroupedQueryAttention(
embed_dim,
cfg.decoder.attention_heads,
Expand All @@ -488,8 +481,16 @@ def build_encoder_attention(self, embed_dim, cfg):
qn_block_size=self.quant_noise_block_size,
)
else:
raise NotImplementedError(
f"Unknown attention implementation: {self.attn_implementation}"
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
kdim=cfg.encoder.embed_dim,
vdim=cfg.encoder.embed_dim,
dropout=cfg.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
)

def prepare_for_onnx_export_(self):
Expand Down

0 comments on commit 29bcbed

Please sign in to comment.