Skip to content

Commit

Permalink
[PROGRESS]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 9, 2024
1 parent dc257ff commit 36e8b1b
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 67 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

# HedgeHog
Implementation of the model "Hedgehog" from the paper: "The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry". This paper implements MLPs to mimic the softmax of a transformer. Suppodesly hits SOTA on wikitext for sub quadratic models. I've too been thinking about replacing softmax with MLPs. This past month we saw doezens of papers on mamba and convolutions but MLPs might have undiscovered powers.





Expand Down
19 changes: 19 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from hedgehog.main import Hedgehog

# Creat tokens
x = torch.randint(0, 100, (1, 100))

# Create model
model = Hedgehog(
num_tokens=100,
dim=512,
head_dim=512,
)

# Forward
out = model(x)


# Out
print(out)
252 changes: 185 additions & 67 deletions hedgehog/main.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,235 @@
import torch
import torch
from torch import nn, Tensor
from einops import rearrange
from zeta.nn import FeedForward

def softmax_attn(
q: Tensor,
k: Tensor,
):
scale = q.shape[-1] ** -0.5
qk = torch.einsum("bhmd, bhnd -> bhmn", q, k) * scale
return torch.softmax(qk, dim=-1)


def quadratic_linear_attn(
q: Tensor,
q: Tensor,
k: Tensor,
):
qk = torch.einsum(
"bhnd, bhnd -> bhmn", q, k
)
qk = torch.einsum("bhnd, bhnd -> bhmn", q, k)
return qk / qk.sum(dim=-1, keepdim=True)



class HedgeHogModule(nn.Module):
"""
HedgeHogModule is a PyTorch module that applies linear transformation
followed by an activation function to the input tensor.
Args:
head_dim (int): The dimension of the input tensor.
activation (str, optional): The activation function to be applied.
Defaults to "exp".
Attributes:
head_dim (int): The dimension of the input tensor.
activation (str): The activation function to be applied.
layer (nn.Linear): The linear transformation layer.
Methods:
init_weights: Initializes the weights of the linear layer.
forward: Performs forward pass through the module.
"""

def __init__(
self,
head_dim: int,
dim: int,
activation: str = "exp",
):
super().__init__()
self.head_dim = head_dim
self.dim = dim
self.activation = activation
self.layer = nn.Linear(head_dim, head_dim)
self.init_weights()
self.layer = nn.Linear(dim, dim)
self.init_weights_()

def init_weights_(self):
nn.init.eye_(self.layer.weight)
nn.init.zeros_(self.layer.bias)

def forward(self, x: Tensor) -> Tensor:
x = self.layer(x) # Shape BATCH, HEADS, SEQLEN, DIMENSION
return torch.cat(
[torch.exp(x), torch.exp(-x)],
dim=-1
),


"""
Performs forward pass through the module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor after applying linear transformation
and activation function.
"""
x = self.layer(x) # Shape BATCH, HEADS, SEQLEN, DIMENSION
return torch.cat([torch.exp(x), torch.exp(-x)], dim=1)



class HedgeHogAttention(nn.Module):
def __init__(
self,
base_attn,
dim: int,
head_dim: int,
training: bool = True,
output_attentions: bool = False,
qk_norm: bool = False,
*args,
**kwargs
**kwargs,
):
"""
HedgeHogAttention module that performs attention computation.
Args:
dim (int): The input dimension of the module.
base_attn: The base attention module.
training (bool, optional): Whether the module is in training mode. Defaults to True.
output_attentions (bool, optional): Whether to output attention weights. Defaults to False.
"""
super().__init__()
self.base_attn = base_attn
self.dim = dim
self.head_dim = head_dim
self.training = training
self.qk_norm = qk_norm
self.output_attentions = output_attentions

# Trainable maps
self.mlp_q = HedgeHogModule(base_attn.head_dim)
self.mlp_k = HedgeHogModule(base_attn.head_dim)

self.mlp_q = HedgeHogModule(dim)
self.mlp_k = HedgeHogModule(dim)
self.mlp_v = HedgeHogModule(dim)

# Freeze params
for p in self.base_attn.parameters():
p.requires_grad = False

self.q_proj = self.base_attn.q_proj
self.k_proj = self.base_attn.k_proj

if not self.training:
for p in self.base_attn.parameters():
p.requires_grad = False

self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)

# If qk norm
if qk_norm:
self.norm = nn.LayerNorm(dim)

def forward(self, x: Tensor) -> Tensor:
q, k, v = x

"""
Forward pass of the HedgeHogAttention module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The concatenated tensor of q, k, and v.
"""
# Compute maps
q = self.mlp_q(
self.q_proj(x)
)

k = self.mlp_k(
self.k_proj(x)
)

# Pred attns
pred_attns = quadratic_linear_attn(q, k)


# Output
true_attns = self.base_attn(x)

if self.output_attentions:
return pred_attns, true_attns
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)

if self.qk_norm:
q, k = self.norm(q), self.norm(k)

# Apply the mlp
q = self.mlp_q(q)
k = self.mlp_k(k)
v = self.mlp_v(v)

concat = q + k + v
print(f"concat shape: {concat.shape}")

return concat


class HedgehogBlock(nn.Module):
class Hedgehog(nn.Module):
"""
Hedgehog module for performing attention-based computations.
Args:
num_tokens (int): Number of tokens in the input.
dim (int): Dimension of the input.
heads (int, optional): Number of attention heads. Defaults to 8.
depth (int, optional): Number of layers. Defaults to 4.
head_dim (int, optional): Dimension of each attention head. Defaults to 64.
mult (int, optional): Multiplier for the feedforward layer. Defaults to 4.
dropout (float, optional): Dropout probability. Defaults to 0.1.
Attributes:
dim (int): Dimension of the input.
heads (int): Number of attention heads.
depth (int): Number of layers.
head_dim (int): Dimension of each attention head.
mult (int): Multiplier for the feedforward layer.
dropout (float): Dropout probability.
layers (nn.ModuleList): List of attention and feedforward layers.
emb (nn.Embedding): Embedding layer.
norm (nn.LayerNorm): Layer normalization.
to_out (nn.Sequential): Sequential layer for output transformation.
"""

def __init__(
self,
num_tokens: int,
dim: int,
heads: int = 8,
dim_head: int = 64,
depth: int = 4,
head_dim: int = 64,
mult: int = 4,
dropout: float = 0.1,
*args,
**kwargs
**kwargs,
):
super().__init__()

self.dim = dim
self.heads = heads
self.dim_head = dim_head
self.depth = depth
self.head_dim = head_dim
self.mult = mult
self.dropout = dropout

self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(dim, dim)
self.to_v = nn.Linear(dim, dim)

self.weight = nn.Parameter(torch.randn(heads, dim_head, dim_head))
self.beta = nn.Parameter(torch.randn(heads, dim_head, dim_head))

self.theta = torch.exp(self.weight.transpose(1, 2) + self.weight)



# layers
self.layers = nn.ModuleList([])

# Embedding
self.emb = nn.Embedding(num_tokens, dim)

# Add both the attention and the feedforward
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
HedgeHogAttention(
dim=dim,
head_dim=head_dim,
),
FeedForward(
dim=dim,
mult=mult,
dropout=dropout,
*args,
**kwargs,
),
]
)
)

# norm
self.norm = nn.LayerNorm(dim)

# To out
self.to_out = nn.Sequential(
nn.LayerNorm(dim), nn.Linear(dim, dim), nn.Softmax(dim=-1)
)

def forward(self, x: Tensor) -> Tensor:
x = self.emb(x)
print(f"x embedding shape: {x.shape}")
for attn, ff in self.layers:
x = attn(x) + x
print(f"x attn shape: {x.shape}")
x = ff(x) + x
return self.to_out(x)

0 comments on commit 36e8b1b

Please sign in to comment.