Skip to content

Commit

Permalink
fixed lora bug
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jul 20, 2024
1 parent f08a220 commit f82a8de
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `
|-----------------------|-----------------------|-----------------------|------------|
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $teacher_ckpt --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest. | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--use-rope --attn-implementation fast --no-token-positional-embeddings --load-checkpoint-liberally` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) |
| **Gated Linear Unit (GLU)** ([Shazeer](https://arxiv.org/abs/2002.05202)) | A better Feed-Forward-Network variant | `--encoder-use-glu --decoder-use-glu` | [GLU Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) |
| **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | An efficient normalization technique | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/rms_norm.py) |
Expand Down
73 changes: 37 additions & 36 deletions fairseq/modules/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@
import torch.nn.functional as F


# source: https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
class LoRALayer:
def __init__(
self,
r: int,
alpha: int,
dropout: float,
rank_scaled: bool,
dropout: float = 0.0,
rank_scaled: bool = False,
fan_in: int = None,
fan_out: int = None,
):
assert r > 0, "Rank must be greater than 0 for LoRA to work."

assert (
r > 0
), "rank must be greater than 0 for LoRA to work. Use nn.Linear/nn.Embedding if rank is 0."

self.r = r
self.alpha = alpha
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
# Mark the weight as unmerged initially
self.merged = False
self.rank_scaled = rank_scaled
self.dropout_p = dropout
self.scaling = (alpha / math.sqrt(r)) if rank_scaled else (alpha / r)
# better to contain the A and B matrices in the same class
self.lora_A = nn.Parameter(torch.zeros((r, fan_in)))
self.lora_B = nn.Parameter(torch.zeros((fan_out, r)))


class LoRAEmbedding(nn.Embedding, LoRALayer):
Expand All @@ -34,26 +32,29 @@ def __init__(
embedding_dim: int,
r: int,
alpha: int = 1,
dropout: float = 0.0,
rank_scaled: bool = False,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(self, r=r, alpha=alpha, dropout=0, rank_scaled=rank_scaled)
# Actual trainable parameters
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
self.scaling = (
(self.alpha / math.sqrt(self.r)) if rank_scaled else (self.alpha / self.r)
LoRALayer.__init__(
self,
r=r,
alpha=alpha,
dropout=dropout,
ranked_scaled=rank_scaled,
fan_in=embedding_dim,
fan_out=num_embeddings,
)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()

def reset_parameters(self):
nn.Embedding.reset_parameters(self)
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
if hasattr(self, "lora_A"):
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)

def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
Expand All @@ -72,7 +73,7 @@ def forward(self, x: torch.Tensor):
result = nn.Embedding.forward(self, x)

if not self.merged:
after_A = F.embedding(
x = F.embedding(
x,
self.lora_A.T,
self.padding_idx,
Expand All @@ -81,7 +82,8 @@ def forward(self, x: torch.Tensor):
self.scale_grad_by_freq,
self.sparse,
)
result += (after_A @ self.lora_B.T) * self.scaling
x = F.dropout(x, p=self.dropout_p, training=self.training)
result += (x @ self.lora_B.T) * self.scaling

return result

Expand All @@ -98,26 +100,26 @@ def __init__(
rank_scaled: bool = False,
**kwargs
):

nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(
self, r=r, alpha=alpha, dropout=dropout, rank_scaled=rank_scaled
)
# Actual trainable parameters
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = (
(self.alpha / math.sqrt(self.r)) if rank_scaled else (self.alpha / self.r)
self,
r=r,
alpha=alpha,
dropout=dropout,
rank_scaled=rank_scaled,
fan_in=in_features,
fan_out=out_features,
)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()

def reset_parameters(self):
nn.Linear.reset_parameters(self)
# initialize B the same way as the default for nn.Linear and A to zero
# this is different than what is described in the paper but should not affect performance
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
if hasattr(self, "lora_A"):
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)

def train(self, mode: bool = True):
nn.Linear.train(self, mode)
Expand All @@ -137,8 +139,7 @@ def forward(self, x: torch.Tensor):

if not self.merged:
# if the weights are not merged, apply LoRA
if self.dropout is not None:
x = self.dropout(x)
x = F.dropout(x, p=self.dropout_p, training=self.training)
result += x @ ((self.lora_B @ self.lora_A).T * self.scaling)

return result
2 changes: 1 addition & 1 deletion fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class TranslationConfig(FairseqDataclass):
eval_bleu_args: Optional[str] = field(
default="{}",
metadata={
"help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
"help": 'generation args for BLEU scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
},
)
eval_bleu_detok: str = field(
Expand Down
31 changes: 13 additions & 18 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from fairseq.trainer import Trainer
from fairseq.checkpoint_utils import load_model_ensemble
from fairseq.modules.lora import LoRALinear, LoRALayer, LoRAEmbedding
from fairseq.modules.lora import *


def mark_only_lora_as_trainable(model, bias, saved_modules=None) -> None:
def mark_only_lora_as_trainable(model, bias="none") -> None:
for n, p in model.named_parameters():
p.requires_grad = "lora_" in n and not any([m in n for m in saved_modules])
p.requires_grad = "lora_" in n
logger.info(f"Setting {n} to {'trainable' if p.requires_grad else 'frozen'}")
if bias == "none":
return
elif bias == "all":
Expand All @@ -73,6 +74,7 @@ def replace_with_lora(
- model: The original model to be modified.
- lora_modules: A list of module names that should be replaced with fairseq.modules.lora.Linear layers.
- lora_params: A dictionary containing parameters for the fairseq.modules.lora.Linear layer.
- parent_module_name: The name of the parent module. Used for recursive calls.
"""
for name, module in model.named_children():
full_module_name = (
Expand All @@ -87,20 +89,19 @@ def replace_with_lora(
module.in_features, module.out_features, **lora_params
)

# do this after the initialization always
if module.weight is not None:
with torch.no_grad():
new_module.weight.data = module.weight.data
if module.bias is not None:
new_module.bias.data = module.bias.data

logger.info(f"Replacing {full_module_name} with LoRALinear")
setattr(model, name, new_module)

elif isinstance(module, torch.nn.Embedding) and any(
[m in full_module_name for m in lora_modules]
):
lora_params_emb = lora_params.copy()
lora_params_emb.pop("dropout", None)

new_module = LoRAEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
Expand All @@ -109,13 +110,14 @@ def replace_with_lora(
max_norm=module.max_norm,
norm_type=module.norm_type,
sparse=module.sparse,
**lora_params_emb,
**lora_params,
)

if module.weight is not None:
with torch.no_grad():
new_module.weight.data = module.weight.data

logger.info(f" | > Replacing {full_module_name} with LoRAEmbedding")
setattr(model, name, new_module)

else:
Expand Down Expand Up @@ -243,16 +245,9 @@ def main(cfg: FairseqConfig) -> None:
lora_modules = set(lora_config.get("target_modules", "").split(","))
# assert there are target modules specified for LoRA
assert len(lora_modules) != [""], "No target modules specified for LoRA"
saved_modules = set(lora_config.get("saved_modules", "").split(","))
# assert there are no common modules between saved_modules and lora_modules
assert len(lora_modules.intersection(saved_modules)) == 0, (
"lora_modules and saved_modules cannot have common modules. "
"Please remove the following modules from either target_modules or saved_modules: "
f"{lora_modules.intersection(saved_modules)}"
)
lora_bias = lora_config.get("bias", "none")
replace_with_lora(model, lora_modules=lora_modules, lora_params=lora_params)
mark_only_lora_as_trainable(model, bias=lora_bias, saved_modules=saved_modules)
replace_with_lora(model, lora_modules, lora_params)
mark_only_lora_as_trainable(model, lora_bias)
### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ###

logger.info(
Expand Down Expand Up @@ -315,8 +310,8 @@ def main(cfg: FairseqConfig) -> None:
# assign the teacher model (is present) to the trainer
# we had to build the teacher model first before the student and trainer
# to avoid over-writing the generator for beam-search of the student with that of the teacher
if (cfg.task._name == "translation_with_kd") and (
cfg.criterion._name == "label_smoothed_cross_entropy_with_kd"
if (cfg.task._name == "seq2seq_lm_distillation") and (
cfg.criterion._name == "seq2seq_lm_distillation"
):
trainer.assign_teacher_model(teacher_model)

Expand Down

0 comments on commit f82a8de

Please sign in to comment.