Skip to content

Commit

Permalink
remove unused external imports
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Nov 11, 2024
1 parent 1c14052 commit 0b376d1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 76 deletions.
69 changes: 0 additions & 69 deletions hawk/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
https://github.com/google-deepmind/recurrentgemma/blob/main/recurrentgemma/torch/layers.py
"""

import math

import torch
import torch.nn as nn
from einops import rearrange


def rnn_param_init(
Expand All @@ -25,68 +21,3 @@ def rnn_param_init(
return tensor.neg_().exp_().sub_(1.0).log_()
else:
raise NotImplementedError()


class BlockDiagonalLinear(nn.Module):
"""Block-diagonal linear layer."""

def __init__(
self,
width: int,
num_blocks: int,
w_init_variance_scale: float = 1.0,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""Initializes the BlockDiagonalLinear.
Args:
width: The number of dimensions of the input and output.
num_blocks: The number of diagonal blocks in the layer.
w_init_variance_scale: A parameters that scales the variance of the
initialization of the weights.
device: On what device to initialize parameters. Needed to allow for
initializing the module without parameter initialzation.
dtype: What dtype to use for initialziation.
"""
super().__init__()
self.width = width
self.num_blocks = num_blocks
self.w_init_variance_scale = w_init_variance_scale
self.block_width = self.width // self.num_blocks

# Parameters.
self.w = nn.Parameter(
torch.empty(
[self.num_blocks, self.block_width, self.block_width],
device=device,
dtype=dtype,
)
)
self.b = nn.Parameter(
torch.empty([self.num_blocks, self.block_width], device=device, dtype=dtype)
)

# Initialization.
self.reset_parameters()

def reset_parameters(self) -> None:
"""Resets the parameters of the module."""
self.w_init_(self.w)
torch.nn.init.zeros_(self.b)

def w_init_(self, w: torch.Tensor) -> None:
"""Initializes the weight `w` of the layer."""
std = math.sqrt(self.w_init_variance_scale / self.block_width)
torch.nn.init.normal_(w, mean=0.0, std=std)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calls the BlockDiagonalLinear."""
# Split x to blocks.
x = rearrange(x, "... (h i) -> ... h i", h=self.num_blocks)

# Linear layer over each block + bias.
y = torch.einsum("... h i, h i j -> ... h j", x, self.w) + self.b

# Flatten the output.
return rearrange(y, "... h j -> ... (h j)", h=self.num_blocks)
10 changes: 3 additions & 7 deletions hawk/hawk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .cache import RNNCache
from .causal_conv1d import causal_conv
from .external import BlockDiagonalLinear, rnn_param_init
from .external import rnn_param_init
from .fused_cross_entropy import fused_cross_entropy
from .scan_fused import fused_linear_scan

Expand Down Expand Up @@ -63,13 +63,9 @@ def __init__(self, config: HawkConfig, use_cache: bool = False):
else:
self.norm = nn.Identity()

self.rg_lru_input_gate = BlockDiagonalLinear(
width=config.recurrent_size, num_blocks=self.config.num_blocks
)
self.rg_lru_input_gate = nn.Linear(config.recurrent_size, config.recurrent_size)

self.rg_lru_a_gate = BlockDiagonalLinear(
width=config.recurrent_size, num_blocks=self.config.num_blocks
)
self.rg_lru_a_gate = nn.Linear(config.recurrent_size, config.recurrent_size)

self.rg_lru_a_param = nn.Parameter(
torch.empty([config.recurrent_size], dtype=torch.float32)
Expand Down

0 comments on commit 0b376d1

Please sign in to comment.