-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
ch15_part2.py
412 lines (262 loc) · 10.6 KB
/
ch15_part2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
# coding: utf-8
import sys
from python_environment_check import check_packages
import torch
import torch.nn as nn
from torchtext.datasets import IMDB
from torch.utils.data.dataset import random_split
import re
from collections import Counter, OrderedDict
from torchtext.vocab import vocab
from torch.utils.data import DataLoader
# # Machine Learning with PyTorch and Scikit-Learn
# # -- Code Examples
# ## Package version checks
# Add folder to path in order to load from the check_packages.py script:
sys.path.insert(0, '..')
# Check recommended package versions:
d = {
'torch': '1.8.0',
'torchtext': '0.10.0'
}
check_packages(d)
# # Chapter 15: Modeling Sequential Data Using Recurrent Neural Networks (Part 2/3)
# **Outline**
#
# - [Implementing RNNs for sequence modeling in PyTorch](#Implementing-RNNs-for-sequence-modeling-in-PyTorch)
# - [Project one -- predicting the sentiment of IMDb movie reviews](#Project-one----predicting-the-sentiment-of-IMDb-movie-reviews)
# - [Preparing the movie review data](#Preparing-the-movie-review-data)
# - [Embedding layers for sentence encoding](#Embedding-layers-for-sentence-encoding)
# - [Building an RNN model](#Building-an-RNN-model)
# - [Building an RNN model for the sentiment analysis task](#Building-an-RNN-model-for-the-sentiment-analysis-task)
# - [More on the bidirectional RNN](#More-on-the-bidirectional-RNN)
# # Implementing RNNs for sequence modeling in PyTorch
#
# ## Project one: predicting the sentiment of IMDb movie reviews
#
# ### Preparing the movie review data
#
#
# !pip install torchtext
# Step 1: load and create the datasets
train_dataset = IMDB(split='train')
test_dataset = IMDB(split='test')
torch.manual_seed(1)
train_dataset, valid_dataset = random_split(
list(train_dataset), [20000, 5000])
## Step 2: find unique tokens (words)
token_counts = Counter()
def tokenizer(text):
text = re.sub('<[^>]*>', '', text)
emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text.lower())
text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
tokenized = text.split()
return tokenized
for label, line in train_dataset:
tokens = tokenizer(line)
token_counts.update(tokens)
print('Vocab-size:', len(token_counts))
## Step 3: encoding each unique token into integers
sorted_by_freq_tuples = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
vocab = vocab(ordered_dict)
vocab.insert_token("<pad>", 0)
vocab.insert_token("<unk>", 1)
vocab.set_default_index(1)
print([vocab[token] for token in ['this', 'is', 'an', 'example']])
## Step 3-A: define the functions for transformation
device = torch.device("cuda:0")
# device = 'cpu'
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: 1. if x == 'pos' else 0.
## Step 3-B: wrap the encode and transformation function
def collate_batch(batch):
label_list, text_list, lengths = [], [], []
for _label, _text in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text),
dtype=torch.int64)
text_list.append(processed_text)
lengths.append(processed_text.size(0))
label_list = torch.tensor(label_list)
lengths = torch.tensor(lengths)
padded_text_list = nn.utils.rnn.pad_sequence(
text_list, batch_first=True)
return padded_text_list.to(device), label_list.to(device), lengths.to(device)
## Take a small batch
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False, collate_fn=collate_batch)
text_batch, label_batch, length_batch = next(iter(dataloader))
print(text_batch)
print(label_batch)
print(length_batch)
print(text_batch.shape)
## Step 4: batching the datasets
batch_size = 32
train_dl = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, collate_fn=collate_batch)
valid_dl = DataLoader(valid_dataset, batch_size=batch_size,
shuffle=False, collate_fn=collate_batch)
test_dl = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, collate_fn=collate_batch)
# ### Embedding layers for sentence encoding
#
#
# * `input_dim`: number of words, i.e. maximum integer index + 1.
# * `output_dim`:
# * `input_length`: the length of (padded) sequence
# * for example, `'This is an example' -> [0, 0, 0, 0, 0, 0, 3, 1, 8, 9]`
# => input_lenght is 10
#
#
#
# * When calling the layer, takes integr values as input,
# the embedding layer convert each interger into float vector of size `[output_dim]`
# * If input shape is `[BATCH_SIZE]`, output shape will be `[BATCH_SIZE, output_dim]`
# * If input shape is `[BATCH_SIZE, 10]`, output shape will be `[BATCH_SIZE, 10, output_dim]`
embedding = nn.Embedding(num_embeddings=10,
embedding_dim=3,
padding_idx=0)
# a batch of 2 samples of 4 indices each
text_encoded_input = torch.LongTensor([[1,2,4,5],[4,3,2,0]])
print(embedding(text_encoded_input))
# ### Building an RNN model
#
# * **RNN layers:**
# * `nn.RNN(input_size, hidden_size, num_layers=1)`
# * `nn.LSTM(..)`
# * `nn.GRU(..)`
# * `nn.RNN(input_size, hidden_size, num_layers=1, bidirectional=True)`
#
#
## An example of building a RNN model
## with simple RNN layer
# Fully connected neural network with one hidden layer
class RNN(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.rnn = nn.RNN(input_size,
hidden_size,
num_layers=2,
batch_first=True)
#self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
#self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
_, hidden = self.rnn(x)
out = hidden[-1, :, :]
out = self.fc(out)
return out
model = RNN(64, 32)
print(model)
model(torch.randn(5, 3, 64))
# ### Building an RNN model for the sentiment analysis task
class RNN(nn.Module):
def __init__(self, vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size,
embed_dim,
padding_idx=0)
self.rnn = nn.LSTM(embed_dim, rnn_hidden_size,
batch_first=True)
self.fc1 = nn.Linear(rnn_hidden_size, fc_hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(fc_hidden_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, text, lengths):
out = self.embedding(text)
out = nn.utils.rnn.pack_padded_sequence(out, lengths.cpu().numpy(), enforce_sorted=False, batch_first=True)
out, (hidden, cell) = self.rnn(out)
out = hidden[-1, :, :]
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
vocab_size = len(vocab)
embed_dim = 20
rnn_hidden_size = 64
fc_hidden_size = 64
torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size)
model = model.to(device)
def train(dataloader):
model.train()
total_acc, total_loss = 0, 0
for text_batch, label_batch, lengths in dataloader:
optimizer.zero_grad()
pred = model(text_batch, lengths)[:, 0]
loss = loss_fn(pred, label_batch)
loss.backward()
optimizer.step()
total_acc += ((pred>=0.5).float() == label_batch).float().sum().item()
total_loss += loss.item()*label_batch.size(0)
return total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset)
def evaluate(dataloader):
model.eval()
total_acc, total_loss = 0, 0
with torch.no_grad():
for text_batch, label_batch, lengths in dataloader:
pred = model(text_batch, lengths)[:, 0]
loss = loss_fn(pred, label_batch)
total_acc += ((pred>=0.5).float() == label_batch).float().sum().item()
total_loss += loss.item()*label_batch.size(0)
return total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
torch.manual_seed(1)
for epoch in range(num_epochs):
acc_train, loss_train = train(train_dl)
acc_valid, loss_valid = evaluate(valid_dl)
print(f'Epoch {epoch} accuracy: {acc_train:.4f} val_accuracy: {acc_valid:.4f}')
acc_test, _ = evaluate(test_dl)
print(f'test_accuracy: {acc_test:.4f}')
# #### More on the bidirectional RNN
# * **Trying bidirectional recurrent layer**
class RNN(nn.Module):
def __init__(self, vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size,
embed_dim,
padding_idx=0)
self.rnn = nn.LSTM(embed_dim, rnn_hidden_size,
batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(rnn_hidden_size*2, fc_hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(fc_hidden_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, text, lengths):
out = self.embedding(text)
out = nn.utils.rnn.pack_padded_sequence(out, lengths.cpu().numpy(), enforce_sorted=False, batch_first=True)
_, (hidden, cell) = self.rnn(out)
out = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size)
model = model.to(device)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
num_epochs = 10
torch.manual_seed(1)
for epoch in range(num_epochs):
acc_train, loss_train = train(train_dl)
acc_valid, loss_valid = evaluate(valid_dl)
print(f'Epoch {epoch} accuracy: {acc_train:.4f} val_accuracy: {acc_valid:.4f}')
test_dataset = IMDB(split='test')
test_dl = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, collate_fn=collate_batch)
acc_test, _ = evaluate(test_dl)
print(f'test_accuracy: {acc_test:.4f}')
# ## Optional exercise:
#
# ### Uni-directional SimpleRNN with full-length sequences
#
# ---
#
#
# Readers may ignore the next cell.
#