forked from amir-saniyan/AlexNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
44 lines (30 loc) · 1.13 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
import tensorflow as tf
from alexnet import AlexNet
from dataset_helper import read_cifar_10
INPUT_WIDTH = 70
INPUT_HEIGHT = 70
INPUT_CHANNELS = 3
NUM_CLASSES = 10
LEARNING_RATE = 0.001 # Original value: 0.01
MOMENTUM = 0.9
KEEP_PROB = 0.5
EPOCHS = 100
BATCH_SIZE = 128
print('Reading CIFAR-10...')
X_train, Y_train, X_test, Y_test = read_cifar_10(image_width=INPUT_WIDTH, image_height=INPUT_HEIGHT)
alexnet = AlexNet(input_width=INPUT_WIDTH, input_height=INPUT_HEIGHT, input_channels=INPUT_CHANNELS,
num_classes=NUM_CLASSES, learning_rate=LEARNING_RATE, momentum=MOMENTUM, keep_prob=KEEP_PROB)
with tf.Session() as sess:
print('Evaluating dataset...')
print()
sess.run(tf.global_variables_initializer())
print('Loading model...')
print()
alexnet.restore(sess, './model')
print('Evaluating...')
train_accuracy = alexnet.evaluate(sess, X_train, Y_train, BATCH_SIZE)
test_accuracy = alexnet.evaluate(sess, X_test, Y_test, BATCH_SIZE)
print('Train Accuracy = {:.3f}'.format(train_accuracy))
print('Test Accuracy = {:.3f}'.format(test_accuracy))
print()