-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
106 lines (90 loc) · 4.13 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import torchvision.models as models
# ----------- Encoder ------------
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)
# disable learning for parameters
for param in resnet.parameters():
param.requires_grad_(False)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
def forward(self, images):
features = self.resnet(images)
features = features.view(features.size(0), -1)
features = self.embed(features)
return features
# --------- Decoder ----------
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
"""
Args:
embed_size: final embedding size of the CNN encoder
hidden_size: hidden size of the LSTM
vocab_size: size of the vocabulary
num_layers: number of layers of the LSTM
"""
super(DecoderRNN, self).__init__()
# Assigning hidden dimension
self.hidden_dim = hidden_size
# Map each word index to a dense word embedding tensor of embed_size
self.embed = nn.Embedding(vocab_size, embed_size)
# Creating LSTM layer
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
# Initializing linear to apply at last of RNN layer for further prediction
self.linear = nn.Linear(hidden_size, vocab_size)
# Initializing values for hidden and cell state
self.hidden = (torch.zeros(1, 1, hidden_size), torch.zeros(1, 1, hidden_size))
def forward(self, features, captions):
"""
Args:
features: features tensor. shape is (bs, embed_size)
captions: captions tensor. shape is (bs, cap_length)
Returns:
outputs: scores of the linear layer
"""
# remove <end> token from captions and embed captions
cap_embedding = self.embed(
captions[:, :-1]
) # (bs, cap_length) -> (bs, cap_length-1, embed_size)
# concatenate the images features to the first of caption embeddings.
# [bs, embed_size] => [bs, 1, embed_size] concat [bs, cap_length-1, embed_size]
# => [bs, cap_length, embed_size] add encoded image (features) as t=0
embeddings = torch.cat((features.unsqueeze(dim=1), cap_embedding), dim=1)
# getting output i.e. score and hidden layer.
# first value: all the hidden states throughout the sequence. second value: the most recent hidden state
lstm_out, self.hidden = self.lstm(
embeddings
) # (bs, cap_length, hidden_size), (1, bs, hidden_size)
outputs = self.linear(lstm_out) # (bs, cap_length, vocab_size)
return outputs
def sample(self, inputs, states=None, max_len=20):
"""
accepts pre-processed image tensor (inputs) and returns predicted
sentence (list of tensor ids of length max_len)
Args:
inputs: shape is (1, 1, embed_size)
states: initial hidden state of the LSTM
max_len: maximum length of the predicted sentence
Returns:
res: list of predicted words indices
"""
res = []
# Now we feed the LSTM output and hidden states back into itself to get the caption
for i in range(max_len):
lstm_out, states = self.lstm(
inputs, states
) # lstm_out: (1, 1, hidden_size)
outputs = self.linear(lstm_out.squeeze(dim=1)) # outputs: (1, vocab_size)
_, predicted_idx = outputs.max(dim=1) # predicted: (1, 1)
res.append(predicted_idx.item())
# if the predicted idx is the stop index, the loop stops
if predicted_idx == 1:
break
inputs = self.embed(predicted_idx) # inputs: (1, embed_size)
# prepare input for next iteration
inputs = inputs.unsqueeze(1) # inputs: (1, 1, embed_size)
return res