Skip to content

Commit b5d843a

Browse files
committed
0 init on residuals, silu backward pass to fp32
1 parent 0b38cc0 commit b5d843a

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

mamba/mamba.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __init__(self, config: MambaConfig, use_cache: bool = False):
125125

126126
self.scan_fn = selective_scan
127127

128+
self.resid_proj.weight.data.zero_()
129+
128130
def _ssm(
129131
self,
130132
x,
@@ -246,14 +248,9 @@ def __init__(self, config: MambaConfig, use_cache=False):
246248

247249
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
248250

249-
self.lm_head.weight = self.embed_tokens.weight
250-
251251
self.apply(self._init_weights)
252252

253-
for name, p in self.named_parameters():
254-
if name in ["resid_proj.weight"]:
255-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
256-
p /= math.sqrt(1.0 * config.num_hidden_layers)
253+
self.lm_head.weight.data.zero_()
257254

258255
def _init_weights(self, module: nn.Module):
259256
"""Initialize the weights"""

mamba/mamba_inner_fn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def backward(ctx, grad_output): # type: ignore
127127

128128
dy = grad_output @ out_proj_w
129129

130-
dres = dy * y * silu_bwd(res)
130+
dres = dy * y * silu_bwd(res.float())
131131

132132
dout_proj_w = torch.sum(grad_output.mT @ y_f.to(x_conv_out.dtype), dim=0)
133133

0 commit comments

Comments
 (0)