Skip to content

Commit

Permalink
bug fixes in FA2
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jul 12, 2024
1 parent d9b28d6 commit 081df4d
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 194 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
<img src="docs/fairseq_logo.png" width="150">
<br />
<br />
<a href="https://opensource.fb.com/support-ukraine"><img alt="Support Ukraine" src="https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB" /></a>
<a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
<a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
<a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
Expand All @@ -18,27 +17,28 @@ modeling and other text generation tasks.


# Usage
This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `LoRA` `RoPE`, and `ALiBi` for the `Transformer` model and the `translation` task. You can add the following flags to `fairseq-train`/`fairseq-interactive`/`fairseq-generate` to use them:
This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `FlashAttention2`, `LoRA` `RoPE`, and `ALiBi`, for the `Transformer` model and the `translation` task. You can add the following flags to `fairseq-train`/`fairseq-interactive`/`fairseq-generate` to use them:

| **Name and Citation** | **Description** | **Flags to Activate** | **Source** |
|-----------------------|-----------------------|-----------------------|------------|
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $teacher_ckpt --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation native --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--use-rope --attn-implementation native --no-token-positional-embeddings --load-checkpoint-liberally` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--use-rope --attn-implementation fast --no-token-positional-embeddings --load-checkpoint-liberally` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) |
| **Gated Linear Unit (GLU)** ([Shazeer](https://arxiv.org/abs/2002.05202)) | A better Feed-Forward-Network variant | `--encoder-use-glu --decoder-use-glu` | [GLU Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) |
| **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | An efficient normalization technique | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/rms_norm.py) |
| **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--use-alibi symmetrical --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) |
| **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--use-alibi --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) |
| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim $encoder_fac_embed_dim --decoder-factorized-embed-dim $decoder_fac_embed_dim --factorized-embed-activation-fn $fac_embed_activation_fn` | - |
| **Penultimate Linear Transformation Activation** | Adds activation to the penultimate linear transformation before the final projection onto the vocabulary | `--decoder-output-activation-fn $decoder_out_activation_fn` | - |
| **Sanity Validation Steps** | Runs a full pass over the validation set at the beginning of training | `--run-sanity-validation-steps` | - |
| **Efficient/Debloated Attention Variants** | <ul><li>_NativeMultiHeadAttention_: A manual attention computation module to aid the usecases of LoRA and RoPE</li><li>_FastMultiHeadAttention_: A [torch functional variant](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of _NativeMultiHeadAttention_ with a efficient context manager</li></ul> <span style="color: red;">Both these variants disable several checks and arguments in favour of their implementations, so please double check your requirements before enabling these flags.</span> | <ul><li>_NativeMultiHeadAttention_: `--attn-implementation native`</li><li>_FastMultiHeadAttention_: `--attn-implementation fast`</li></ul> | - |
| **Efficient/Debloated Attention Variants** | <ul><li>_FastMultiHeadAttention_: A [torch-functional variant](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of _MultiHeadAttention_ with a efficient context manager</li><li>_FlashMultiHeadAttention_: A [flash-attention variant](https://github.com/Dao-AILab/flash-attention) of _MultiHeadAttention_ with even better speedup and efficiency.</li></ul> <span style="color: red;">Both these variants disable several checks and arguments in favour of their implementations, so please double check your requirements before enabling these flags. _FlashMultiheadAttention_ is still in $\alpha$-testing, so use it with **caution**.</span> | <ul><li>_FastMultiHeadAttention_: `--attn-implementation fast`</li><li>_FlashMultiHeadAttention_: `--attn-implementation flash`</li></ul> | [FlashAttention Implementation](https://huggingface.co/ai4bharat/indictrans2-en-indic-1B/blob/main/modeling_indictrans.py) |


# Requirements and Installation

* [PyTorch](http://pytorch.org/) version >= 2.1.1
* Python version >= 3.8
* [FlashAttention2]((https://github.com/Dao-AILab/flash-attention)) >= 2.5.6
* Python version >= 3.8, <= 3.11
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* **To install fairseq** and develop locally:

Expand Down
6 changes: 3 additions & 3 deletions fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class TransformerConfig(FairseqDataclass):
cross_self_attention: bool = field(
default=False, metadata={"help": "perform cross+self-attention"}
)
attn_implementation: ChoiceEnum(["native", "fast", "fairseq"]) = field(
attn_implementation: ChoiceEnum(["fast", "flash", "fairseq"]) = field(
default="fairseq",
metadata={"help": "Mainly added for RoPE/LoRA and efficiency"},
)
Expand All @@ -249,9 +249,9 @@ class TransformerConfig(FairseqDataclass):
default=False,
metadata={"help": "use Rotary Positional Embedding (RoPE) in self-attention layers"},
)
use_alibi: Optional[str] = field(
use_alibi: Optional[bool] = field(
default=None,
metadata={"help": "use ALiBi positional encoding (symmetrical/asymmetrical)"},
metadata={"help": "use ALiBi positional encoding (symmetrical)"},
)

# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
Expand Down
3 changes: 2 additions & 1 deletion fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,15 @@ def __init__(
else:
self.project_out_activation_fn = None

if cfg.use_alibi is not None:
if cfg.use_alibi and cfg.attn_implementation != "flash":
assert (
self.embed_positions is None
), "ALiBi shouldn't be used with positional embedding"
self.alibi = utils.alibi(
cfg.decoder.attention_heads, self.max_target_positions
)
else:
# FA2 internally uses ALiBi, so we don't need to use it here
self.alibi = None

self.adaptive_softmax = None
Expand Down
6 changes: 4 additions & 2 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,20 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
else None
)

if cfg.use_alibi is not None:
if cfg.use_alibi and cfg.attn_implementation != "flash":
assert (
self.embed_positions is None
), "ALiBi shouldn't be used with positional embedding"
self.alibi = utils.alibi(
cfg.encoder.attention_heads,
self.max_source_positions,
asymmetrical=cfg.use_alibi,
asymmetrical=False
)
else:
# FA2 internally uses ALiBi, so we don't need to use it here
self.alibi = None


def normalization(self, dim, rms=False):
return (
LayerNorm(dim, export=self.cfg.export)
Expand Down
3 changes: 2 additions & 1 deletion fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from .lstm_cell_with_zoneout import LSTMCellWithZoneOut
from .mlp import MLP, GLU
from .multihead_attention import MultiheadAttention
from .native_multihead_attention import NativeMultiheadAttention
from .fast_multihead_attention import FastMultiheadAttention
from .flash_multihead_attention import FlashMultiheadAttention
from .positional_embedding import PositionalEmbedding
from .rms_norm import RMSNorm
from .same_pad import SamePad, SamePad2d
Expand Down Expand Up @@ -73,6 +73,7 @@
"EMAModuleConfig",
"FactorizedEmbedding",
"FastMultiheadAttention",
"FlashMultiheadAttention",
"FairseqDropout",
"Fp32BatchNorm",
"Fp32GroupNorm",
Expand Down
30 changes: 14 additions & 16 deletions fairseq/modules/fast_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
# LICENSE file in the root directory of this source tree.

from typing import Dict, Optional, Tuple

from rotary_embedding_torch import RotaryEmbedding
from torch.nn import Parameter

import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange
from fairseq.modules.quant_noise import quant_noise
from torch.nn.functional import scaled_dot_product_attention

from fairseq.modules.multihead_attention import MultiheadAttention

try:
from rotary_embedding_torch import RotaryEmbedding
except ImportError:
raise ImportError("Please install the rotary-embedding-torch>=0.6.4")

# HACK: This attention variant is mainly for speedup.
# HACK: Attenion weights are internalized and None is returned for them.
# HACK: Double check your requirements before using this variant.
# BUG: FlashAttention does not work with an attn_mask.
# BUG: FlashAttention does not work with an attn_mask. Use FlashMutliheadAttention instead.


class FastMultiheadAttention(MultiheadAttention):
Expand Down Expand Up @@ -95,7 +98,6 @@ def __init__(
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)

self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
Expand All @@ -109,7 +111,6 @@ def __init__(
self.add_zero_attn = add_zero_attn
self.beam_size = 1
self.reset_parameters()

self.init_incremental_state()

def forward(
Expand Down Expand Up @@ -197,9 +198,9 @@ def forward(
k, v, attn_mask, key_padding_mask = self._add_bias(
k=k,
v=v,
bsz=bsz,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
bsz=bsz,
)

q = (
Expand Down Expand Up @@ -313,6 +314,7 @@ def forward(
* torch.finfo(q.dtype).min
)

# SDPA cannot accept both causal and attn_mask
if attn_mask is not None:
if attn_mask.size() != key_padding_mask.size():
combined_mask = attn_mask.unsqueeze(0) + key_padding_mask
Expand All @@ -326,7 +328,7 @@ def forward(

combined_mask = combined_mask.to(q.dtype) if combined_mask is not None else None

# this the part that is different from NativeMultiheadAttention
# this the part that is different from MultiheadAttention
# it uses a kernelized implementation of SDPA
# and the attn_weights are internalized
with torch.backends.cuda.sdp_kernel(
Expand All @@ -343,13 +345,9 @@ def forward(
dropout_p=self.dropout_p,
)

assert list(attn.size()) == [
bsz * self.num_heads,
tgt_len,
self.head_dim,
], f"attn size should be {[bsz * self.num_heads, tgt_len, self.head_dim]}, but is {attn.shape()}"

attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
# attn shape: (bsz * self.num_heads, tgt_len, self.head_dim)
attn = rearrange(
attn, "(b h) t c -> t b (h c)", h=self.num_heads, c=self.head_dim
)
attn = self.out_proj(attn)

return attn, None
Loading

0 comments on commit 081df4d

Please sign in to comment.