Skip to content

Commit

Permalink
removed FA2
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jul 19, 2024
1 parent 20a22d4 commit f08a220
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 541 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `
| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim $encoder_fac_embed_dim --decoder-factorized-embed-dim $decoder_fac_embed_dim --factorized-embed-activation-fn $fac_embed_activation_fn` | - |
| **Penultimate Linear Transformation Activation** | Adds activation to the penultimate linear transformation before the final projection onto the vocabulary | `--decoder-output-activation-fn $decoder_out_activation_fn` | - |
| **Sanity Validation Steps** | Runs a full pass over the validation set at the beginning of training | `--run-sanity-validation-steps` | - |
| **Efficient/Debloated Attention Variants** | <ul><li>_FastMultiHeadAttention_: A [torch-functional variant](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of _MultiHeadAttention_ with a efficient context manager</li><li>_FlashMultiHeadAttention_: A [flash-attention variant](https://github.com/Dao-AILab/flash-attention) of _MultiHeadAttention_ with even better speedup and efficiency.</li></ul> <span style="color: red;">Both these variants disable several checks and arguments in favour of their implementations, so please double check your requirements before enabling these flags. _FlashMultiheadAttention_ is still in $\alpha$-testing, so use it with **caution**.</span> | <ul><li>_FastMultiHeadAttention_: `--attn-implementation fast`</li><li>_FlashMultiHeadAttention_: `--attn-implementation flash`</li></ul> | [FlashAttention Implementation](https://huggingface.co/ai4bharat/indictrans2-en-indic-1B/blob/main/modeling_indictrans.py) |
| **Efficient MHA** | A [torch-functional variant](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of _MultiHeadAttention_ with a efficient context manager | `--attn-implementation fast`. By default, the value is `fairseq` | - |


# Requirements and Installation

* [PyTorch](http://pytorch.org/) version >= 2.1.1
* [FlashAttention2]((https://github.com/Dao-AILab/flash-attention)) >= 2.5.6
* Python version >= 3.8, <= 3.11
* Python version >= 3.8, <= 3.12
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* **To install fairseq** and develop locally:

Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class TransformerConfig(FairseqDataclass):
cross_self_attention: bool = field(
default=False, metadata={"help": "perform cross+self-attention"}
)
attn_implementation: ChoiceEnum(["fast", "flash", "fairseq"]) = field(
attn_implementation: ChoiceEnum(["fast", "fairseq"]) = field(
default="fairseq",
metadata={"help": "Mainly added for RoPE/LoRA and efficiency"},
)
Expand Down
3 changes: 1 addition & 2 deletions fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,14 @@ def __init__(
else:
self.project_out_activation_fn = None

if cfg.use_alibi and cfg.attn_implementation != "flash":
if cfg.use_alibi:
assert (
self.embed_positions is None
), "ALiBi shouldn't be used with positional embedding"
self.alibi = utils.alibi(
cfg.decoder.attention_heads, self.max_target_positions
)
else:
# FA2 internally uses ALiBi, so we don't need to use it here
self.alibi = None

self.adaptive_softmax = None
Expand Down
6 changes: 2 additions & 4 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,18 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
else None
)

if cfg.use_alibi and cfg.attn_implementation != "flash":
if cfg.use_alibi:
assert (
self.embed_positions is None
), "ALiBi shouldn't be used with positional embedding"
self.alibi = utils.alibi(
cfg.encoder.attention_heads,
self.max_source_positions,
asymmetrical=False
asymmetrical=False,
)
else:
# FA2 internally uses ALiBi, so we don't need to use it here
self.alibi = None


def normalization(self, dim, rms=False):
return (
LayerNorm(dim, export=self.cfg.export)
Expand Down
2 changes: 0 additions & 2 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from .mlp import MLP, GLU
from .multihead_attention import MultiheadAttention
from .fast_multihead_attention import FastMultiheadAttention
from .flash_multihead_attention import FlashMultiheadAttention
from .positional_embedding import PositionalEmbedding
from .rms_norm import RMSNorm
from .same_pad import SamePad, SamePad2d
Expand Down Expand Up @@ -73,7 +72,6 @@
"EMAModuleConfig",
"FactorizedEmbedding",
"FastMultiheadAttention",
"FlashMultiheadAttention",
"FairseqDropout",
"Fp32BatchNorm",
"Fp32GroupNorm",
Expand Down
23 changes: 9 additions & 14 deletions fairseq/modules/fast_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,22 +328,17 @@ def forward(

combined_mask = combined_mask.to(q.dtype) if combined_mask is not None else None

# this the part that is different from MultiheadAttention
# this the part that is different from MultiheadAttentiond
# it uses a kernelized implementation of SDPA
# and the attn_weights are internalized
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
attn = scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=False,
attn_mask=combined_mask,
dropout_p=self.dropout_p,
)
attn = scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=False,
attn_mask=combined_mask,
dropout_p=self.dropout_p,
)

# attn shape: (bsz * self.num_heads, tgt_len, self.head_dim)
attn = rearrange(
Expand Down
Loading

0 comments on commit f08a220

Please sign in to comment.