-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
56 lines (42 loc) · 1.3 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from torch.nn import functional
class Attention(nn.Module):
'''
'''
def __init__(self, dim_embedding, dim_head, dim_context, prob_dropout):
super().__init__()
self.key = nn.Linear(
dim_embedding,
dim_head,
bias=False
)
self.query = nn.Linear(
dim_embedding,
dim_head,
bias=False
)
self.value = nn.Linear(
dim_embedding,
dim_head,
bias=False
)
self.register_buffer(
'tril',
torch.tril(torch.ones(dim_context, dim_context))
)
self.dropout = nn.Dropout(prob_dropout)
def forward(self, x):
batch_size, d_token, dim_context = x.shape
key = self.key(x)
query = self.query(x)
# Equation 1 in original paper
attention = query @ key.transpose(-2,-1)*(key.shape[-1]**-0.5)
masked_attention = attention.masked_fill(
self.tril[:d_token, :d_token] == 0, float('-inf')
)
masked_attention = functional.softmax(masked_attention, dim=-1)
masked_attention = self.dropout(masked_attention)
value = self.value(x)
y = masked_attention @ value
return y