Skip to content

Commit

Permalink
final RoPE fix, removed xPOS, and autoflaked
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jun 28, 2024
1 parent d468f57 commit 9c29523
Show file tree
Hide file tree
Showing 134 changed files with 122 additions and 433 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model | `--teacher-checkpoint-path $teacher_ckpt --task translation_with_kd --criterion label_smoothed_cross_entropy_with_kd --kd-args '{"strategy": "word_level"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --use-native-attention --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)), **Extrapolatable Position Embedding (xPOS)** ([Sun _et al_.](https://aclanthology.org/2023.acl-long.816/)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"base": 10000, "use_xpos": false}' --use-native-attention --no-token-positional-embeddings` | [RoPE/xPOS Implementation](https://github.com/microsoft/torchscale/blob/main/torchscale/component/xpos_relative_position.py) |
| **Gated FC** | Add a gating module to the Fully Connected layers in a Transformer | `--encoder-use-gated-fc --decoder-use-gated-fc` | [Gated FC Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) |
| **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | Use RMSNorm instead if LayerNorm in a Transformer | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/meta-llama/llama/blob/main/llama/model.py#L34) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"base": 10000, "learned_freq": false}' --use-native-attention --no-token-positional-embeddings` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) |
| **Gated FFN** ([Shazeer](https://arxiv.org/abs/2002.05202)) | Add a gating module to the Feed-Forward Module in a Transformer | `--encoder-use-gated-ffn --decoder-use-gated-ffn` | [Gated FFN Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) |
| **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | Use RMSNorm instead if LayerNorm in a Transformer | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/rms_norm.py) |
| **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--alibi-args '{"alibi_asymmetrical": "false"}' --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) |
| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim $encoder_fac_embed_dim --decoder-factorized-embed-dim $decoder_fac_embed_dim --factorized-embed-activation-fn $fac_embed_activation_fn` | - |
| **Penultimate Linear Transformation Activation** | Adds activation to the penultimate linear transformation before the final projection onto the vocabulary | `--decoder-output-activation-fn $decoder_out_activation_fn` | - |
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/MMPT/mmpt/modules/retri.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections import defaultdict

from ..utils import get_local_rank, print_on_rank0
from ..utils import get_local_rank


class VectorRetriever(object):
Expand Down
1 change: 0 additions & 1 deletion docs/examples/MMPT/mmpt/processors/models/s3dg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch as th
import torch.nn.functional as F
import torch.nn as nn
import os
import numpy as np
import re

Expand Down
2 changes: 0 additions & 2 deletions docs/examples/MMPT/mmpt/utils/shardedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import pickle
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion docs/examples/MMPT/mmpt_cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from mmpt.utils import load_config, set_seed
from mmpt.utils import load_config
from mmpt.evaluators import Evaluator
from mmpt.evaluators import predictor as predictor_path
from mmpt.tasks import Task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import urllib.parse
import json
import pandas as pd

from tqdm import tqdm

Expand Down
1 change: 0 additions & 1 deletion docs/examples/adaptive_span/adaptive_span_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math
from dataclasses import dataclass

import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass

import torch
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
from fairseq.data.audio.data_cfg import S2TDataConfig
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDatasetItem,
Expand Down
1 change: 0 additions & 1 deletion docs/examples/byte_level_bpe/gru_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerEncoder, TransformerModel

Expand Down
1 change: 0 additions & 1 deletion docs/examples/hubert/simple_kmeans/dump_hubert_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sys

import fairseq
import soundfile as sf
import torch
import torch.nn.functional as F

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import csv
import io
import logging
import os
import os.path as op
Expand Down
1 change: 0 additions & 1 deletion docs/examples/hubert/simple_kmeans/dump_mfcc_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import sys

import soundfile as sf
import torch
import torchaudio

Expand Down
1 change: 0 additions & 1 deletion docs/examples/multilingual/data_scripts/binarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import glob
import argparse
import shutil
import pathlib
import itertools


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os, sys
import subprocess
import re
from subprocess import check_call, check_output
from subprocess import check_output

WORKDIR_ROOT = os.environ.get("WORKDIR_ROOT", None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import glob
import argparse
from utils.dedup import deup
import sys

WORKDIR_ROOT = os.environ.get("WORKDIR_ROOT", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import argparse
import pandas as pd
import sys


Expand Down Expand Up @@ -124,7 +123,6 @@ def main():
directions = args.directions.split(",")
directions = sorted(set(directions))

results = []
# print(f'checking where {args.split} split data are in training')
# print(f'direction\tcommon_count\tsrc common\ttgt common\tfrom_size\tto_size')
raw_data = args.folder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import wget
import sys

from subprocess import check_call, check_output
from subprocess import check_call

# scripts and data locations
CWD = os.getcwd()
Expand Down Expand Up @@ -342,7 +342,6 @@ def download_and_extract(download_to, extract_to):
args = parser.parse_args()

import sys
import json

# TED Talks data directory
ted_data_path = args.ted_data_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import wget
import re
import multiprocessing as mp
from functools import partial
import pathlib
from collections import OrderedDict

Expand Down Expand Up @@ -471,8 +470,6 @@ def concat_into_splits(dl_dataset, src, tgt, extracted_folders, to_folder, debug

def download_multi(dl_folder, extract_folder, urls, num_processes=8, debug=False):
pool = mp.Pool(processes=num_processes)
download_f = partial(download_a_url, dl_folder)
downloaded_files = pool.imap_unordered(download_f, urls)
pool.close()
pool.join()

Expand Down
1 change: 0 additions & 1 deletion docs/examples/paraphraser/paraphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import fileinput
import logging
import os
import sys

from fairseq.models.transformer import TransformerModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from torch import Tensor

from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
prob_check,
moving_sum,
)
Expand Down
2 changes: 0 additions & 2 deletions docs/examples/speech_recognition/kaldi/kaldi_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
from omegaconf import MISSING
import os
import torch
from typing import Optional
import warnings

Expand Down Expand Up @@ -56,7 +55,6 @@ def __init__(
):
try:
from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer
from kaldi.base import set_verbose_level
from kaldi.decoder import (
FasterDecoder,
FasterDecoderOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import re
import sys

import torch
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
from fairseq.data import Dictionary
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/speech_recognition/w2l_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
from omegaconf import open_dict
Expand Down Expand Up @@ -100,7 +99,6 @@ def __init__(self, args, tgt_dict):

def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
Expand Down Expand Up @@ -161,7 +159,7 @@ def __init__(self, args, tgt_dict):
)

if self.asg_transitions is None:
N = 768
pass
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
self.asg_transitions = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
"""

import math
import sys
from fractions import Fraction
import warnings
from collections import Counter
from nltk.translate.bleu_score import (
modified_precision,
Expand Down
1 change: 0 additions & 1 deletion docs/examples/textless_nlp/gslm/unit2speech/glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import copy
import torch
from torch.autograd import Variable
import torch.nn.functional as F
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import soundfile as sf
import time
import torch
from scipy.io.wavfile import read
from .text import SOS_TOK, EOS_TOK


Expand Down
5 changes: 1 addition & 4 deletions docs/examples/wav2vec/unsupervised/models/wav2vec_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn.functional as F
from torch import autograd

from fairseq import checkpoint_utils, utils
from fairseq import utils
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.modules import (
Expand Down Expand Up @@ -568,8 +568,6 @@ def forward(
if segment:
features, padding_mask = self.segmenter.pre_segment(features, padding_mask)

orig_size = features.size(0) * features.size(1) - padding_mask.sum()

gen_result = self.generator(features, random_label, padding_mask)

orig_dense_x, token_x = gen_result["dense_x"], gen_result["token_x"]
Expand Down Expand Up @@ -610,7 +608,6 @@ def forward(
if self.smoothing_one_sided:
fake_smooth = 0

zero_loss = None
smoothness_loss = None
code_pen = None
mmi_loss = None
Expand Down
2 changes: 1 addition & 1 deletion examples/MMPT/mmpt/modules/retri.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections import defaultdict

from ..utils import get_local_rank, print_on_rank0
from ..utils import get_local_rank


class VectorRetriever(object):
Expand Down
1 change: 0 additions & 1 deletion examples/MMPT/mmpt/processors/models/s3dg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch as th
import torch.nn.functional as F
import torch.nn as nn
import os
import numpy as np
import re

Expand Down
2 changes: 0 additions & 2 deletions examples/MMPT/mmpt/utils/shardedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import pickle
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion examples/MMPT/mmpt_cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from mmpt.utils import load_config, set_seed
from mmpt.utils import load_config
from mmpt.evaluators import Evaluator
from mmpt.evaluators import predictor as predictor_path
from mmpt.tasks import Task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import urllib.parse
import json
import pandas as pd

from tqdm import tqdm

Expand Down
1 change: 0 additions & 1 deletion examples/adaptive_span/adaptive_span_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math
from dataclasses import dataclass

import torch.nn.functional as F
from fairseq import utils
from fairseq.logging import metrics
from fairseq.criterions import register_criterion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass

import torch
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
from fairseq.data.audio.data_cfg import S2TDataConfig
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDatasetItem,
Expand Down
2 changes: 1 addition & 1 deletion examples/audio_nlp/nlu/generate_manifests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def get_insl_frame(parse):
out = []
pass

def is_ont_token(tok):
return tok[0] in ["[", "]"]
Expand Down
1 change: 0 additions & 1 deletion examples/byte_level_bpe/gru_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerEncoder, TransformerModel

Expand Down
2 changes: 1 addition & 1 deletion examples/data2vec/data/add_class_target_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from fairseq.data import BaseWrapperDataset, data_utils
from fairseq.data import BaseWrapperDataset


class AddClassTargetDataset(BaseWrapperDataset):
Expand Down
1 change: 0 additions & 1 deletion examples/data2vec/models/data2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def forward(
encoder_mask,
)
xs.append(dx)
orig_x = x

assert len(xs) > 0

Expand Down
Loading

0 comments on commit 9c29523

Please sign in to comment.