diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 0abcb0afaa..feb2e68d15 100755 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -196,8 +196,8 @@ def build_normalization(self, dim, rms=False): else RMSNorm(dim, export=self.cfg.export) ) - def residual_connection(self, x, residual, scale_resid=None): - return residual + x if scale_resid is None else (scale_resid * residual + x) + def residual_connection(self, x, residual): + return residual + x def upgrade_state_dict_named(self, state_dict, name): """ @@ -264,7 +264,7 @@ def forward( attn_mask=attn_mask, ) x = self.dropout_module(x) - x = self.residual_connection(x, residual, scale_resid=self.sa_scale_resid) + x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) @@ -283,7 +283,7 @@ def forward( fc_result = x x = self.dropout_module(x) - x = self.residual_connection(x, residual, scale_resid=self.ffn_scale_resid) + x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) @@ -498,8 +498,8 @@ def build_encoder_attention(self, embed_dim, cfg): def prepare_for_onnx_export_(self): self.onnx_trace = True - def residual_connection(self, x, residual, scale_resid=None): - return residual + x if scale_resid is None else (scale_resid * residual + x) + def residual_connection(self, x, residual): + return residual + x def build_normalization(self, dim, rms=False): return (