-
Notifications
You must be signed in to change notification settings - Fork 0
/
ViTHelper.py
111 lines (97 loc) · 5.4 KB
/
ViTHelper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
## This code is from the Transformers co-class of DLStudio:
## https://engineering.purdue.edu/kak/distDLS/
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class MasterEncoder(nn.Module):
def __init__(self, max_seq_length, embedding_size, how_many_basic_encoders, num_atten_heads):
super().__init__()
self.max_seq_length = max_seq_length
self.basic_encoder_arr = nn.ModuleList([BasicEncoder(
max_seq_length, embedding_size, num_atten_heads) for _ in range(how_many_basic_encoders)]) # (A)
def forward(self, sentence_tensor):
out_tensor = sentence_tensor
for i in range(len(self.basic_encoder_arr)): # (B)
out_tensor = self.basic_encoder_arr[i](out_tensor)
return out_tensor
class BasicEncoder(nn.Module):
def __init__(self, max_seq_length, embedding_size, num_atten_heads):
super().__init__()
self.max_seq_length = max_seq_length
self.embedding_size = embedding_size
self.qkv_size = self.embedding_size // num_atten_heads
self.num_atten_heads = num_atten_heads
self.self_attention_layer = SelfAttention(
max_seq_length, embedding_size, num_atten_heads) # (A)
self.norm1 = nn.LayerNorm(self.embedding_size) # (C)
self.W1 = nn.Linear(self.max_seq_length * self.embedding_size,
self.max_seq_length * 2 * self.embedding_size)
self.W2 = nn.Linear(self.max_seq_length * 2 * self.embedding_size,
self.max_seq_length * self.embedding_size)
self.norm2 = nn.LayerNorm(self.embedding_size) # (E)
def forward(self, sentence_tensor):
input_for_self_atten = sentence_tensor.float()
normed_input_self_atten = self.norm1(input_for_self_atten)
output_self_atten = self.self_attention_layer(
normed_input_self_atten).to(device) # (F)
input_for_FFN = output_self_atten + input_for_self_atten
normed_input_FFN = self.norm2(input_for_FFN) # (I)
basic_encoder_out = nn.ReLU()(
self.W1(normed_input_FFN.view(sentence_tensor.shape[0], -1))) # (K)
basic_encoder_out = self.W2(basic_encoder_out) # (L)
basic_encoder_out = basic_encoder_out.view(
sentence_tensor.shape[0], self.max_seq_length, self.embedding_size)
basic_encoder_out = basic_encoder_out + input_for_FFN
return basic_encoder_out
#################################### Self Attention Code TransformerPreLN ###########################################
class SelfAttention(nn.Module):
def __init__(self, max_seq_length, embedding_size, num_atten_heads):
super().__init__()
self.max_seq_length = max_seq_length
self.embedding_size = embedding_size
self.num_atten_heads = num_atten_heads
self.qkv_size = self.embedding_size // num_atten_heads
self.attention_heads_arr = nn.ModuleList([AttentionHead(self.max_seq_length,
self.qkv_size) for _ in range(num_atten_heads)]) # (A)
def forward(self, sentence_tensor): # (B)
concat_out_from_atten_heads = torch.zeros(sentence_tensor.shape[0], self.max_seq_length,
self.num_atten_heads * self.qkv_size).float()
for i in range(self.num_atten_heads): # (C)
sentence_tensor_portion = sentence_tensor[:,
:, i * self.qkv_size: (i+1) * self.qkv_size]
concat_out_from_atten_heads[:, :, i * self.qkv_size: (i+1) * self.qkv_size] = \
self.attention_heads_arr[i](sentence_tensor_portion) # (D)
return concat_out_from_atten_heads
class AttentionHead(nn.Module):
def __init__(self, max_seq_length, qkv_size):
super().__init__()
self.qkv_size = qkv_size
self.max_seq_length = max_seq_length
self.WQ = nn.Linear(max_seq_length * self.qkv_size,
max_seq_length * self.qkv_size) # (B)
self.WK = nn.Linear(max_seq_length * self.qkv_size,
max_seq_length * self.qkv_size) # (C)
self.WV = nn.Linear(max_seq_length * self.qkv_size,
max_seq_length * self.qkv_size) # (D)
self.softmax = nn.Softmax(dim=1) # (E)
def forward(self, sentence_portion): # (F)
Q = self.WQ(sentence_portion.reshape(
sentence_portion.shape[0], -1).float()).to(device) # (G)
K = self.WK(sentence_portion.reshape(
sentence_portion.shape[0], -1).float()).to(device) # (H)
V = self.WV(sentence_portion.reshape(
sentence_portion.shape[0], -1).float()).to(device) # (I)
Q = Q.view(sentence_portion.shape[0],
self.max_seq_length, self.qkv_size) # (J)
K = K.view(sentence_portion.shape[0],
self.max_seq_length, self.qkv_size) # (K)
V = V.view(sentence_portion.shape[0],
self.max_seq_length, self.qkv_size) # (L)
A = K.transpose(2, 1) # (M)
QK_dot_prod = Q @ A # (N)
rowwise_softmax_normalizations = self.softmax(QK_dot_prod) # (O)
Z = rowwise_softmax_normalizations @ V
coeff = 1.0/torch.sqrt(torch.tensor([self.qkv_size]).float()).to(device) # (S)
Z = coeff * Z # (T)
return Z