-
Notifications
You must be signed in to change notification settings - Fork 0
/
query_model.py
64 lines (53 loc) · 1.98 KB
/
query_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import print_function
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
from keras.datasets.data_utils import get_file
import numpy as np
import random
import sys
import operator
path = get_file('nietzsche.txt', origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt")
text = open(path).read().lower()
print('corpus length:', len(text))
chars = set(text)
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
maxlen = 20
try:
print('Loading model...')
fh = open('keras-nietzsche.model.json', 'rb')
model = model_from_json(fh.read())
fh.close()
except Exception as error:
print('Build model...')
model = Sequential()
model.add(LSTM(512, return_sequences=True, input_shape=(maxlen, len(chars))))
model.add(Dropout(0.2))
model.add(LSTM(512, return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
fh = open('keras-nietzsche.model.json', 'wb')
fh.write(model.to_json())
fh.close()
model.load_weights('nietzsche.model.weights.h5')
def sample(a, temperature=1.0):
# helper function to sample an index from a probability array
a = np.log(a) / temperature
a = np.exp(a) / np.sum(np.exp(a))
return np.argmax(np.random.multinomial(1, a, 1))
while True:
sentence = raw_input('Text to spell check: ')
for iteration in range(len(sentence)):
x = np.zeros((1, maxlen, len(chars)))
for t, char in enumerate(sentence):
x[0, t, char_indices[char]] = 1.
preds = model.predict(x, verbose=0)[0]
next_index = sample(preds, 0.1)
next_char = indices_char[next_index]
sentence = sentence[1:] + next_char
sys.stdout.write(next_char)
sys.stdout.flush()