-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathimdb_cnn.py
67 lines (55 loc) · 1.85 KB
/
imdb_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
'''This example demonstrates the use of Convolution1D for text classification.
Gets to 0.89 test accuracy after 2 epochs.
90s/epoch on Intel i5 2.4Ghz CPU.
10s/epoch on Tesla K40 GPU.
'''
from __future__ import print_function
from keras.preprocessing import sequence
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalMaxPooling1D, Input
from keras.datasets import imdb
# set parameters:
max_features = 5000
maxlen = 400
batch_size = 32
embedding_dims = 50
filters = 250
kernel_size = 3
hidden_dims = 250
epochs = 2
print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')
print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
print('y_train shape:', y_train.shape)
print('Build model...')
x = Input(shape=(maxlen, ))
embed = Embedding(max_features, embedding_dims, input_length=maxlen)(x)
print(embed.shape)
dropout = Dropout(0.2)(embed)
conv1 = Conv1D(filters,
kernel_size,
padding='valid',
activation='relu',
strides=1)(dropout)
pool = GlobalMaxPooling1D()(conv1)
x_out = Dense(hidden_dims, )(pool)
dropout = Dropout(0.2)(x_out)
out = Activation('relu')(dropout)
x_out = Dense(1, activation='sigmoid')(out)
model = Model(x, x_out)
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
print(model.summary())
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test))