-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcausal.py
executable file
·73 lines (55 loc) · 2.46 KB
/
causal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# ----- Written by Ahmad Farhan ------
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import copy
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def attention(query, key, value, mask=None, dropout=None):
'''
query, key, value : All are projections of input
mask : To ensure future words are unreachable
'''
B, _, T, d_k = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / \
(math.sqrt(d_k)) # dot product b/w query,key
if mask is not None:
# make future words unreachable -inf
scores = scores.masked_fill(mask[:, :, :T, :T] == 0, -1e10)
prob_attn = F.softmax(scores, dim=-1) # calculating probs
if dropout is not None:
prob_attn = dropout(prob_attn) # pass through dropout
# attn_weights * value # weighted sum of values. each emb idx is a weighted sum of all other emb idxs of all T values
return torch.matmul(prob_attn, value)
class CausalSelfAttention(nn.Module):
'''
n_head : number of attention heads
block_size : context length
attn_pdrop : attention dropout probability
resid_pdrop : dropout prob after projection layer.
'''
def __init__(self, d_model, n_head, attn_pdrop, resid_pdrop):
super().__init__()
d_model = d_model
self.n_head = n_head
assert d_model % n_head == 0 # d_model/n_head are divisble
self.d_k = d_model//self.n_head
# key, value, query, out_proj
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# to hide future words
# subsequent_mask = torch.tril(torch.ones(block_size,block_size)).view(1,1,block_size,block_size)
# self.register_buffer("mask",subsequent_mask) # to make sure it is stored in states dict while saving model
def forward(self, x, mask):
B, T, d_model = x.size()
query, key, value = [l(x).view(B, -1, self.n_head, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (x, x, x))]
# print(x.shape)
y = attention(query, key, value, mask=mask, dropout=self.attn_drop)
y = y.transpose(1, 2).contiguous().view(B, T, d_model)
# print(y.shape)
# pass through a linear and dropout
return self.resid_drop(self.linears[-1](y))