-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining.py
105 lines (82 loc) · 3.51 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
## The training.py file trains the classifier
## I recommend using GPU to train the model because using CPU is too time-consuming!
import pickle
import build_network
import tensorflow as tf
import preprocessing as prep
from set_params import batch_size, epochs, keep_probability
############################################################################################################
# Train the Neural Network
# Single Optimization
def train_neural_network(session, optimizer, keep_probability, feature_batch, label_batch):
"""
Optimize the session on a batch of images and labels
: session: Current TensorFlow session
: optimizer: TensorFlow optimizer function
: keep_probability: keep probability
: feature_batch: Batch of Numpy image data
: label_batch: Batch of Numpy label data
"""
session.run(optimizer,
feed_dict={
x: feature_batch,
y: label_batch,
keep_prob: keep_probability})
# Show Stats
def print_stats(session, feature_batch, label_batch, cost, accuracy):
"""
Print information about loss and validation accuracy
: session: Current TensorFlow session
: feature_batch: Batch of Numpy image data
: label_batch: Batch of Numpy label data
: cost: TensorFlow cost function
: accuracy: TensorFlow accuracy function
"""
loss = session.run(cost, feed_dict={
x: feature_batch,
y: label_batch,
keep_prob: 1.0})
valid_acc = session.run(accuracy, feed_dict={
x: valid_features,
y: valid_labels,
keep_prob: 1.0})
print('Loss: {:>10.4f} Validation Accuracy: {:.6f}'.format(loss, valid_acc))
############################################################################################################
# Set up tensorflow placeholders
tf.reset_default_graph()
# Inputs
x = build_network.neural_net_image_input((32, 32, 3))
y = build_network.neural_net_label_input(10)
keep_prob = build_network.neural_net_keep_prob_input()
# Model
logits = build_network.conv_net(x, keep_prob)
# Name logits Tensor, so that is can be loaded from disk after training
logits = tf.identity(logits, name='logits')
# Loss and Optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.AdamOptimizer().minimize(cost)
# Accuracy
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')
# Load validation set
valid_features, valid_labels = pickle.load(open('preprocess_validation.p', mode='rb'))
############################################################################################################
# Train the Model
save_model_path = './image_classification'
print('Training...')
with tf.Session() as sess:
# Initializing the variables
sess.run(tf.global_variables_initializer())
# Training cycle
for epoch in range(epochs):
# Loop over all batches
n_batches = 5
for batch_i in range(1, n_batches + 1):
for batch_features, batch_labels in prep.load_preprocess_training_batch(batch_i, batch_size):
train_neural_network(sess, optimizer, keep_probability, batch_features, batch_labels)
print('Epoch {:>2}, CIFAR-10 Batch {}: '.format(epoch + 1, batch_i), end='')
print_stats(sess, batch_features, batch_labels, cost, accuracy)
# Save Model
saver = tf.train.Saver()
save_path = saver.save(sess, save_model_path)
print('Training complete')