Skip to content

Commit

Permalink
remove unused attr, revert to original weight init
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Nov 9, 2024
1 parent b5d843a commit b0d8d98
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,6 @@ def __init__(self, config: MambaConfig, use_cache: bool = False):
padding=self.temporal_width - 1,
)

self.scan_fn = selective_scan

self.resid_proj.weight.data.zero_()

def _ssm(
self,
x,
Expand Down Expand Up @@ -250,7 +246,12 @@ def __init__(self, config: MambaConfig, use_cache=False):

self.apply(self._init_weights)

self.lm_head.weight.data.zero_()
self.lm_head.weight = self.embed_tokens.weight

for name, p in self.named_parameters():
if name in ["resid_proj.weight"]:
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
p /= math.sqrt(1.0 * config.num_hidden_layers)

def _init_weights(self, module: nn.Module):
"""Initialize the weights"""
Expand Down

0 comments on commit b0d8d98

Please sign in to comment.