Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Sep 10, 2024
1 parent 50e56c9 commit 61a2110
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 35 deletions.
4 changes: 4 additions & 0 deletions fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ class TransformerConfig(FairseqDataclass):
"help": "dropout probability after activation in FFN.",
},
)
scale_resids: bool = field(
default=False,
metadata={"help": "scale the residuals in the transformer modules (fine to set False as default)"}
)
adaptive_input: bool = False
encoder: EncDecBaseConfig = field(default=EncDecBaseConfig)
# TODO should really be in the encoder config
Expand Down
6 changes: 5 additions & 1 deletion fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def __init__(
self.build_output_projection(cfg, dictionary, embed_tokens)

def build_normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)
return (
LayerNorm(dim, export=self.cfg.export)
if not rms
else RMSNorm(dim, export=self.cfg.export)
)

def build_output_projection(self, cfg, dictionary, embed_tokens):
if cfg.adaptive_softmax_cutoff is not None:
Expand Down
9 changes: 6 additions & 3 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
FairseqDropout,
LayerDropModuleList,
LayerNorm,
RMSNorm,
PositionalEmbedding,
transformer_layer,
)

from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_

from fairseq.modules.rms_norm import RMSNorm

device = "cuda" if torch.cuda.is_available() else "cpu"


Expand Down Expand Up @@ -130,7 +129,11 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
self.alibi = None

def build_normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)
return (
LayerNorm(dim, export=self.cfg.export)
if not rms
else RMSNorm(dim, export=self.cfg.export)
)

def build_encoder_layer(self, cfg):
layer = transformer_layer.TransformerEncoderLayerBase(
Expand Down
18 changes: 18 additions & 0 deletions fairseq/models/transformer/transformer_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ def transformer_IT2_dist(args):
base_architecture(args)


@register_model_architecture("transformer", "roformer_IT2_dist")
def roformer_IT2_dist(args):
args.attn_implementation = getattr(args, "attn_implementation", "fast")
args.no_token_positional_embeddings = getattr(args, "no_token_positional_embeddings", True)
args.rope_args = getattr(args, "rope_args", '{"theta": 10000}')
transformer_IT2_dist(args)


@register_model_architecture("transformer", "transformer_IT2")
def transformer_IT2(args):
args.activation_fn = getattr(args, "activation_fn", "gelu")
Expand All @@ -326,3 +334,13 @@ def transformer_IT2(args):
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
base_architecture(args)


@register_model_architecture("transformer", "roformer_IT2")
def roformer_IT2(args):
args.attn_implementation = getattr(args, "attn_implementation", "fast")
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", True
)
args.rope_args = getattr(args, "rope_args", '{"theta": 10000}')
transformer_IT2(args)
4 changes: 2 additions & 2 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .gumbel_vector_quantizer import GumbelVectorQuantizer
from .kmeans_vector_quantizer import KmeansVectorQuantizer
from .layer_drop import LayerDropModuleList
from .rms_norm import RMSNorm
from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
Expand All @@ -36,7 +37,6 @@
from .fast_multihead_attention import FastMultiheadAttention
from .fast_grouped_query_attention import FastGroupedQueryAttention
from .positional_embedding import PositionalEmbedding
from .rms_norm import RMSNorm
from .same_pad import SamePad, SamePad2d
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
Expand Down Expand Up @@ -85,6 +85,7 @@
"KmeansVectorQuantizer",
"LayerDropModuleList",
"LayerNorm",
"RMSNorm",
"LearnedPositionalEmbedding",
"LightweightConv1dTBC",
"LightweightConv",
Expand All @@ -110,7 +111,6 @@
"RelPositionMultiHeadedAttention",
"RelPositionalEncoding",
"RotaryPositionMultiHeadedAttention",
"RMSNorm",
"GLU",
"MLP",
]
55 changes: 35 additions & 20 deletions fairseq/modules/rms_norm.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,20 +1,35 @@
import torch
from torch import nn


class RMSNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.eps = eps
self.normalized_shape = normalized_shape
self.scale = nn.Parameter(torch.ones(normalized_shape))

def forward(self, x):
x_fp32 = x.float()
x_normed = (
x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
).type_as(x)
return x_normed * self.scale

def extra_repr(self):
return f"normalized_shape={self.normalized_shape}, eps={self.eps}"
import torch
import torch.nn as nn

try:
from apex.normalization import FusedRMSNorm as _FusedRMSNorm

has_fused_rmsnorm = True

class FusedRMSNorm(_FusedRMSNorm):
@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)

except ImportError:
has_fused_rmsnorm = False


def RMSNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
if torch.jit.is_scripting() or torch.jit.is_tracing():
export = True
if not export and torch.cuda.is_available() and has_fused_rmsnorm:
return FusedRMSNorm(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
return nn.RMSNorm(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
26 changes: 17 additions & 9 deletions fairseq/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
FastGroupedQueryAttention,
)

# This module does not support `scale_attn`, `scale_heads`, `scale_fc`, `scale_resids` anymore
# BUG: This module does not support `scale_attn`, `scale_heads`, `scale_fc`, `scale_resids` anymore


class TransformerEncoderLayerBase(nn.Module):
Expand Down Expand Up @@ -190,10 +190,14 @@ def build_self_attention(self, embed_dim, cfg):
)

def build_normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)
return (
LayerNorm(dim, export=self.cfg.export)
if not rms
else RMSNorm(dim, export=self.cfg.export)
)

def residual_connection(self, x, residual):
return residual + x
def residual_connection(self, x, residual, scale_resid=None):
return residual + x if scale_resid is None else (scale_resid * residual + x)

def upgrade_state_dict_named(self, state_dict, name):
"""
Expand Down Expand Up @@ -260,7 +264,7 @@ def forward(
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
x = self.residual_connection(x, residual, scale_resid=self.sa_scale_resid)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)

Expand All @@ -279,7 +283,7 @@ def forward(
fc_result = x

x = self.dropout_module(x)
x = self.residual_connection(x, residual)
x = self.residual_connection(x, residual, scale_resid=self.ffn_scale_resid)

if not self.normalize_before:
x = self.final_layer_norm(x)
Expand Down Expand Up @@ -494,11 +498,15 @@ def build_encoder_attention(self, embed_dim, cfg):
def prepare_for_onnx_export_(self):
self.onnx_trace = True

def residual_connection(self, x, residual):
return residual + x
def residual_connection(self, x, residual, scale_resid=None):
return residual + x if scale_resid is None else (scale_resid * residual + x)

def build_normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)
return (
LayerNorm(dim, export=self.cfg.export)
if not rms
else RMSNorm(dim, export=self.cfg.export)
)

def forward(
self,
Expand Down

0 comments on commit 61a2110

Please sign in to comment.