From 44656f350f02db4eb80adfc245a3e1da3a61d7fc Mon Sep 17 00:00:00 2001 From: Benjamin Fattori Date: Sun, 17 Nov 2024 17:03:59 -0500 Subject: [PATCH] minor fixes to cache, remove post ln option for hawk + add conv width as a config option --- hawk/cache.py | 12 +++++++----- hawk/hawk.py | 12 +++--------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/hawk/cache.py b/hawk/cache.py index db4650d..37d5eb8 100644 --- a/hawk/cache.py +++ b/hawk/cache.py @@ -5,8 +5,10 @@ class RNNCache: state: torch.Tensor device: torch.device current_cache_size: int = 0 + temporal_conv_width: int = 4 - def __init__(self, state_dim: int, device: torch.device): + def __init__(self, state_dim: int, device: torch.device, temporal_conv_width: int): + self.conv_cache_size = self.temporal_conv_width - 1 self.recc_state = torch.full( (1, 1, state_dim), fill_value=torch.nan, @@ -15,7 +17,7 @@ def __init__(self, state_dim: int, device: torch.device): ) self.conv_state = torch.full( - (1, 3, state_dim), # window_size - 1 + (1, self.conv_cache_size, state_dim), # window_size - 1 fill_value=torch.nan, dtype=torch.bfloat16, device=device, @@ -27,7 +29,7 @@ def __init__(self, state_dim: int, device: torch.device): def update_cache(self, new_state: torch.Tensor) -> None: assert new_state.shape[0] == 1 assert new_state.shape[1] == 1 - assert new_state.ndim == 3 + assert new_state.ndim == self.conv_cache_size self.recc_state[...] = new_state self.current_cache_size = 1 @@ -35,10 +37,10 @@ def update_cache(self, new_state: torch.Tensor) -> None: def update_conv_cache(self, new_state: torch.Tensor) -> None: assert new_state.shape[0] == 1 assert new_state.shape[1] == 1 - assert new_state.ndim == 3 + assert new_state.ndim == self.conv_cache_size self.conv_state = torch.roll(self.conv_state, shifts=-1, dims=1) self.conv_state[:, -1, :] = new_state def __repr__(self) -> str: - return f"RNNCache: current_cache_size={self.current_cache_size}, state_dim={self.state_dim}" + return f"RNNCache: {self.current_cache_size=}, {self.state_dim=}, {self.temporal_conv_width=}" diff --git a/hawk/hawk.py b/hawk/hawk.py index af1dd38..a6ede9f 100644 --- a/hawk/hawk.py +++ b/hawk/hawk.py @@ -26,7 +26,7 @@ class HawkConfig: num_hidden_layers: int recurrent_size: int num_blocks: int - post_norm: bool = False + conv_width: int # ---- @@ -96,11 +96,6 @@ def __init__(self, config: HawkConfig, use_cache: bool = False): config.hidden_size, 2 * config.recurrent_size, bias=False ) - if config.post_norm: - self.norm = RMSNorm(config.recurrent_size) - else: - self.norm = nn.Identity() - self.rg_lru_input_gate = BlockDiagonalLinear( width=config.recurrent_size, num_blocks=self.config.num_blocks ) @@ -119,10 +114,9 @@ def __init__(self, config: HawkConfig, use_cache: bool = False): self.use_cache = use_cache - self.temporal_width = 4 self.conv1d = Conv1D( width=config.recurrent_size, - temporal_width=self.temporal_width, + temporal_width=self.config.conv_width, ) self.reset_parameters() @@ -138,7 +132,7 @@ def forget_init(self, w: torch.Tensor) -> torch.Tensor: return rnn_param_init(w, min_rad=0.9, max_rad=0.999) def epilogue(self, gate, h): - return self.resid_proj(F.gelu(gate) * self.norm(h)) + return self.resid_proj(F.gelu(gate) * h) def inference_prologue(self, x): # inference-only prologue function