From d5153e91db856e2e9b6b38134fe0d148bb496a3b Mon Sep 17 00:00:00 2001 From: "ben@xps" Date: Thu, 26 Dec 2024 09:48:29 -0500 Subject: [PATCH] revert unitied head and 0 init --- hawk/hawk.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hawk/hawk.py b/hawk/hawk.py index 294336b..54a8c19 100644 --- a/hawk/hawk.py +++ b/hawk/hawk.py @@ -76,7 +76,7 @@ def __init__(self, config: HawkConfig): def reset_parameters(self): lecun_init(self.gate_up_proj.weight, self.hidden_size) - torch.nn.init.zeros_(self.resid_proj.weight) + torch.nn.init.zeros_(self.resid_proj.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.gate_up_proj(x) @@ -125,7 +125,7 @@ def reset_parameters(self) -> None: self.forget_init(self.rg_lru_a_param) lecun_init(self.input_xy.weight, self.config.hidden_size) - torch.nn.init.zeros_(self.resid_proj.weight) + torch.nn.init.zeros_(self.resid_proj.weight) def forget_init(self, w: torch.Tensor) -> torch.Tensor: """Initializes the `A` parameter of the RG-LRU.""" @@ -253,9 +253,10 @@ def __init__(self, config: HawkConfig, use_cache=False): self.reset_parameters() + self.lm_head.weight = self.embed_tokens.weight + def reset_parameters(self): lecun_init(self.embed_tokens.weight, self.config.hidden_size) - torch.nn.init.zeros_(self.lm_head.weight) def forward( self,