Skip to content

Commit e98ff85

Browse files
committed
untie head, 0 inits
1 parent 56a67a0 commit e98ff85

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

hawk/hawk.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, config: HawkConfig):
7676

7777
def reset_parameters(self):
7878
lecun_init(self.gate_up_proj.weight, self.hidden_size)
79-
lecun_init(self.resid_proj.weight, self.intermediate_size)
79+
torch.nn.init.zeros_(self.resid_proj.weight)
8080

8181
def forward(self, x: torch.Tensor) -> torch.Tensor:
8282
x = self.gate_up_proj(x)
@@ -125,7 +125,7 @@ def reset_parameters(self) -> None:
125125
self.forget_init(self.rg_lru_a_param)
126126

127127
lecun_init(self.input_xy.weight, self.config.hidden_size)
128-
lecun_init(self.resid_proj.weight, self.config.recurrent_size)
128+
torch.nn.init.zeros_(self.resid_proj.weight)
129129

130130
def forget_init(self, w: torch.Tensor) -> torch.Tensor:
131131
"""Initializes the `A` parameter of the RG-LRU."""
@@ -253,10 +253,9 @@ def __init__(self, config: HawkConfig, use_cache=False):
253253

254254
self.reset_parameters()
255255

256-
self.lm_head.weight = self.embed_tokens.weight
257-
258256
def reset_parameters(self):
259257
lecun_init(self.embed_tokens.weight, self.config.hidden_size)
258+
torch.nn.init.zeros_(self.lm_head.weight)
260259

261260
def forward(
262261
self,

0 commit comments

Comments
 (0)