From 390bb36c6d5432dd62f95fb46f71a92932adb076 Mon Sep 17 00:00:00 2001 From: Varun Gumma Date: Sat, 3 Aug 2024 12:41:45 +0000 Subject: [PATCH] removed old implementations --- README.md | 13 ++++++++---- fairseq/dataclass/configs.py | 18 ++++++----------- .../models/transformer/transformer_base.py | 1 - .../models/transformer/transformer_config.py | 14 ------------- fairseq/modules/factorized_embedding.py | 4 +--- fairseq/options.py | 2 -- fairseq_cli/generate.py | 12 +++++++++-- fairseq_cli/interactive.py | 20 +++++++++---------- 8 files changed, 36 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 719af9fcd5..33d1f78983 100755 --- a/README.md +++ b/README.md @@ -21,17 +21,22 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, ` | **Name and Citation** | **Description** | **Flags to Activate** | **Source** | |-----------------------|-----------------------|-----------------------|------------| -| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $teacher_ckpt --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) | -| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - | +| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model. Please check [here](https://github.com/VarunGumma/fairseq/blob/main/fairseq/criterions/seq2seq_lm_distillation.py) for a detailed description of the arguments. | `--teacher-checkpoint-path $path --task seq2seq_lm_distillation --criterion seq2seq_lm_distillation --kd-args '{"strategy": "on_policy", "lambda": 1.0, "loss_type": "forward_kld"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) | +| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking 6 --decoder-recurrent-stacking 6` | - | | **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest. | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --attn-implementation fast --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) | | **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"theta": 10000, "use_xpos": false, "xpos_scale_base": 512}' --attn-implementation fast --no-token-positional-embeddings --load-checkpoint-liberally` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) | | **Gated Linear Unit (GLU)** ([Shazeer](https://arxiv.org/abs/2002.05202)) | A better Feed-Forward-Network variant | `--encoder-use-glu --decoder-use-glu` | [GLU Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) | | **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | An efficient normalization technique | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/rms_norm.py) | | **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--alibi-args '{"type": "symmetrical"}' --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) | -| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim $encoder_fac_embed_dim --decoder-factorized-embed-dim $decoder_fac_embed_dim --factorized-embed-activation-fn $fac_embed_activation_fn` | - | +| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim 128 --decoder-factorized-embed-dim 128` | - | | **Sanity Validation Steps** | Runs a full pass over the validation set at the beginning of training | `--run-sanity-validation-steps` | - | | **Efficient Multihead Attention (MHA)** | A [torch-functional variant](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of _MultiHeadAttention_ with a efficient context manager | `--attn-implementation fast`. By default, the value is `fairseq` | - | -| **Grouped Query Attention (GQA)** [Ainslie _et al._](https://aclanthology.org/2023.emnlp-main.298/) | Clusters queries into groups, allowing for more efficient computation and enhanced scalability in processing large sets of queries within transformer models. | `--attn-implementation fast_gqa` | [GQA Implementation](https://pytorch.org/torchtune/stable/_modules/torchtune/modules/attention.html) | +| **Grouped Query Attention (GQA)** [(Ainslie _et al._)](https://aclanthology.org/2023.emnlp-main.298/) | Clusters queries into groups, allowing for more efficient computation and enhanced scalability in processing large sets of queries within transformer models. | `--attn-implementation fast_gqa` | [GQA Implementation](https://pytorch.org/torchtune/stable/_modules/torchtune/modules/attention.html) | + + +## Upcoming features ($\alpha$-testing) +* `--bf16` has been decoupled from `--tpu`, and can be used independently to train the model with `bfloat16`. +* `--torch-compile $mode` can be used in the `interactive` and `generate` methods for faster inference. # Requirements and Installation diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 5a048afca2..582efdccbc 100755 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -1070,6 +1070,12 @@ class CommonEvalConfig(FairseqDataclass): results_path: Optional[str] = field( default=None, metadata={"help": "path to save eval results (optional)"} ) + torch_compile: Optional[str] = field( + default=None, + metadata={ + "help": "compile PyTorch model for faster execution", + }, + ) @dataclass @@ -1112,18 +1118,6 @@ class InteractiveConfig(FairseqDataclass): default="-", metadata={"help": "file to read from; use - for stdin"}, ) - torch_compile: Optional[str] = field( - default=None, - metadata={ - "help": "compile PyTorch model for faster execution", - }, - ) - force_override_max_positions: Optional[str] = field( - default=None, - metadata={ - "help": "force override the max_positions specified in the checkpoint. Should be a tuple of integers, ex. (2048, 2048), in the form of a string." - }, - ) @dataclass diff --git a/fairseq/models/transformer/transformer_base.py b/fairseq/models/transformer/transformer_base.py index 3b12c4414e..d9eb5560a2 100755 --- a/fairseq/models/transformer/transformer_base.py +++ b/fairseq/models/transformer/transformer_base.py @@ -154,7 +154,6 @@ def build_embedding( embed_dim, padding_idx=padding_idx, hid_dim=factorized_embed_dim, - activation=cfg.factorized_embed_activation_fn, ) else: emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx) diff --git a/fairseq/models/transformer/transformer_config.py b/fairseq/models/transformer/transformer_config.py index d30b2821e9..49f3364054 100755 --- a/fairseq/models/transformer/transformer_config.py +++ b/fairseq/models/transformer/transformer_config.py @@ -120,20 +120,6 @@ class QuantNoiseConfig(FairseqDataclass): @dataclass class TransformerConfig(FairseqDataclass): - ### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ### - adapter_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( - default="relu", metadata={"help": "activation function for adapters"} - ) - factorized_embed_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = ( - field( - default="linear", - metadata={ - "help": "activation function to use for the factorized embedding" - }, - ) - ) - ### EXPERIMENTAL :: NOT TO BE USED UNTIL TESTED ### - activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( default="relu", metadata={"help": "activation function to use"}, diff --git a/fairseq/modules/factorized_embedding.py b/fairseq/modules/factorized_embedding.py index 065aebd02e..b753e2e849 100755 --- a/fairseq/modules/factorized_embedding.py +++ b/fairseq/modules/factorized_embedding.py @@ -21,7 +21,6 @@ def __init__( hid_dim=128, padding_idx=1, bias=False, - activation="linear", ): super().__init__() self.embedding_dim = embedding_dim @@ -29,7 +28,6 @@ def __init__( self.up = nn.Linear(hid_dim, embedding_dim, bias=bias) self.emb = nn.Embedding(num_embeddings, hid_dim, padding_idx=padding_idx) - self.activation_fn = utils.get_activation_fn(activation=activation) def forward(self, x): - return self.up(self.activation_fn(self.emb(x))) + return self.up(self.emb(x)) diff --git a/fairseq/options.py b/fairseq/options.py index edd847b70d..c83c3971eb 100755 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -197,8 +197,6 @@ def parse_args_and_arch( args.bf16 = True args.tpu = getattr(args, "tpu", False) args.bf16 = getattr(args, "bf16", False) - if args.bf16: - args.tpu = True if args.tpu and args.fp16: raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 764148c622..9e7ca35788 100755 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -123,16 +123,24 @@ def _main(cfg: DictConfig, output_file): else: lms = [None] + compile_mode = getattr(cfg.common_eval, "torch_compile", None) + # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if cfg.common.fp16: - model.half() + model = model.half() + if cfg.common.bf16: + model = model.to(dtype=torch.bfloat16) if use_cuda and not cfg.distributed_training.pipeline_model_parallel: - model.cuda() + model = model.cuda() + model.prepare_for_inference_(cfg) + if compile_mode is not None: + model = torch.compile(model, mode=compile_mode) + model.eval() # Load alignment dictionary for unknown word replacement diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 25bd43b34b..641028a699 100755 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -157,20 +157,23 @@ def main(cfg: FairseqConfig): src_dict = task.source_dictionary tgt_dict = task.target_dictionary - compile_mode = getattr(cfg, "torch_compile", None) + compile_mode = getattr(cfg.common_eval, "torch_compile", None) # Optimize ensemble for generation for model in models: if model is None: continue if cfg.common.fp16: - model.half() + model = model.half() + if cfg.common.bf16: + model = model.to(torch.bfloat16) if use_cuda and not cfg.distributed_training.pipeline_model_parallel: - model.cuda() + model = model.cuda() + model.prepare_for_inference_(cfg) if compile_mode is not None: - model = torch.compile(model, model=compile_mode) + model = torch.compile(model, mode=compile_mode) model.eval() @@ -199,12 +202,9 @@ def decode_fn(x): # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) - if not getattr(cfg.interactive, "force_override_max_positions", False): - max_positions = utils.resolve_max_positions( - task.max_positions(), *[model.max_positions() for model in models] - ) - else: - max_positions = eval(cfg.interactive.force_override_max_positions) + max_positions = utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ) logger.info("Max positions: {}".format(max_positions))