diff --git a/mamba/mamba.py b/mamba/mamba.py index 6b861c3..221c2da 100644 --- a/mamba/mamba.py +++ b/mamba/mamba.py @@ -123,10 +123,6 @@ def __init__(self, config: MambaConfig, use_cache: bool = False): padding=self.temporal_width - 1, ) - self.scan_fn = selective_scan - - self.resid_proj.weight.data.zero_() - def _ssm( self, x, @@ -250,7 +246,12 @@ def __init__(self, config: MambaConfig, use_cache=False): self.apply(self._init_weights) - self.lm_head.weight.data.zero_() + self.lm_head.weight = self.embed_tokens.weight + + for name, p in self.named_parameters(): + if name in ["resid_proj.weight"]: + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + p /= math.sqrt(1.0 * config.num_hidden_layers) def _init_weights(self, module: nn.Module): """Initialize the weights"""