diff --git a/model.py b/model.py index 71a7e30..30cd43b 100644 --- a/model.py +++ b/model.py @@ -1,7 +1,7 @@ import os from keras.models import Sequential, load_model -from keras.layers import LSTM, Dropout, TimeDistributed, Dense, Activation, Embedding +from keras.layers import LSTM, Dropout, TimeDistributed, Dense, Activation, Embedding, CuDNNLSTM MODEL_DIR = './model' @@ -17,7 +17,7 @@ def build_model(batch_size, seq_len, vocab_size): model = Sequential() model.add(Embedding(vocab_size, 512, batch_input_shape=(batch_size, seq_len))) for i in range(3): - model.add(LSTM(256, return_sequences=True, stateful=True)) + model.add(CuDNNLSTM(256, return_sequences=True, stateful=True)) model.add(Dropout(0.2)) model.add(TimeDistributed(Dense(vocab_size)))