-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhstu.py
152 lines (122 loc) · 6.54 KB
/
hstu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class PointwiseAggregatedAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# TODO: add relative attention bias based on time
self.rab_p = RelativeAttentionBias(num_heads, relative_attention_num_buckets=32,
relative_attention_max_distance=128)
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
return x.permute(0, 2, 1, 3)
def forward(self, v, k, q, mask=None):
batch_size = q.shape[0]
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
attention_scores = torch.matmul(q, k.transpose(-2, -1))
# attention_scores=torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
rab = self.rab_p(q.shape[2], k.shape[2], device=q.device)
att_w_bias = attention_scores + rab
av = (F.silu(att_w_bias) @ v)
return av.transpose(1, 2).flatten(2)
class RelativeAttentionBias(nn.Module):
def __init__(self, num_heads, relative_attention_num_buckets, relative_attention_max_distance=128):
super().__init__()
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.relative_attention_bias = nn.Embedding(relative_attention_num_buckets, num_heads)
def forward(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=False,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
# https://github.com/huggingface/transformers/blob/6cdbd73e01a9719bfaec07d91fd108e8d932bbbb/src/transformers/models/t5/modeling_t5.py#L384
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
class HSTUBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
self.f1 = nn.Linear(d_model, d_model * 4) # Transform and split
self.pointwise_attn = PointwiseAggregatedAttention(d_model, num_heads)
self.f2 = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
def split(self, x):
u, v, q, k = x.chunk(4, dim=-1)
return u, v, q, k
def forward(self, x):
# Pointwise Projection
x_proj = F.silu(self.f1(x))
u, v, q, k = self.split(x_proj)
# Spatial Aggregation
av = self.pointwise_attn(v, k, q)
# Pointwise Transformation
y = self.f2(self.norm(av * u))
return y
class GenRec(nn.Module):
def __init__(self, d_model, num_heads, num_layers, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([HSTUBlock(d_model, num_heads, dropout) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x) + x
return x
if __name__ == "__main__":
model = GenRec(d_model=52, num_heads=2, num_layers=3)
input_shape = (32, 10, 52)
x = torch.rand(*input_shape)
assert model(x).shape == input_shape