Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Aug 27, 2024
1 parent caeb62e commit 50e56c9
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 30 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `

| **Name and Citation** | **Description** | **Flags to Activate** | **Source** |
|-----------------------|-----------------------|-----------------------|------------|
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $path --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $path --task seq2seq_lm_distillation --criterion lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking 6 --decoder-recurrent-stacking 6` | - |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest. | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"theta": 10000, "use_xpos": false, "xpos_scale_base": 512}' --attn-implementation fast --no-token-positional-embeddings --load-checkpoint-liberally` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) |
Expand All @@ -28,10 +28,11 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `
| **Sanity Validation Steps** | Runs a full pass over the validation set at the beginning of training | `--run-sanity-validation-steps` | - |
| **Efficient Multihead Attention (MHA)** | A [torch-functional variant](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of _MultiHeadAttention_ | `--attn-implementation fast`. By default, the value is `fairseq` | - |
| **Grouped Query Attention (GQA)** ([Ainslie _et al._](https://aclanthology.org/2023.emnlp-main.298/)) | Clusters queries into groups, allowing for more efficient computation and enhanced scalability in processing large sets of queries within transformer models. | `--attn-implementation fast_gqa` | [GQA Implementation](https://pytorch.org/torchtune/stable/_modules/torchtune/modules/attention.html) |
| **Fused Attention** | Combines the efficiency of Multi-Head Attention (MHA) and Grouped Query Attention (GQA) into a fused operation, providing faster computation. This fused version is not compatible with models trained using the unfused versions of MHA or GQA. | `--attn-implementation fast_fused` or `--attn-implementation fast_gqa_fused` | [Fused Implementation](https://pytorch.org/torchtune/stable/_modules/torchtune/modules/attention.html) |



## Upcoming features ($\alpha$-testing)
* `fused` version of **MHA** (`fast_fused`) and **GQA** (`fast_gqa_fused`) for faster computation. Note, this cannot be used with models that were trained with the _un-fused_ version.
* `--bf16` has been decoupled from `--tpu`, and can be used independently to train the model with `bfloat16`.
* `--torch-compile $mode` can be used in the `interactive` and `generate` methods for faster inference.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


@dataclass
class Seq2SeqLMDistillationCriterionConfig(CrossEntropyCriterionConfig):
class LMDistillationCriterionConfig(CrossEntropyCriterionConfig):
kd_args: Optional[str] = field(
default=None,
metadata={"help": "arguments for knowledge distillation (kd_strategy)"},
Expand All @@ -37,10 +37,10 @@ class Seq2SeqLMDistillationCriterionConfig(CrossEntropyCriterionConfig):


@register_criterion(
"seq2seq_lm_distillation",
dataclass=Seq2SeqLMDistillationCriterionConfig,
"lm_distillation_loss",
dataclass=LMDistillationCriterionConfig,
)
class Seq2SeqLMDistillationCriterion(CrossEntropyCriterion):
class LMDistillationCriterion(CrossEntropyCriterion):
def __init__(
self,
task,
Expand Down Expand Up @@ -157,12 +157,12 @@ def compute_kd_loss(self, model, net_output, sample, teacher_model, teacher_outp
teacher_model, teacher_output, sample, log_probs=False
)

m = self.beta * probs + (1 - self.beta) * teacher_probs
m_log = torch.log(self.beta * probs + (1 - self.beta) * teacher_probs)

kd_loss = self.beta * F.kl_div(
m.log(), lprobs, log_target=True, reduction="none"
m_log, lprobs, log_target=True, reduction="none"
) + (1 - self.beta) * F.kl_div(
m.log(), teacher_lprobs, log_target=True, reduction="none"
m_log, teacher_lprobs, log_target=True, reduction="none"
)
kd_loss = kd_loss.masked_fill_(pad_mask, 0.0).sum()
else:
Expand Down
6 changes: 1 addition & 5 deletions fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,7 @@ def __init__(
self.build_output_projection(cfg, dictionary, embed_tokens)

def build_normalization(self, dim, rms=False):
return (
LayerNorm(dim, export=self.cfg.export)
if not rms
else RMSNorm(dim, export=self.cfg.export)
)
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)

def build_output_projection(self, cfg, dictionary, embed_tokens):
if cfg.adaptive_softmax_cutoff is not None:
Expand Down
6 changes: 1 addition & 5 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,7 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
self.alibi = None

def build_normalization(self, dim, rms=False):
return (
LayerNorm(dim, export=self.cfg.export)
if not rms
else RMSNorm(dim, export=self.cfg.export)
)
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)

def build_encoder_layer(self, cfg):
layer = transformer_layer.TransformerEncoderLayerBase(
Expand Down
16 changes: 10 additions & 6 deletions fairseq/modules/factorized_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch.nn as nn
from fairseq import utils


class FactorizedEmbedding(nn.Module):
Expand All @@ -20,14 +19,19 @@ def __init__(
embedding_dim,
hid_dim=128,
padding_idx=1,
bias=False,
):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embedding_dim = embedding_dim

self.up = nn.Linear(hid_dim, embedding_dim, bias=False)
self.m = nn.Embedding(num_embeddings, hid_dim, padding_idx=padding_idx)

self.reset_parameters()

self.up = nn.Linear(hid_dim, embedding_dim, bias=bias)
self.emb = nn.Embedding(num_embeddings, hid_dim, padding_idx=padding_idx)
def reset_parameters(self):
nn.init.normal_(self.m.weight, mean=0, std=self.embedding_dim**-0.5)
nn.init.constant_(self.m.weight[self.padding_idx], 0)

def forward(self, x):
return self.up(self.emb(x))
return self.up(self.m(x))
3 changes: 3 additions & 0 deletions fairseq/modules/fairseq_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, p, module_name=None):
self.module_name = module_name
self.apply_during_inference = False

def extra_repr(self):
return f"p={self.p}"

def forward(self, x, inplace: bool = False):
if self.p > 0 and (self.training or self.apply_during_inference):
return F.dropout(x, p=self.p, training=True, inplace=inplace)
Expand Down
4 changes: 2 additions & 2 deletions fairseq/modules/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
alpha: int = 1,
dropout: float = 0.0,
rank_scaled: bool = False,
**kwargs
**kwargs,
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
alpha: int = 1,
dropout: float = 0.0,
rank_scaled: bool = False,
**kwargs
**kwargs,
):

nn.Linear.__init__(self, in_features, out_features, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion fairseq/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.13.0
0.13.1
4 changes: 2 additions & 2 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def main(cfg: FairseqConfig) -> None:

# build teacher model here
if (cfg.task._name == "seq2seq_lm_distillation") and (
cfg.criterion._name == "seq2seq_lm_distillation"
cfg.criterion._name == "lm_distillation_loss"
):
logging.info("Building teacher model")

Expand Down Expand Up @@ -311,7 +311,7 @@ def main(cfg: FairseqConfig) -> None:
# we had to build the teacher model first before the student and trainer
# to avoid over-writing the generator for beam-search of the student with that of the teacher
if (cfg.task._name == "seq2seq_lm_distillation") and (
cfg.criterion._name == "seq2seq_lm_distillation"
cfg.criterion._name == "lm_distillation_loss"
):
trainer.assign_teacher_model(teacher_model)

Expand Down

0 comments on commit 50e56c9

Please sign in to comment.