@@ -40,8 +40,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
40
40
:param X_train: numpy array with training inputs
41
41
:param Y_train: numpy array with training outputs
42
42
:param save: Boolean controling the save operation
43
- :param predictions_adv: if set with the adversarial example tensor,
44
- will run adversarial training
43
+ :param predictions_adv: if set with the adversarial example tensor,
44
+ will run adversarial training
45
45
:return: True if model trained
46
46
"""
47
47
print "Starting model training using TensorFlow."
@@ -63,7 +63,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
63
63
print ("Epoch " + str (epoch ))
64
64
65
65
# Compute number of batches
66
- nb_batches = int (math .ceil (len (X_train ) / FLAGS .batch_size ))
66
+ nb_batches = int (math .ceil (float (len (X_train )) / FLAGS .batch_size ))
67
+ assert nb_batches * FLAGS .batch_size >= len (X_train )
67
68
68
69
prev = time .time ()
69
70
for batch in range (nb_batches ):
@@ -80,6 +81,7 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
80
81
train_step .run (feed_dict = {x : X_train [start :end ],
81
82
y : Y_train [start :end ],
82
83
keras .backend .learning_phase (): 1 })
84
+ assert end >= len (X_train ) # Check that all examples were used
83
85
84
86
85
87
if save :
@@ -112,21 +114,29 @@ def tf_model_eval(sess, x, y, model, X_test, Y_test):
112
114
113
115
with sess .as_default ():
114
116
# Compute number of batches
115
- nb_batches = int (math .ceil (len (X_test ) / FLAGS .batch_size ))
117
+ nb_batches = int (math .ceil (float (len (X_test )) / FLAGS .batch_size ))
118
+ assert nb_batches * FLAGS .batch_size >= len (X_test )
116
119
117
120
for batch in range (nb_batches ):
118
121
if batch % 100 == 0 and batch > 0 :
119
122
print ("Batch " + str (batch ))
120
123
121
- # Compute batch start and end indices
122
- start , end = batch_indices (batch , len (X_test ), FLAGS .batch_size )
124
+ # Must not use the `batch_indices` function here, because it
125
+ # repeats some examples.
126
+ # It's acceptable to repeat during training, but not eval.
127
+ start = batch * FLAGS .batch_size
128
+ end = min (len (X_test ), start + FLAGS .batch_size )
129
+ cur_batch_size = end - start + 1
123
130
124
- accuracy += acc_value .eval (feed_dict = {x : X_test [start :end ],
131
+ # The last batch may be smaller than all others, so we need to
132
+ # account for variable batch size here
133
+ accuracy += cur_batch_size * acc_value .eval (feed_dict = {x : X_test [start :end ],
125
134
y : Y_test [start :end ],
126
135
keras .backend .learning_phase (): 0 })
136
+ assert end >= len (X_test )
127
137
128
- # Divide by number of batches to get final value
129
- accuracy /= nb_batches
138
+ # Divide by number of examples to get final value
139
+ accuracy /= len ( X_test )
130
140
131
141
return accuracy
132
142
0 commit comments