-
Notifications
You must be signed in to change notification settings - Fork 1
/
educational_transformer.py
149 lines (132 loc) · 5.46 KB
/
educational_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch.nn.functional as F
import copy
import torch
import torch.nn as nn
import numpy as np
class MultiHeadedAttention(nn.Module):
def __init__(self, n_heads, d_model, dropout=0.1):
super(MultiHeadedAttention, self).__init__()
self.n_heads = n_heads
self.d_model = d_model
self.d_k = int(d_model / n_heads)
self.linear_query = nn.Linear(d_model, d_model)
self.linear_key = nn.Linear(d_model, d_model)
self.linear_value = nn.Linear(d_model, d_model)
self.linear_out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(p=dropout)
self.alphas = None
def make_chunks(self, x):
batch_size, seq_len = x.size(0), x.size(1)
# N, L, D -> N, L, n_heads * d_k
x = x.view(batch_size, seq_len, self.n_heads, self.d_k)
# N, n_heads, L, d_k
x = x.transpose(1, 2)
return x
def init_keys(self, key):
# N, n_heads, L, d_k
self.proj_key = self.make_chunks(self.linear_key(key))
self.proj_value = self.make_chunks(self.linear_value(key))
def score_function(self, query):
# scaled dot product
# N, n_heads, L, d_k x # N, n_heads, d_k, L -> N, n_heads, L, L
proj_query = self.make_chunks(self.linear_query(query))
dot_products = torch.matmul(proj_query,
self.proj_key.transpose(-2, -1))
scores = dot_products / np.sqrt(self.d_k)
return scores
def attn(self, query, mask=None):
# Query is batch-first: N, L, D
# Score function will generate scores for each head
scores = self.score_function(query) # N, n_heads, L, L
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
alphas = F.softmax(scores, dim=-1) # N, n_heads, L, L
alphas = self.dropout(alphas)
self.alphas = alphas.detach()
# N, n_heads, L, L x N, n_heads, L, d_k -> N, n_heads, L, d_k
context = torch.matmul(alphas, self.proj_value)
return context
def output_function(self, contexts):
# N, L, D
out = self.linear_out(contexts) # N, L, D
return out
def forward(self, query, mask=None):
if mask is not None:
# N, 1, L, L - every head uses the same mask
mask = mask.unsqueeze(1)
# N, n_heads, L, d_k
context = self.attn(query, mask=mask)
# N, L, n_heads, d_k
context = context.transpose(1, 2).contiguous()
# N, L, n_heads * d_k = N, L, d_model
context = context.view(query.size(0), -1, self.d_model)
# N, L, d_model
out = self.output_function(context)
return out
class PositionalEncoding(nn.Module):
def __init__(self, max_len, d_model):
super().__init__()
self.d_model = d_model
self.max_len = max_len
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).float().unsqueeze(1)
angular_speed = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * angular_speed) # even dimensions
pe[:, 1::2] = torch.cos(position * angular_speed) # odd dimensions
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
# x is N, L, D
# pe is 1, maxlen, D
scaled_x = x * np.sqrt(self.d_model)
encoded = scaled_x + self.pe[:, :x.size(1), :]
return encoded
class SubLayerWrapper(nn.Module):
def __init__(self, d_model, dropout):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, sublayer, is_self_attn=False, **kwargs):
norm_x = self.norm(x)
if is_self_attn:
sublayer.init_keys(norm_x)
out = x + self.drop(sublayer(norm_x, **kwargs))
return out
class EncoderLayer(nn.Module):
def __init__(self, n_heads, d_model, ff_units, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.d_model = d_model
self.ff_units = ff_units
self.self_attn_heads = MultiHeadedAttention(n_heads, d_model,
dropout=dropout)
self.ffn = nn.Sequential(
nn.Linear(d_model, ff_units),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ff_units, d_model),
)
self.sublayers = nn.ModuleList([SubLayerWrapper(d_model, dropout) for _ in range(2)])
def forward(self, query, mask=None):
# SubLayer 0 - Self-Attention
att = self.sublayers[0](query,
sublayer=self.self_attn_heads,
is_self_attn=True,
mask=mask)
# SubLayer 1 - FFN
out = self.sublayers[1](att, sublayer=self.ffn)
return out
class EncoderTransf(nn.Module):
def __init__(self, encoder_layer, n_layers=1, max_len=100):
super().__init__()
self.d_model = encoder_layer.d_model
self.pe = PositionalEncoding(max_len, self.d_model)
self.norm = nn.LayerNorm(self.d_model)
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer)
for _ in range(n_layers)])
def forward(self, query, mask=None):
# Positional Encoding
x = self.pe(query)
for layer in self.layers:
x = layer(x, ~mask.unsqueeze(1))
# Norm
return self.norm(x)