From 0b376d1bfd8b757091d43a0fbf7516b16faffac2 Mon Sep 17 00:00:00 2001 From: Benjamin Fattori Date: Mon, 11 Nov 2024 17:54:39 -0500 Subject: [PATCH] remove unused external imports --- hawk/external.py | 69 ------------------------------------------------ hawk/hawk.py | 10 +++---- 2 files changed, 3 insertions(+), 76 deletions(-) diff --git a/hawk/external.py b/hawk/external.py index 2bada16..c635833 100644 --- a/hawk/external.py +++ b/hawk/external.py @@ -2,11 +2,7 @@ https://github.com/google-deepmind/recurrentgemma/blob/main/recurrentgemma/torch/layers.py """ -import math - import torch -import torch.nn as nn -from einops import rearrange def rnn_param_init( @@ -25,68 +21,3 @@ def rnn_param_init( return tensor.neg_().exp_().sub_(1.0).log_() else: raise NotImplementedError() - - -class BlockDiagonalLinear(nn.Module): - """Block-diagonal linear layer.""" - - def __init__( - self, - width: int, - num_blocks: int, - w_init_variance_scale: float = 1.0, - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - ): - """Initializes the BlockDiagonalLinear. - - Args: - width: The number of dimensions of the input and output. - num_blocks: The number of diagonal blocks in the layer. - w_init_variance_scale: A parameters that scales the variance of the - initialization of the weights. - device: On what device to initialize parameters. Needed to allow for - initializing the module without parameter initialzation. - dtype: What dtype to use for initialziation. - """ - super().__init__() - self.width = width - self.num_blocks = num_blocks - self.w_init_variance_scale = w_init_variance_scale - self.block_width = self.width // self.num_blocks - - # Parameters. - self.w = nn.Parameter( - torch.empty( - [self.num_blocks, self.block_width, self.block_width], - device=device, - dtype=dtype, - ) - ) - self.b = nn.Parameter( - torch.empty([self.num_blocks, self.block_width], device=device, dtype=dtype) - ) - - # Initialization. - self.reset_parameters() - - def reset_parameters(self) -> None: - """Resets the parameters of the module.""" - self.w_init_(self.w) - torch.nn.init.zeros_(self.b) - - def w_init_(self, w: torch.Tensor) -> None: - """Initializes the weight `w` of the layer.""" - std = math.sqrt(self.w_init_variance_scale / self.block_width) - torch.nn.init.normal_(w, mean=0.0, std=std) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Calls the BlockDiagonalLinear.""" - # Split x to blocks. - x = rearrange(x, "... (h i) -> ... h i", h=self.num_blocks) - - # Linear layer over each block + bias. - y = torch.einsum("... h i, h i j -> ... h j", x, self.w) + self.b - - # Flatten the output. - return rearrange(y, "... h j -> ... (h j)", h=self.num_blocks) diff --git a/hawk/hawk.py b/hawk/hawk.py index 8def5d2..0a267db 100644 --- a/hawk/hawk.py +++ b/hawk/hawk.py @@ -10,7 +10,7 @@ from .cache import RNNCache from .causal_conv1d import causal_conv -from .external import BlockDiagonalLinear, rnn_param_init +from .external import rnn_param_init from .fused_cross_entropy import fused_cross_entropy from .scan_fused import fused_linear_scan @@ -63,13 +63,9 @@ def __init__(self, config: HawkConfig, use_cache: bool = False): else: self.norm = nn.Identity() - self.rg_lru_input_gate = BlockDiagonalLinear( - width=config.recurrent_size, num_blocks=self.config.num_blocks - ) + self.rg_lru_input_gate = nn.Linear(config.recurrent_size, config.recurrent_size) - self.rg_lru_a_gate = BlockDiagonalLinear( - width=config.recurrent_size, num_blocks=self.config.num_blocks - ) + self.rg_lru_a_gate = nn.Linear(config.recurrent_size, config.recurrent_size) self.rg_lru_a_param = nn.Parameter( torch.empty([config.recurrent_size], dtype=torch.float32)