-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
123 lines (101 loc) · 5.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import tensorflow as tf
import input_layer
import modekeys
import hparam
import os
import HRED
from tensorflow.python.training import saver as saver_lib
from tensorflow.python import debug as tf_debug
import evaluate
import datetime
def main(unused_arg):
tf.logging.set_verbosity(tf.logging.INFO)
train()
tf.flags.DEFINE_boolean('debug',False,'debug mode')
tf.flags.DEFINE_string('model_dir','./model/model1','model dir')
tf.flags.DEFINE_string('data_dir','./data','data dir')
FLAGS = tf.flags.FLAGS
TRAIN_FILE = os.path.join(os.path.abspath(FLAGS.data_dir), 'train.tfrecords')
MODEL_DIR = FLAGS.model_dir
if MODEL_DIR is None:
timestamp = datetime.datetime.now()
MODEL_DIR = os.path.join(os.path.abspath('./model'), str(timestamp))
def train():
hp = hparam.create_hparam()
train_graph = tf.Graph()
with train_graph.as_default():
input_features = input_layer.create_input_layer(filename=TRAIN_FILE,hp=hp,mode=modekeys.TRAIN)
loss,debug_tensors = HRED.impl(input_features=input_features,mode=modekeys.TRAIN,hp=hp)
global_step_tensor = tf.Variable(initial_value=0,
trainable=False,
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES],
name='global_step')
train_op,grad_norm = create_train_op(loss,hp.learning_rate,global_step_tensor,hp.clip_norm)
stop_criteria_tensor = tf.Variable(initial_value=10000, trainable=False, name='stop_criteria', dtype=tf.float32)
tf.summary.scalar(name='train_loss',tensor=loss)
tf.summary.scalar(name='train_grad_norm',tensor=grad_norm)
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(logdir=os.path.join(os.path.abspath(MODEL_DIR), 'summary'))
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
if FLAGS.debug:
#sess = tf_debug.LocalCLIDebugWrapperSession(sess,thread_name_filter = "MainThread$")
#sess.add_tensor_filter(tensor_filter=tf_debug.has_inf_or_nan,filter_name='has_inf_or_nan')
pass
saver = tf.train.Saver(max_to_keep=1)
best_saver = tf.train.Saver(max_to_keep=1)
checkpoint = saver_lib.latest_checkpoint(MODEL_DIR)
tf.logging.info('model dir {}'.format(MODEL_DIR))
tf.logging.info('check point {}'.format(checkpoint))
if checkpoint:
tf.logging.info('Restore parameter from {}'.format(checkpoint))
saver.restore(sess=sess,save_path=checkpoint)
sess.run(tf.local_variables_initializer())
else:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
tf.logging.info(msg='Begin training')
try:
stop_count = 10
while not coord.should_stop():
_,current_loss,summary,global_step = sess.run(fetches=[train_op,loss,summary_op,global_step_tensor])
if global_step % 100 == 0:
tf.logging.info('global step '+str(global_step)+' loss: ' + str(current_loss))
if global_step % hp.summary_save_steps == 0:
summary_writer.add_summary(summary=summary,global_step=global_step)
if global_step % hp.eval_step == 0:
saver.save(sess=sess, save_path=os.path.join(MODEL_DIR, 'model.ckpt'), global_step=global_step)
eval_file = os.path.join(os.path.abspath(FLAGS.data_dir), 'valid.tfrecords')
cur_stop_criteria = evaluate.evaluate(eval_file, MODEL_DIR, os.path.join(MODEL_DIR, 'summary/eval'),
global_step)
stop_criteria = sess.run(stop_criteria_tensor)
if cur_stop_criteria < stop_criteria:
sess.run(stop_criteria_tensor.assign(cur_stop_criteria))
best_model_path = os.path.join(os.path.join(MODEL_DIR, 'best_model'), 'model.ckpt')
save_path = best_saver.save(sess=sess, save_path=best_model_path,
global_step=tf.train.get_global_step())
tf.logging.info('Save best model to {}'.format(save_path))
stop_count = 10
else:
stop_count -= 1
if stop_count == 0:
tf.logging.info('Early stop at step {}'.format(global_step))
break
except tf.errors.OutOfRangeError:
tf.logging.info('Finish training -- epoch limit reached')
finally:
tf.logging.info('Best ppl is {}'.format(sess.run(stop_criteria_tensor)))
coord.request_stop()
coord.join(threads)
def create_train_op(loss,lr,global_step,clip_norm):
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
grad_var = optimizer.compute_gradients(loss)
grad_var = [(tf.clip_by_norm(grad, clip_norm=clip_norm), var) for grad, var in grad_var]
train_op = optimizer.apply_gradients(grads_and_vars=grad_var, global_step=global_step)
debug_tensors = [gv[0] for gv in grad_var]
grads = [gv[0] for gv in grad_var]
grad_norm = tf.global_norm(grads)
return train_op, grad_norm
if __name__ == '__main__':
tf.app.run()