From f82a8dee8203c8fcc227efb070501762a2810fb5 Mon Sep 17 00:00:00 2001 From: Varun Gumma Date: Sat, 20 Jul 2024 08:00:21 +0000 Subject: [PATCH] fixed lora bug --- README.md | 2 +- fairseq/modules/lora.py | 73 ++++++++++++++++++------------------ fairseq/tasks/translation.py | 2 +- fairseq_cli/train.py | 31 +++++++-------- 4 files changed, 52 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index ddf1ca9d26..da504123f3 100755 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, ` |-----------------------|-----------------------|-----------------------|------------| | **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $teacher_ckpt --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) | | **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - | -| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) | +| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest. | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) | | **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--use-rope --attn-implementation fast --no-token-positional-embeddings --load-checkpoint-liberally` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) | | **Gated Linear Unit (GLU)** ([Shazeer](https://arxiv.org/abs/2002.05202)) | A better Feed-Forward-Network variant | `--encoder-use-glu --decoder-use-glu` | [GLU Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) | | **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | An efficient normalization technique | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/rms_norm.py) | diff --git a/fairseq/modules/lora.py b/fairseq/modules/lora.py index 6cc55cce17..89d06e7c51 100644 --- a/fairseq/modules/lora.py +++ b/fairseq/modules/lora.py @@ -4,26 +4,24 @@ import torch.nn.functional as F -# source: https://github.com/microsoft/LoRA/blob/main/loralib/layers.py class LoRALayer: def __init__( self, r: int, alpha: int, - dropout: float, - rank_scaled: bool, + dropout: float = 0.0, + rank_scaled: bool = False, + fan_in: int = None, + fan_out: int = None, ): + assert r > 0, "Rank must be greater than 0 for LoRA to work." - assert ( - r > 0 - ), "rank must be greater than 0 for LoRA to work. Use nn.Linear/nn.Embedding if rank is 0." - - self.r = r - self.alpha = alpha - self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None - # Mark the weight as unmerged initially self.merged = False - self.rank_scaled = rank_scaled + self.dropout_p = dropout + self.scaling = (alpha / math.sqrt(r)) if rank_scaled else (alpha / r) + # better to contain the A and B matrices in the same class + self.lora_A = nn.Parameter(torch.zeros((r, fan_in))) + self.lora_B = nn.Parameter(torch.zeros((fan_out, r))) class LoRAEmbedding(nn.Embedding, LoRALayer): @@ -34,16 +32,19 @@ def __init__( embedding_dim: int, r: int, alpha: int = 1, + dropout: float = 0.0, rank_scaled: bool = False, **kwargs ): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) - LoRALayer.__init__(self, r=r, alpha=alpha, dropout=0, rank_scaled=rank_scaled) - # Actual trainable parameters - self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) - self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) - self.scaling = ( - (self.alpha / math.sqrt(self.r)) if rank_scaled else (self.alpha / self.r) + LoRALayer.__init__( + self, + r=r, + alpha=alpha, + dropout=dropout, + ranked_scaled=rank_scaled, + fan_in=embedding_dim, + fan_out=num_embeddings, ) # Freezing the pre-trained weight matrix self.weight.requires_grad = False @@ -51,9 +52,9 @@ def __init__( def reset_parameters(self): nn.Embedding.reset_parameters(self) - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.zeros_(self.lora_A) - nn.init.normal_(self.lora_B) + if hasattr(self, "lora_A"): + nn.init.zeros_(self.lora_A) + nn.init.normal_(self.lora_B) def train(self, mode: bool = True): nn.Embedding.train(self, mode) @@ -72,7 +73,7 @@ def forward(self, x: torch.Tensor): result = nn.Embedding.forward(self, x) if not self.merged: - after_A = F.embedding( + x = F.embedding( x, self.lora_A.T, self.padding_idx, @@ -81,7 +82,8 @@ def forward(self, x: torch.Tensor): self.scale_grad_by_freq, self.sparse, ) - result += (after_A @ self.lora_B.T) * self.scaling + x = F.dropout(x, p=self.dropout_p, training=self.training) + result += (x @ self.lora_B.T) * self.scaling return result @@ -98,15 +100,16 @@ def __init__( rank_scaled: bool = False, **kwargs ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__( - self, r=r, alpha=alpha, dropout=dropout, rank_scaled=rank_scaled - ) - # Actual trainable parameters - self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) - self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) - self.scaling = ( - (self.alpha / math.sqrt(self.r)) if rank_scaled else (self.alpha / self.r) + self, + r=r, + alpha=alpha, + dropout=dropout, + rank_scaled=rank_scaled, + fan_in=in_features, + fan_out=out_features, ) # Freezing the pre-trained weight matrix self.weight.requires_grad = False @@ -114,10 +117,9 @@ def __init__( def reset_parameters(self): nn.Linear.reset_parameters(self) - # initialize B the same way as the default for nn.Linear and A to zero - # this is different than what is described in the paper but should not affect performance - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) + if hasattr(self, "lora_A"): + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): nn.Linear.train(self, mode) @@ -137,8 +139,7 @@ def forward(self, x: torch.Tensor): if not self.merged: # if the weights are not merged, apply LoRA - if self.dropout is not None: - x = self.dropout(x) + x = F.dropout(x, p=self.dropout_p, training=self.training) result += x @ ((self.lora_B @ self.lora_A).T * self.scaling) return result diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 02489f3406..ddd2ab3a21 100755 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -236,7 +236,7 @@ class TranslationConfig(FairseqDataclass): eval_bleu_args: Optional[str] = field( default="{}", metadata={ - "help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' + "help": 'generation args for BLEU scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' }, ) eval_bleu_detok: str = field( diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index ee1a48f67e..6c076abf3d 100755 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -43,12 +43,13 @@ from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer from fairseq.checkpoint_utils import load_model_ensemble -from fairseq.modules.lora import LoRALinear, LoRALayer, LoRAEmbedding +from fairseq.modules.lora import * -def mark_only_lora_as_trainable(model, bias, saved_modules=None) -> None: +def mark_only_lora_as_trainable(model, bias="none") -> None: for n, p in model.named_parameters(): - p.requires_grad = "lora_" in n and not any([m in n for m in saved_modules]) + p.requires_grad = "lora_" in n + logger.info(f"Setting {n} to {'trainable' if p.requires_grad else 'frozen'}") if bias == "none": return elif bias == "all": @@ -73,6 +74,7 @@ def replace_with_lora( - model: The original model to be modified. - lora_modules: A list of module names that should be replaced with fairseq.modules.lora.Linear layers. - lora_params: A dictionary containing parameters for the fairseq.modules.lora.Linear layer. + - parent_module_name: The name of the parent module. Used for recursive calls. """ for name, module in model.named_children(): full_module_name = ( @@ -87,20 +89,19 @@ def replace_with_lora( module.in_features, module.out_features, **lora_params ) + # do this after the initialization always if module.weight is not None: with torch.no_grad(): new_module.weight.data = module.weight.data if module.bias is not None: new_module.bias.data = module.bias.data + logger.info(f"Replacing {full_module_name} with LoRALinear") setattr(model, name, new_module) elif isinstance(module, torch.nn.Embedding) and any( [m in full_module_name for m in lora_modules] ): - lora_params_emb = lora_params.copy() - lora_params_emb.pop("dropout", None) - new_module = LoRAEmbedding( num_embeddings=module.num_embeddings, embedding_dim=module.embedding_dim, @@ -109,13 +110,14 @@ def replace_with_lora( max_norm=module.max_norm, norm_type=module.norm_type, sparse=module.sparse, - **lora_params_emb, + **lora_params, ) if module.weight is not None: with torch.no_grad(): new_module.weight.data = module.weight.data + logger.info(f" | > Replacing {full_module_name} with LoRAEmbedding") setattr(model, name, new_module) else: @@ -243,16 +245,9 @@ def main(cfg: FairseqConfig) -> None: lora_modules = set(lora_config.get("target_modules", "").split(",")) # assert there are target modules specified for LoRA assert len(lora_modules) != [""], "No target modules specified for LoRA" - saved_modules = set(lora_config.get("saved_modules", "").split(",")) - # assert there are no common modules between saved_modules and lora_modules - assert len(lora_modules.intersection(saved_modules)) == 0, ( - "lora_modules and saved_modules cannot have common modules. " - "Please remove the following modules from either target_modules or saved_modules: " - f"{lora_modules.intersection(saved_modules)}" - ) lora_bias = lora_config.get("bias", "none") - replace_with_lora(model, lora_modules=lora_modules, lora_params=lora_params) - mark_only_lora_as_trainable(model, bias=lora_bias, saved_modules=saved_modules) + replace_with_lora(model, lora_modules, lora_params) + mark_only_lora_as_trainable(model, lora_bias) ### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ### logger.info( @@ -315,8 +310,8 @@ def main(cfg: FairseqConfig) -> None: # assign the teacher model (is present) to the trainer # we had to build the teacher model first before the student and trainer # to avoid over-writing the generator for beam-search of the student with that of the teacher - if (cfg.task._name == "translation_with_kd") and ( - cfg.criterion._name == "label_smoothed_cross_entropy_with_kd" + if (cfg.task._name == "seq2seq_lm_distillation") and ( + cfg.criterion._name == "seq2seq_lm_distillation" ): trainer.assign_teacher_model(teacher_model)