Skip to content

Commit

Permalink
removed old implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Aug 3, 2024
1 parent dbb185f commit 390bb36
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 48 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,22 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `

| **Name and Citation** | **Description** | **Flags to Activate** | **Source** |
|-----------------------|-----------------------|-----------------------|------------|
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $teacher_ckpt --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - |
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $path --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking 6 --decoder-recurrent-stacking 6` | - |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest. | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"theta": 10000, "use_xpos": false, "xpos_scale_base": 512}' --attn-implementation fast --no-token-positional-embeddings --load-checkpoint-liberally` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) |
| **Gated Linear Unit (GLU)** ([Shazeer](https://arxiv.org/abs/2002.05202)) | A better Feed-Forward-Network variant | `--encoder-use-glu --decoder-use-glu` | [GLU Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) |
| **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | An efficient normalization technique | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/rms_norm.py) |
| **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--alibi-args '{"type": "symmetrical"}' --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) |
| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim $encoder_fac_embed_dim --decoder-factorized-embed-dim $decoder_fac_embed_dim --factorized-embed-activation-fn $fac_embed_activation_fn` | - |
| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim 128 --decoder-factorized-embed-dim 128` | - |
| **Sanity Validation Steps** | Runs a full pass over the validation set at the beginning of training | `--run-sanity-validation-steps` | - |
| **Efficient Multihead Attention (MHA)** | A [torch-functional variant](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of _MultiHeadAttention_ with a efficient context manager | `--attn-implementation fast`. By default, the value is `fairseq` | - |
| **Grouped Query Attention (GQA)** [Ainslie _et al._](https://aclanthology.org/2023.emnlp-main.298/) | Clusters queries into groups, allowing for more efficient computation and enhanced scalability in processing large sets of queries within transformer models. | `--attn-implementation fast_gqa` | [GQA Implementation](https://pytorch.org/torchtune/stable/_modules/torchtune/modules/attention.html) |
| **Grouped Query Attention (GQA)** [(Ainslie _et al._)](https://aclanthology.org/2023.emnlp-main.298/) | Clusters queries into groups, allowing for more efficient computation and enhanced scalability in processing large sets of queries within transformer models. | `--attn-implementation fast_gqa` | [GQA Implementation](https://pytorch.org/torchtune/stable/_modules/torchtune/modules/attention.html) |


## Upcoming features ($\alpha$-testing)
* `--bf16` has been decoupled from `--tpu`, and can be used independently to train the model with `bfloat16`.
* `--torch-compile $mode` can be used in the `interactive` and `generate` methods for faster inference.


# Requirements and Installation
Expand Down
18 changes: 6 additions & 12 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,12 @@ class CommonEvalConfig(FairseqDataclass):
results_path: Optional[str] = field(
default=None, metadata={"help": "path to save eval results (optional)"}
)
torch_compile: Optional[str] = field(
default=None,
metadata={
"help": "compile PyTorch model for faster execution",
},
)


@dataclass
Expand Down Expand Up @@ -1112,18 +1118,6 @@ class InteractiveConfig(FairseqDataclass):
default="-",
metadata={"help": "file to read from; use - for stdin"},
)
torch_compile: Optional[str] = field(
default=None,
metadata={
"help": "compile PyTorch model for faster execution",
},
)
force_override_max_positions: Optional[str] = field(
default=None,
metadata={
"help": "force override the max_positions specified in the checkpoint. Should be a tuple of integers, ex. (2048, 2048), in the form of a string."
},
)


@dataclass
Expand Down
1 change: 0 additions & 1 deletion fairseq/models/transformer/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def build_embedding(
embed_dim,
padding_idx=padding_idx,
hid_dim=factorized_embed_dim,
activation=cfg.factorized_embed_activation_fn,
)
else:
emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx)
Expand Down
14 changes: 0 additions & 14 deletions fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,6 @@ class QuantNoiseConfig(FairseqDataclass):

@dataclass
class TransformerConfig(FairseqDataclass):
### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ###
adapter_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function for adapters"}
)
factorized_embed_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = (
field(
default="linear",
metadata={
"help": "activation function to use for the factorized embedding"
},
)
)
### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ###

activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu",
metadata={"help": "activation function to use"},
Expand Down
4 changes: 1 addition & 3 deletions fairseq/modules/factorized_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@ def __init__(
hid_dim=128,
padding_idx=1,
bias=False,
activation="linear",
):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx

self.up = nn.Linear(hid_dim, embedding_dim, bias=bias)
self.emb = nn.Embedding(num_embeddings, hid_dim, padding_idx=padding_idx)
self.activation_fn = utils.get_activation_fn(activation=activation)

def forward(self, x):
return self.up(self.activation_fn(self.emb(x)))
return self.up(self.emb(x))
2 changes: 0 additions & 2 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ def parse_args_and_arch(
args.bf16 = True
args.tpu = getattr(args, "tpu", False)
args.bf16 = getattr(args, "bf16", False)
if args.bf16:
args.tpu = True
if args.tpu and args.fp16:
raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs")

Expand Down
12 changes: 10 additions & 2 deletions fairseq_cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,24 @@ def _main(cfg: DictConfig, output_file):
else:
lms = [None]

compile_mode = getattr(cfg.common_eval, "torch_compile", None)

# Optimize ensemble for generation
for model in chain(models, lms):
if model is None:
continue
if cfg.common.fp16:
model.half()
model = model.half()
if cfg.common.bf16:
model = model.to(dtype=torch.bfloat16)
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model = model.cuda()

model.prepare_for_inference_(cfg)

if compile_mode is not None:
model = torch.compile(model, mode=compile_mode)

model.eval()

# Load alignment dictionary for unknown word replacement
Expand Down
20 changes: 10 additions & 10 deletions fairseq_cli/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,23 @@ def main(cfg: FairseqConfig):
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary

compile_mode = getattr(cfg, "torch_compile", None)
compile_mode = getattr(cfg.common_eval, "torch_compile", None)

# Optimize ensemble for generation
for model in models:
if model is None:
continue
if cfg.common.fp16:
model.half()
model = model.half()
if cfg.common.bf16:
model = model.to(torch.bfloat16)
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model = model.cuda()

model.prepare_for_inference_(cfg)

if compile_mode is not None:
model = torch.compile(model, model=compile_mode)
model = torch.compile(model, mode=compile_mode)

model.eval()

Expand Down Expand Up @@ -199,12 +202,9 @@ def decode_fn(x):
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(cfg.generation.replace_unk)

if not getattr(cfg.interactive, "force_override_max_positions", False):
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
else:
max_positions = eval(cfg.interactive.force_override_max_positions)
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)

logger.info("Max positions: {}".format(max_positions))

Expand Down

0 comments on commit 390bb36

Please sign in to comment.