Skip to content

Commit

Permalink
minor fixes to cache, remove post ln option for hawk + add conv width…
Browse files Browse the repository at this point in the history
… as a config option
  • Loading branch information
fattorib committed Nov 17, 2024
1 parent 73b6e76 commit 44656f3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
12 changes: 7 additions & 5 deletions hawk/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ class RNNCache:
state: torch.Tensor
device: torch.device
current_cache_size: int = 0
temporal_conv_width: int = 4

def __init__(self, state_dim: int, device: torch.device):
def __init__(self, state_dim: int, device: torch.device, temporal_conv_width: int):
self.conv_cache_size = self.temporal_conv_width - 1
self.recc_state = torch.full(
(1, 1, state_dim),
fill_value=torch.nan,
Expand All @@ -15,7 +17,7 @@ def __init__(self, state_dim: int, device: torch.device):
)

self.conv_state = torch.full(
(1, 3, state_dim), # window_size - 1
(1, self.conv_cache_size, state_dim), # window_size - 1
fill_value=torch.nan,
dtype=torch.bfloat16,
device=device,
Expand All @@ -27,18 +29,18 @@ def __init__(self, state_dim: int, device: torch.device):
def update_cache(self, new_state: torch.Tensor) -> None:
assert new_state.shape[0] == 1
assert new_state.shape[1] == 1
assert new_state.ndim == 3
assert new_state.ndim == self.conv_cache_size

self.recc_state[...] = new_state
self.current_cache_size = 1

def update_conv_cache(self, new_state: torch.Tensor) -> None:
assert new_state.shape[0] == 1
assert new_state.shape[1] == 1
assert new_state.ndim == 3
assert new_state.ndim == self.conv_cache_size

self.conv_state = torch.roll(self.conv_state, shifts=-1, dims=1)
self.conv_state[:, -1, :] = new_state

def __repr__(self) -> str:
return f"RNNCache: current_cache_size={self.current_cache_size}, state_dim={self.state_dim}"
return f"RNNCache: {self.current_cache_size=}, {self.state_dim=}, {self.temporal_conv_width=}"
12 changes: 3 additions & 9 deletions hawk/hawk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class HawkConfig:
num_hidden_layers: int
recurrent_size: int
num_blocks: int
post_norm: bool = False
conv_width: int


# ----
Expand Down Expand Up @@ -96,11 +96,6 @@ def __init__(self, config: HawkConfig, use_cache: bool = False):
config.hidden_size, 2 * config.recurrent_size, bias=False
)

if config.post_norm:
self.norm = RMSNorm(config.recurrent_size)
else:
self.norm = nn.Identity()

self.rg_lru_input_gate = BlockDiagonalLinear(
width=config.recurrent_size, num_blocks=self.config.num_blocks
)
Expand All @@ -119,10 +114,9 @@ def __init__(self, config: HawkConfig, use_cache: bool = False):

self.use_cache = use_cache

self.temporal_width = 4
self.conv1d = Conv1D(
width=config.recurrent_size,
temporal_width=self.temporal_width,
temporal_width=self.config.conv_width,
)

self.reset_parameters()
Expand All @@ -138,7 +132,7 @@ def forget_init(self, w: torch.Tensor) -> torch.Tensor:
return rnn_param_init(w, min_rad=0.9, max_rad=0.999)

def epilogue(self, gate, h):
return self.resid_proj(F.gelu(gate) * self.norm(h))
return self.resid_proj(F.gelu(gate) * h)

def inference_prologue(self, x):
# inference-only prologue function
Expand Down

0 comments on commit 44656f3

Please sign in to comment.