-
Notifications
You must be signed in to change notification settings - Fork 0
/
pointer_network.py
66 lines (54 loc) · 2.96 KB
/
pointer_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Reference: https://github.com/guacomolia/ptr_net
# https://medium.com/@devnag/pointer-networks-in-tensorflow-with-sample-code-14645063f264
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import to_var
class PointerNetwork(nn.Module):
def __init__(self, input_size, emb_size, weight_size, answer_seq_len, hidden_size=512, is_GRU=True):
super(PointerNetwork, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.answer_seq_len = answer_seq_len
self.weight_size = weight_size
self.emb_size = emb_size
self.is_GRU = is_GRU
self.emb = nn.Embedding(input_size, emb_size) # embed inputs
if is_GRU:
self.enc = nn.GRU(emb_size, hidden_size, batch_first=True)
self.dec = nn.GRUCell(emb_size, hidden_size) # GRUCell's input is always batch first
else:
self.enc = nn.LSTM(emb_size, hidden_size, batch_first=True)
self.dec = nn.LSTMCell(emb_size, hidden_size) # LSTMCell's input is always batch first
self.W1 = nn.Linear(hidden_size, weight_size, bias=False) # blending encoder
self.W2 = nn.Linear(hidden_size, weight_size, bias=False) # blending decoder
self.vt = nn.Linear(weight_size, 1, bias=False) # scaling sum of enc and dec by v.T
def forward(self, input):
batch_size = input.size(0)
input = self.emb(input) # (bs, L, embd_size)
# Encoding
encoder_states, hc = self.enc(input) # encoder_state: (bs, L, H)
encoder_states = encoder_states.transpose(1, 0) # (L, bs, H)
# Decoding states initialization
decoder_input = to_var(torch.zeros(batch_size, self.emb_size)) # (bs, embd_size)
hidden = to_var(torch.zeros([batch_size, self.hidden_size])) # (bs, h)
cell_state = encoder_states[-1] # (bs, h)
probs = []
# Decoding
for i in range(self.answer_seq_len): # range(M)
if self.is_GRU:
hidden = self.dec(decoder_input, hidden) # (bs, h), (bs, h)
else:
hidden, cell_state = self.dec(decoder_input, (hidden, cell_state)) # (bs, h), (bs, h)
# Compute blended representation at each decoder time step
blend1 = self.W1(encoder_states) # (L, bs, W)
blend2 = self.W2(hidden) # (bs, W)
# blend_sum = F.tanh(blend1 + blend2) # (L, bs, W)
# UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
# Reference: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.tanh
blend_sum = torch.tanh(blend1 + blend2) # (L, bs, W)
out = self.vt(blend_sum).squeeze() # (L, bs)
out = F.log_softmax(out.transpose(0, 1).contiguous(), -1) # (bs, L)
probs.append(out)
probs = torch.stack(probs, dim=1) # (bs, M, L)
return probs