Skip to content

Commit

Permalink
shift tokens is still highly effective when done right before the pre…
Browse files Browse the repository at this point in the history
…norm. cite Peng et al.
  • Loading branch information
lucidrains committed Jan 13, 2024
1 parent 2a9967f commit 3ba926b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,14 @@ assert x.shape == out.shape
}
```

```bibtex
@inproceedings{Peng2023RWKVRR,
title = {RWKV: Reinventing RNNs for the Transformer Era},
author = {Bo Peng and Eric Alcaide and Quentin G. Anthony and Alon Albalak and Samuel Arcadinho and Stella Biderman and Huanqi Cao and Xin Cheng and Michael Chung and Matteo Grella and G Kranthikiran and Xuming He and Haowen Hou and Przemyslaw Kazienko and Jan Kocoń and Jiaming Kong and Bartlomiej Koptyra and Hayden Lau and Krishna Sri Ipsit Mantri and Ferdinand Mom and Atsushi Saito and Xiangru Tang and Bolun Wang and Johan Sokrates Wind and Stansilaw Wozniak and Ruichong Zhang and Zhenyuan Zhang and Qihang Zhao and Peng Zhou and Jian Zhu and Rui Zhu},
booktitle = {Conference on Empirical Methods in Natural Language Processing},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:258832459}
}
```

*The greatest shortcoming of the human race is man’s inability to understand the exponential function.* - Albert A. Bartlett
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'taylor-series-linear-attention',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'Taylor Series Linear Attention',
author = 'Phil Wang',
Expand Down
28 changes: 27 additions & 1 deletion taylor_series_linear_attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def shift(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = -1)

# prenorm

class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))

def forward(self, x):
return self.gamma * F.normalize(x, dim = -1) * self.scale

# they use 2nd taylor expansion for exp(x)
# https://arxiv.org/abs/2209.04881
# in a linear attention formulation
Expand Down Expand Up @@ -58,12 +74,17 @@ def __init__(
one_headed_kv = False,
rotary_emb = False,
combine_heads = True,
gate_value_heads = False
gate_value_heads = False,
prenorm = False,
shift_tokens = False
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads

self.shift_tokens = shift_tokens
self.norm = RMSNorm(dim) if prenorm else nn.Identity()

self.heads = heads
self.dim_hidden = dim_inner

Expand Down Expand Up @@ -126,6 +147,11 @@ def forward(
is_cross_attn = exists(context)
assert not (exists(self.rotary_emb) and is_cross_attn), 'rotary embedding does not work with cross attention'

if self.shift_tokens:
x = shift(x)

x = self.norm(x)

q = self.to_q(x)
k, v = self.to_kv(default(context, x))

Expand Down

0 comments on commit 3ba926b

Please sign in to comment.