Skip to content

Commit

Permalink
bug fix in transformer layer
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Sep 10, 2024
1 parent 61a2110 commit 219ce4d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions fairseq/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def build_normalization(self, dim, rms=False):
else RMSNorm(dim, export=self.cfg.export)
)

def residual_connection(self, x, residual, scale_resid=None):
return residual + x if scale_resid is None else (scale_resid * residual + x)
def residual_connection(self, x, residual):
return residual + x

def upgrade_state_dict_named(self, state_dict, name):
"""
Expand Down Expand Up @@ -264,7 +264,7 @@ def forward(
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual, scale_resid=self.sa_scale_resid)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)

Expand All @@ -283,7 +283,7 @@ def forward(
fc_result = x

x = self.dropout_module(x)
x = self.residual_connection(x, residual, scale_resid=self.ffn_scale_resid)
x = self.residual_connection(x, residual)

if not self.normalize_before:
x = self.final_layer_norm(x)
Expand Down Expand Up @@ -498,8 +498,8 @@ def build_encoder_attention(self, embed_dim, cfg):
def prepare_for_onnx_export_(self):
self.onnx_trace = True

def residual_connection(self, x, residual, scale_resid=None):
return residual + x if scale_resid is None else (scale_resid * residual + x)
def residual_connection(self, x, residual):
return residual + x

def build_normalization(self, dim, rms=False):
return (
Expand Down

0 comments on commit 219ce4d

Please sign in to comment.