-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
74 lines (57 loc) · 2.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# coding=utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import os
import time
import readData
import vae_model
np.random.seed(0)
tf.set_random_seed(0)
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('buckets', 'D:\\DL\\aliyun\\VAE\\data', '数据目录')
tf.flags.DEFINE_string('checkpointDir', 'D:\DL\\aliyun\\VAE\\model', '模型保存路径')
tf.flags.DEFINE_string('summaryDir', 'D:\\DL\\aliyun\\VAE\\logs', 'tensorboard保存路径')
tf.flags.DEFINE_integer('batch_size', 256, 'Batch Size')
tf.flags.DEFINE_float('learning_rate', 1e-3, 'learning rate')
tf.flags.DEFINE_integer('display_step', 500, 'Display step')
tf.flags.DEFINE_float('train_step', 100000, 'Train step')
train_file_path = os.path.join(FLAGS.buckets, 'train.tfrecords')
# read data
reader = readData.MnistReader(train_file_path, batch_size=FLAGS.batch_size)
image_batch = reader.read_image_batch()
x = tf.placeholder(dtype=tf.float32, shape=[FLAGS.batch_size, 28*28])
vae = vae_model.VariationalAutoEncoder(data=x,
learning_rate=FLAGS.learning_rate,
batch_size=FLAGS.batch_size)
vae.create_vae_network()
print('create VAE model network')
vae.create_loss_optimizer()
sess = tf.InteractiveSession()
summary = tf.summary.FileWriter(FLAGS.summaryDir, graph=sess.graph)
saver = tf.train.Saver(var_list=tf.trainable_variables())
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for step in range(FLAGS.train_step):
if coord.should_stop():
break
train_image_batch = sess.run(image_batch)
_, losses = sess.run(fetches=[vae.optimizer, vae.losses], feed_dict={x: train_image_batch})
if step % FLAGS.display_step == 0:
summary_data = tf.summary.merge_all()
summary.add_summary(sess.run(summary_data, feed_dict={x: train_image_batch}), step)
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
print('Step:', '%05d' % step, 'losses:', '{:.8f}'.format(losses))
except tf.errors.OutOfRangeError:
print('train done')
finally:
coord.request_stop()
saver.save(sess=sess, save_path=os.path.join(FLAGS.checkpointDir, 'vae.model'))
coord.join(threads)
sess.close()
summary.close()
print('train done!')