Skip to content

Commit

Permalink
added rslora
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jun 6, 2024
1 parent c55b317 commit b548b71
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `
|-----------------------|-----------------------|-----------------------|------------|
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model | `--teacher-checkpoint-path $teacher_ckpt --task translation_with_kd --criterion label_smoothed_cross_entropy_with_kd --kd-args '{"strategy": "word_level"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj"}' --use-native-attention --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --use-native-attention --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"max_position_embeddings": 2048, "base": 10000, "type": "vanilla"}' --use-native-attention --no-token-positional-embeddings` | [RoPE Implementation](https://github.com/jquesnelle/yarn/blob/master/scaled_rope/modeling_llama_yarn.py) |
| **Yet another RoPE extensioN method (YaRN)** ([Peng _et al_.](https://openreview.net/forum?id=wHBfxhZu1u)) | Compute-efficient method to extend the context window of models | `--yarn-args '{"max_position_embeddings": 2048, "base": 10000, "type": "vanilla", "original_max_position_embeddings": 256, "extrapolation_factor": 1, "attn_factor": 1, "beta_fast": 32, "beta_slow": 1}' --use-native-attention --no-token-positional-embeddings` | [YaRN Implementation](https://github.com/jquesnelle/yarn/blob/master/scaled_rope/modeling_llama_yarn.py) |
| **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--alibi-args '{"alibi_asymmetrical": "false"}' --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) |
Expand Down Expand Up @@ -99,7 +99,8 @@ Please cite as:
pages = "103--114",
abstract = "Knowledge distillation (KD) is a well-known method for compressing neural models. However, works focusing on distilling knowledge from large multilingual neural machine translation (MNMT) models into smaller ones are practically nonexistent, despite the popularity and superiority of MNMT. This paper bridges this gap by presenting an empirical investigation of knowledge distillation for compressing MNMT models. We take Indic to English translation as a case study and demonstrate that commonly used language-agnostic and language-aware KD approaches yield models that are 4-5x smaller but also suffer from performance drops of up to 3.5 BLEU. To mitigate this, we then experiment with design considerations such as shallower versus deeper models, heavy parameter sharing, multistage training, and adapters. We observe that deeper compact models tend to be as good as shallower non-compact ones and that fine-tuning a distilled model on a high-quality subset slightly boosts translation quality. Overall, we conclude that compressing MNMT models via KD is challenging, indicating immense scope for further research.",
}
```
```
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class TransformerConfig(FairseqDataclass):

lora_args: Optional[str] = field(
default=None,
metadata={"help": "LoRA arguments (rank, alpha, dropout, target_modules)"},
metadata={"help": "LoRA arguments (rank, alpha, dropout, target_modules, rank_scaled)"},
)
rope_args: Optional[str] = field(
default=None,
Expand Down
12 changes: 9 additions & 3 deletions fairseq/modules/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(
r: int,
alpha: int,
dropout: float,
rank_scaled: bool,
merge_weights: bool,
):
self.r = r
Expand All @@ -22,6 +23,7 @@ def __init__(
self.dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.rank_scaled = rank_scaled
self.merge_weights = merge_weights


Expand All @@ -33,6 +35,7 @@ def __init__(
embedding_dim: int,
r: int = 0,
alpha: int = 1,
rank_scaled: bool = False,
merge_weights: bool = True,
**kwargs
):
Expand All @@ -42,13 +45,14 @@ def __init__(
r=r,
alpha=alpha,
dropout=0,
rank_scaled=rank_scaled,
merge_weights=merge_weights,
)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
self.scaling = self.alpha / self.r
self.scaling = (self.alpha / math.sqrt(self.r)) if rank_scaled else (self.alpha / self.r)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
Expand Down Expand Up @@ -106,7 +110,8 @@ def __init__(
r: int = 0,
alpha: int = 1,
dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
rank_scaled: bool = False,
merge_weights: bool = True,
**kwargs
):
Expand All @@ -116,6 +121,7 @@ def __init__(
r=r,
alpha=alpha,
dropout=dropout,
rank_scaled=rank_scaled,
merge_weights=merge_weights,
)

Expand All @@ -124,7 +130,7 @@ def __init__(
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.alpha / self.r
self.scaling = (self.alpha / math.sqrt(self.r)) if rank_scaled else (self.alpha / self.r)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
Expand Down
1 change: 1 addition & 0 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def main(cfg: FairseqConfig) -> None:
"r": lora_config.get("r", 0),
"alpha": lora_config.get("alpha", 1),
"dropout": lora_config.get("dropout", 0.0),
"rank_scaled": lora_config.get("rank_scaled", False),
"merge_weights": True,
}

Expand Down

0 comments on commit b548b71

Please sign in to comment.