-
Notifications
You must be signed in to change notification settings - Fork 937
/
Copy pathcompose_poem.py
92 lines (75 loc) · 3.07 KB
/
compose_poem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
# file: main.py
# author: JinTian
# time: 11/03/2017 9:53 AM
# Copyright 2017 JinTian. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems
import numpy as np
start_token = 'B'
end_token = 'E'
model_dir = './model/'
corpus_file = './data/poems.txt'
lr = 0.0002
def to_word(predict, vocabs):
predict = predict[0]
predict /= np.sum(predict)
sample = np.random.choice(np.arange(len(predict)), p=predict)
if sample > len(vocabs):
return vocabs[-1]
else:
return vocabs[sample]
def gen_poem(begin_word):
batch_size = 1
print('## loading corpus from %s' % model_dir)
poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
input_data = tf.placeholder(tf.int32, [batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
checkpoint = tf.train.latest_checkpoint(model_dir)
saver.restore(sess, checkpoint)
x = np.array([list(map(word_int_map.get, start_token))])
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x})
word = begin_word or to_word(predict, vocabularies)
poem_ = ''
i = 0
while word != end_token:
poem_ += word
i += 1
if i > 24:
break
x = np.array([[word_int_map[word]]])
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x, end_points['initial_state']: last_state})
word = to_word(predict, vocabularies)
return poem_
def pretty_print_poem(poem_):
poem_sentences = poem_.split('。')
for s in poem_sentences:
if s != '' and len(s) > 10:
print(s + '。')
if __name__ == '__main__':
begin_char = input('## (输入 quit 退出)请输入第一个字 please input the first character: ')
if begin_char == 'quit':
exit()
poem = gen_poem(begin_char)
pretty_print_poem(poem_=poem)