Skip to content

Commit 92f3122

Browse files
author
Ian Goodfellow
authored
Merge pull request #20 from openai/batch_size
Fix bugs when data size is not multiple of batch size
2 parents 757877b + 4a60b39 commit 92f3122

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

cleverhans/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ def batch_indices(batch_nb, data_length, batch_size):
7373
start -= shift
7474
end -= shift
7575

76-
return start, end
76+
return start, end

cleverhans/utils_tf.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
4040
:param X_train: numpy array with training inputs
4141
:param Y_train: numpy array with training outputs
4242
:param save: Boolean controling the save operation
43-
:param predictions_adv: if set with the adversarial example tensor,
44-
will run adversarial training
43+
:param predictions_adv: if set with the adversarial example tensor,
44+
will run adversarial training
4545
:return: True if model trained
4646
"""
4747
print "Starting model training using TensorFlow."
@@ -63,7 +63,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
6363
print("Epoch " + str(epoch))
6464

6565
# Compute number of batches
66-
nb_batches = int(math.ceil(len(X_train) / FLAGS.batch_size))
66+
nb_batches = int(math.ceil(float(len(X_train)) / FLAGS.batch_size))
67+
assert nb_batches * FLAGS.batch_size >= len(X_train)
6768

6869
prev = time.time()
6970
for batch in range(nb_batches):
@@ -80,6 +81,7 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
8081
train_step.run(feed_dict={x: X_train[start:end],
8182
y: Y_train[start:end],
8283
keras.backend.learning_phase(): 1})
84+
assert end >= len(X_train) # Check that all examples were used
8385

8486

8587
if save:
@@ -112,21 +114,29 @@ def tf_model_eval(sess, x, y, model, X_test, Y_test):
112114

113115
with sess.as_default():
114116
# Compute number of batches
115-
nb_batches = int(math.ceil(len(X_test) / FLAGS.batch_size))
117+
nb_batches = int(math.ceil(float(len(X_test)) / FLAGS.batch_size))
118+
assert nb_batches * FLAGS.batch_size >= len(X_test)
116119

117120
for batch in range(nb_batches):
118121
if batch % 100 == 0 and batch > 0:
119122
print("Batch " + str(batch))
120123

121-
# Compute batch start and end indices
122-
start, end = batch_indices(batch, len(X_test), FLAGS.batch_size)
124+
# Must not use the `batch_indices` function here, because it
125+
# repeats some examples.
126+
# It's acceptable to repeat during training, but not eval.
127+
start = batch * FLAGS.batch_size
128+
end = min(len(X_test), start + FLAGS.batch_size)
129+
cur_batch_size = end - start + 1
123130

124-
accuracy += acc_value.eval(feed_dict={x: X_test[start:end],
131+
# The last batch may be smaller than all others, so we need to
132+
# account for variable batch size here
133+
accuracy += cur_batch_size * acc_value.eval(feed_dict={x: X_test[start:end],
125134
y: Y_test[start:end],
126135
keras.backend.learning_phase(): 0})
136+
assert end >= len(X_test)
127137

128-
# Divide by number of batches to get final value
129-
accuracy /= nb_batches
138+
# Divide by number of examples to get final value
139+
accuracy /= len(X_test)
130140

131141
return accuracy
132142

tests/test_mnist_accuracy.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def main(argv=None):
5151

5252
# Train an MNIST model
5353
tf_model_train(sess, x, y, predictions, X_train, Y_train)
54-
54+
5555
# Evaluate the accuracy of the MNIST model on legitimate test examples
5656
accuracy = tf_model_eval(sess, x, y, predictions, X_test, Y_test)
57-
assert float(accuracy) >= 0.97
58-
59-
57+
assert float(accuracy) >= 0.97, accuracy
58+
59+
6060
if __name__ == '__main__':
6161
app.run()

0 commit comments

Comments
 (0)