-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
71 lines (53 loc) · 2.39 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pickle
## The test.py tests the models performance on the test set.
import random
import tensorflow as tf
import preprocessing as prep
from set_params import batch_size
# Test Model
# Set batch size if not already set
try:
if batch_size:
pass
except NameError:
batch_size = 64
save_model_path = './image_classification'
n_samples = 4
top_n_predictions = 3
def test_model():
"""
Test the saved model against the test dataset
"""
test_features, test_labels = pickle.load(open('preprocess_test.p', mode='rb'))
loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
# Load model
loader = tf.train.import_meta_graph(save_model_path + '.meta')
loader.restore(sess, save_model_path)
# Get Tensors from loaded model
loaded_x = loaded_graph.get_tensor_by_name('x:0')
loaded_y = loaded_graph.get_tensor_by_name('y:0')
loaded_keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0')
loaded_logits = loaded_graph.get_tensor_by_name('logits:0')
loaded_acc = loaded_graph.get_tensor_by_name('accuracy:0')
# Get accuracy in batches for memory limitations
test_batch_acc_total = 0
test_batch_count = 0
for test_feature_batch, test_label_batch in prep.batch_features_labels(test_features, test_labels,
batch_size):
test_batch_acc_total += sess.run(
loaded_acc,
feed_dict={loaded_x: test_feature_batch, loaded_y: test_label_batch, loaded_keep_prob: 1.0})
test_batch_count += 1
print('Testing Accuracy: {}\n'.format(test_batch_acc_total / test_batch_count))
# Print Random Samples
random_test_features, random_test_labels = tuple(
zip(*random.sample(list(zip(test_features, test_labels)), n_samples)))
random_test_predictions = sess.run(
tf.nn.top_k(tf.nn.softmax(loaded_logits), top_n_predictions),
feed_dict={loaded_x: random_test_features, loaded_y: random_test_labels, loaded_keep_prob: 1.0})
prep.display_image_predictions(random_test_features, random_test_labels, random_test_predictions)
############################################################################################################
print('Testing......')
test_model()
print('Testing complete')