From 61a2110ed32b53eceeaa6308397029320905c20b Mon Sep 17 00:00:00 2001 From: Varun Gumma Date: Tue, 10 Sep 2024 10:12:26 +0000 Subject: [PATCH] bug fixes --- .../models/transformer/transformer_config.py | 4 ++ .../models/transformer/transformer_decoder.py | 6 +- .../models/transformer/transformer_encoder.py | 9 ++- .../models/transformer/transformer_legacy.py | 18 ++++++ fairseq/modules/__init__.py | 4 +- fairseq/modules/rms_norm.py | 55 ++++++++++++------- fairseq/modules/transformer_layer.py | 26 ++++++--- 7 files changed, 87 insertions(+), 35 deletions(-) mode change 100644 => 100755 fairseq/modules/rms_norm.py diff --git a/fairseq/models/transformer/transformer_config.py b/fairseq/models/transformer/transformer_config.py index 0978394033..6e6c8a6035 100755 --- a/fairseq/models/transformer/transformer_config.py +++ b/fairseq/models/transformer/transformer_config.py @@ -141,6 +141,10 @@ class TransformerConfig(FairseqDataclass): "help": "dropout probability after activation in FFN.", }, ) + scale_resids: bool = field( + default=False, + metadata={"help": "scale the residuals in the transformer modules (fine to set False as default)"} + ) adaptive_input: bool = False encoder: EncDecBaseConfig = field(default=EncDecBaseConfig) # TODO should really be in the encoder config diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py index d8acc4d50e..6f5874bcbc 100755 --- a/fairseq/models/transformer/transformer_decoder.py +++ b/fairseq/models/transformer/transformer_decoder.py @@ -164,7 +164,11 @@ def __init__( self.build_output_projection(cfg, dictionary, embed_tokens) def build_normalization(self, dim, rms=False): - return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim) + return ( + LayerNorm(dim, export=self.cfg.export) + if not rms + else RMSNorm(dim, export=self.cfg.export) + ) def build_output_projection(self, cfg, dictionary, embed_tokens): if cfg.adaptive_softmax_cutoff is not None: diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py index c279044659..9b8f7cc758 100755 --- a/fairseq/models/transformer/transformer_encoder.py +++ b/fairseq/models/transformer/transformer_encoder.py @@ -18,6 +18,7 @@ FairseqDropout, LayerDropModuleList, LayerNorm, + RMSNorm, PositionalEmbedding, transformer_layer, ) @@ -25,8 +26,6 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ -from fairseq.modules.rms_norm import RMSNorm - device = "cuda" if torch.cuda.is_available() else "cpu" @@ -130,7 +129,11 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): self.alibi = None def build_normalization(self, dim, rms=False): - return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim) + return ( + LayerNorm(dim, export=self.cfg.export) + if not rms + else RMSNorm(dim, export=self.cfg.export) + ) def build_encoder_layer(self, cfg): layer = transformer_layer.TransformerEncoderLayerBase( diff --git a/fairseq/models/transformer/transformer_legacy.py b/fairseq/models/transformer/transformer_legacy.py index 56a7c894c2..390dce9313 100755 --- a/fairseq/models/transformer/transformer_legacy.py +++ b/fairseq/models/transformer/transformer_legacy.py @@ -311,6 +311,14 @@ def transformer_IT2_dist(args): base_architecture(args) +@register_model_architecture("transformer", "roformer_IT2_dist") +def roformer_IT2_dist(args): + args.attn_implementation = getattr(args, "attn_implementation", "fast") + args.no_token_positional_embeddings = getattr(args, "no_token_positional_embeddings", True) + args.rope_args = getattr(args, "rope_args", '{"theta": 10000}') + transformer_IT2_dist(args) + + @register_model_architecture("transformer", "transformer_IT2") def transformer_IT2(args): args.activation_fn = getattr(args, "activation_fn", "gelu") @@ -326,3 +334,13 @@ def transformer_IT2(args): args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) base_architecture(args) + + +@register_model_architecture("transformer", "roformer_IT2") +def roformer_IT2(args): + args.attn_implementation = getattr(args, "attn_implementation", "fast") + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", True + ) + args.rope_args = getattr(args, "rope_args", '{"theta": 10000}') + transformer_IT2(args) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index ffaf95ba0a..30c35551ba 100755 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -25,6 +25,7 @@ from .gumbel_vector_quantizer import GumbelVectorQuantizer from .kmeans_vector_quantizer import KmeansVectorQuantizer from .layer_drop import LayerDropModuleList +from .rms_norm import RMSNorm from .layer_norm import Fp32LayerNorm, LayerNorm from .learned_positional_embedding import LearnedPositionalEmbedding from .lightweight_convolution import LightweightConv, LightweightConv1dTBC @@ -36,7 +37,6 @@ from .fast_multihead_attention import FastMultiheadAttention from .fast_grouped_query_attention import FastGroupedQueryAttention from .positional_embedding import PositionalEmbedding -from .rms_norm import RMSNorm from .same_pad import SamePad, SamePad2d from .scalar_bias import ScalarBias from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding @@ -85,6 +85,7 @@ "KmeansVectorQuantizer", "LayerDropModuleList", "LayerNorm", + "RMSNorm", "LearnedPositionalEmbedding", "LightweightConv1dTBC", "LightweightConv", @@ -110,7 +111,6 @@ "RelPositionMultiHeadedAttention", "RelPositionalEncoding", "RotaryPositionMultiHeadedAttention", - "RMSNorm", "GLU", "MLP", ] diff --git a/fairseq/modules/rms_norm.py b/fairseq/modules/rms_norm.py old mode 100644 new mode 100755 index 4364a0da91..d647dfe65d --- a/fairseq/modules/rms_norm.py +++ b/fairseq/modules/rms_norm.py @@ -1,20 +1,35 @@ -import torch -from torch import nn - - -class RMSNorm(nn.Module): - def __init__(self, normalized_shape, eps=1e-6): - super().__init__() - self.eps = eps - self.normalized_shape = normalized_shape - self.scale = nn.Parameter(torch.ones(normalized_shape)) - - def forward(self, x): - x_fp32 = x.float() - x_normed = ( - x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) - ).type_as(x) - return x_normed * self.scale - - def extra_repr(self): - return f"normalized_shape={self.normalized_shape}, eps={self.eps}" +import torch +import torch.nn as nn + +try: + from apex.normalization import FusedRMSNorm as _FusedRMSNorm + + has_fused_rmsnorm = True + + class FusedRMSNorm(_FusedRMSNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_rmsnorm = False + + +def RMSNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + export = True + if not export and torch.cuda.is_available() and has_fused_rmsnorm: + return FusedRMSNorm( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + ) + return nn.RMSNorm( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + ) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 1b33010aba..0abcb0afaa 100755 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -22,7 +22,7 @@ FastGroupedQueryAttention, ) -# This module does not support `scale_attn`, `scale_heads`, `scale_fc`, `scale_resids` anymore +# BUG: This module does not support `scale_attn`, `scale_heads`, `scale_fc`, `scale_resids` anymore class TransformerEncoderLayerBase(nn.Module): @@ -190,10 +190,14 @@ def build_self_attention(self, embed_dim, cfg): ) def build_normalization(self, dim, rms=False): - return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim) + return ( + LayerNorm(dim, export=self.cfg.export) + if not rms + else RMSNorm(dim, export=self.cfg.export) + ) - def residual_connection(self, x, residual): - return residual + x + def residual_connection(self, x, residual, scale_resid=None): + return residual + x if scale_resid is None else (scale_resid * residual + x) def upgrade_state_dict_named(self, state_dict, name): """ @@ -260,7 +264,7 @@ def forward( attn_mask=attn_mask, ) x = self.dropout_module(x) - x = self.residual_connection(x, residual) + x = self.residual_connection(x, residual, scale_resid=self.sa_scale_resid) if not self.normalize_before: x = self.self_attn_layer_norm(x) @@ -279,7 +283,7 @@ def forward( fc_result = x x = self.dropout_module(x) - x = self.residual_connection(x, residual) + x = self.residual_connection(x, residual, scale_resid=self.ffn_scale_resid) if not self.normalize_before: x = self.final_layer_norm(x) @@ -494,11 +498,15 @@ def build_encoder_attention(self, embed_dim, cfg): def prepare_for_onnx_export_(self): self.onnx_trace = True - def residual_connection(self, x, residual): - return residual + x + def residual_connection(self, x, residual, scale_resid=None): + return residual + x if scale_resid is None else (scale_resid * residual + x) def build_normalization(self, dim, rms=False): - return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim) + return ( + LayerNorm(dim, export=self.cfg.export) + if not rms + else RMSNorm(dim, export=self.cfg.export) + ) def forward( self,