Skip to content

Commit

Permalink
revert unitied head and 0 init
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Dec 26, 2024
1 parent e98ff85 commit d5153e9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions hawk/hawk.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, config: HawkConfig):

def reset_parameters(self):
lecun_init(self.gate_up_proj.weight, self.hidden_size)
torch.nn.init.zeros_(self.resid_proj.weight)
torch.nn.init.zeros_(self.resid_proj.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.gate_up_proj(x)
Expand Down Expand Up @@ -125,7 +125,7 @@ def reset_parameters(self) -> None:
self.forget_init(self.rg_lru_a_param)

lecun_init(self.input_xy.weight, self.config.hidden_size)
torch.nn.init.zeros_(self.resid_proj.weight)
torch.nn.init.zeros_(self.resid_proj.weight)

def forget_init(self, w: torch.Tensor) -> torch.Tensor:
"""Initializes the `A` parameter of the RG-LRU."""
Expand Down Expand Up @@ -253,9 +253,10 @@ def __init__(self, config: HawkConfig, use_cache=False):

self.reset_parameters()

self.lm_head.weight = self.embed_tokens.weight

def reset_parameters(self):
lecun_init(self.embed_tokens.weight, self.config.hidden_size)
torch.nn.init.zeros_(self.lm_head.weight)

def forward(
self,
Expand Down

0 comments on commit d5153e9

Please sign in to comment.