-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrestore_model.py
68 lines (55 loc) · 2.45 KB
/
restore_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 12 20:37:16 2018
@author: wu
"""
import tensorflow as tf
import tfrecord
import alexnet
BATCH_SIZE=40
IMG_W=100
IMG_H=100
N_CLASSES=6
learning_rate = 0.0001
epoch=1000
dropout=0.5
val_filename = '/home/wu/TF_Project/action/sample_TFrecord/val1.tfrecords'
model_dir = '/home/wu/TF_Project/action/model_tfrecord_sample/'
with tf.Graph().as_default():
val_img, val_label = tfrecord.read_and_decode(val_filename)
train_filename_queue = tf.train.string_input_producer([val_filename],num_epochs=None)
val_batch, val_label_batch = tf.train.shuffle_batch([val_img, val_label],
batch_size=40, capacity=2000,
min_after_dequeue=1000)
x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_W, IMG_H, 3])
y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
train_model = alexnet.alexNet(x, dropout, N_CLASSES)
logits = train_model.fc3
loss = alexnet.losses(logits, y_)
acc = alexnet.evaluation(logits, y_)
train_op = alexnet.training(loss, learning_rate)
with tf.Session() as sess:
saver = tf.train.Saver()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord,sess=sess)
ckpt = tf.train.get_checkpoint_state(model_dir)
#if ckpt and ckpt.model_checkpoint_path:
# saver.restore(sess, ckpt.model_checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess,'/home/wu/TF_Project/action/model_tfrecord_sample/model.ckpt-1000')
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
print("This is %s training global step(s) model" % global_step)
try:
for step in range(epoch+1):
if coord.should_stop():break
val_images,val_labels = sess.run([val_batch, val_label_batch])
_, val_loss, val_acc = sess.run([train_op, loss, acc],feed_dict={x:val_images, y_:val_labels})
if step % 50 ==0 :
print('Step %d, val loss = %.2f, val accuracy = %.2f%%' %(step, val_loss, val_acc*100.0))
except tf.errors.OutOfRangeError:
print('Done reading')
finally:
coord.request_stop()
coord.join(threads)