-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate.py
108 lines (83 loc) · 3.19 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("topics", help="select atleast 2 topics from 'death' ,'family', 'funny', 'freedom' , 'life' , 'love', 'happiness', 'science', 'success', 'politics'", type=str, nargs='+')
parser.add_argument("-s","--seed", help="input a custom seed sentence",type=str)
parser.add_argument("-n","--num", help="number of input and output sentences in quote",type=int)
args = parser.parse_args()
if len(args.topics) < 2 and not args.seed:
raise Exception("select atleast 2 topics from 'death' ,'family', 'funny', 'freedom' ,'life' , 'love', 'happiness', 'science', 'success', 'politics'")
n = 1
if args.num:
if args.num > 3:
raise Exception("Choose number of sentences between 1-3")
else:
n = args.num
import sys
import numpy as np
from keras.preprocessing.text import Tokenizer
from Model import Model
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# In[3]:
with open('data/all.txt','r') as quotefile:
quotes = quotefile.readlines()
# In[4]:
t = Tokenizer(filters='')
t.fit_on_texts(quotes)
vocab_size = len(t.word_index) + 1
# In[12]:
index_word = np.load('data/index_word.npy')
index_word = index_word.item()
topics = args.topics
# In[13]:
## accomodate custom seed
model_topics = []
funny_doc = []
seedlen = 50
maxlen = 50
sentence=""
if not args.seed:
with open('data/%s.txt'%topics[0],'r') as funnyfile:
funnyquotes = funnyfile.readlines()
encoded_docs = t.texts_to_sequences(funnyquotes)
funny_doc = encoded_docs[0]
start_index = np.random.randint(0, len(funny_doc) - seedlen - 1)
sentence = funny_doc[start_index: start_index + seedlen]
model_topics = topics[1:]
else:
sentence = args.seed
sentence = t.texts_to_sequences([word for word in sentence.split(' ')])
sentence = list(filter(None, sentence))
sentence = np.asarray(sentence).flatten()
model_topics = topics
model_list = []
# ## Do for all docs except first
for topic in model_topics:
model_funny = Model(vocab_size,topic)
model = model_funny.load_model()
model_list.append(model)
def on_epoch_end(sentence, model, maxlen = 10):
predicted = ''
original_sentence = ''.join([str(index_word[word])+' ' for word in sentence])
for i in range(maxlen):
x_pred = np.reshape(sentence,(1, -1))
preds = model.predict(x_pred, verbose=0)
preds = preds[0]
next_index = np.argmax(preds)
next_char = index_word[next_index]
sentence = np.append(sentence, next_index)
predicted = predicted + next_char + ' '
# sys.stdout.write(next_char)
if i % (maxlen // 4) == 0:
sys.stdout.write("-")
sys.stdout.flush()
sys.stdout.write("\n")
print('----- Input seed: %s'%','.join(original_sentence.split('.')[-n:]))
print('----- Output: %s'%','.join(predicted.split('.')[0:n]))
sys.stdout.write("-----\n")
return original_sentence.split('.')[-n:] + predicted.split('.')[0:n]
for model in model_list:
sentence = on_epoch_end(sentence,model,maxlen)
sentence = [sent.split(' ') for sent in sentence]
sentence = np.asarray(t.texts_to_sequences([word for word in sentence])).flatten()
#sentence = sentence[maxlen:] #