Skip to content

Commit 990e3e7

Browse files
jereliuedward-bot
authored andcommitted
Add num_dropout_sample option to Monte Carlo Dropout.
PiperOrigin-RevId: 315518655
1 parent 3fe7451 commit 990e3e7

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

baselines/cifar/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
| ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- |
77
| Deterministic | 1e-3 / 0.159 | 99.9% / 96.0% | 1e-3 / 0.0231 | 1.05 / 76.1% / 0.153 | 1.2 (8 TPUv2 cores) | 36.5M |
88
| BatchEnsemble (size=4) | 0.08 / 0.143 | 99.9% / 96.2% | 5e-5 / 0.0206 | 1.02 / 77.5% / 0.129 | 5.4 (8 TPUv2 cores) | 36.6M |
9-
| Dropout | 2e-3 / 0.160 | 99.9% / 95.9% | 2e-3 / 0.0241 | 1.27 / 68.8% / 0.166 | 1.2 (8 TPUv2 cores) | 36.5M |
9+
| Monte Carlo Dropout (size=1) | 2e-3 / 0.160 | 99.9% / 95.9% | 2e-3 / 0.0241 | 1.27 / 68.8% / 0.166 | 1.2 (8 TPUv2 cores) | 36.5M |
10+
| Monte Carlo Dropout (size=30) | 1e-3 / 0.145 | 99.9% / 96.1% | 1.5e-3 / 0.019 | 1.27 / 70.0% / 0.167 | 1.2 (8 TPUv2 cores) | 36.5M |
1011
| Ensemble (size=4) | 2e-3 / 0.114 | 99.9% / 96.6% | - / 0.010 | 0.81 / 77.9% / 0.087 | 1.2 (32 TPUv2 cores) | 146M |
1112
| Variational inference | 1e-3 / 0.211 | 99.9% / 94.7% | 1e-3 / 0.029 | 1.46 / 71.3% / 0.181 | 5.5 (8 TPUv2 cores) | 73M |
1213

@@ -16,7 +17,8 @@
1617
| ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- |
1718
| Deterministic<sup>10</sup> | 1e-3 / 0.875 | 99.9% / 79.8% | 2e-3 / 0.0857 | 2.70 / 51.37% / 0.239 | 1.1 (8 TPUv2 cores) | 36.5M |
1819
| BatchEnsemble (size=4) | 3e-3 / 0.740 | 99.7% / 81.5% | 2e-3 / 0.0561 | 2.49 / 54.1% / 0.191 | 5.5 (8 TPUv2 cores) | 36.6M |
19-
| Dropout | 1e-2 / 0.830 | 99.9% / 79.6% | 9e-3 / 0.0501 | 2.90 / 42.63%/ 0.202 | 1.1 (8 TPUv2 cores) | 36.5M |
20+
| Monte Carlo Dropout (size=1) | 1e-2 / 0.830 | 99.9% / 79.6% | 9e-3 / 0.0501 | 2.90 / 42.63% / 0.202 | 1.1 (8 TPUv2 cores) | 36.5M |
21+
| Monte Carlo Dropout (size=30) | 6e-3 / 0.785 | 99.9% / 80.7% | 5e-3 / 0.0487 | 2.73 / 46.2 / 0.207 | 1.1 (8 TPUv2 cores) | 36.5M |
2022
| Ensemble (size=4) | 0.003 / 0.666 | 99.9% / 82.7% | - / 0.021 | 2.27 / 54.1% / 0.138 | 1.1 (32 TPUv2 cores) | 146M |
2123
| Variational inference | 3e-3 / 0.944 | 99.9% / 77.8% | 2e-3 / 0.097 | 3.18 / 48.2% / 0.271 | 5.5 (8 TPUv2 cores) | 73M |
2224

baselines/cifar/dropout.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
flags.DEFINE_integer('seed', 42, 'Random seed.')
3131
flags.DEFINE_integer('per_core_batch_size', 64, 'Batch size per TPU core/GPU.')
32-
flags.DEFINE_float('base_learning_rate', 0.1,
32+
flags.DEFINE_float('base_learning_rate', 0.05,
3333
'Base learning rate when total batch size is 128. It is '
3434
'scaled by the ratio of the total batch size to 128.')
3535
flags.DEFINE_integer('lr_warmup_epochs', 1,
@@ -40,6 +40,9 @@
4040
'Epochs to decay learning rate by.')
4141
flags.DEFINE_float('l2', 3e-4, 'L2 regularization coefficient.')
4242
flags.DEFINE_float('dropout_rate', 0.1, 'Dropout rate.')
43+
flags.DEFINE_integer('num_dropout_samples', 1,
44+
'Number of dropout samples to use for prediction.')
45+
4346
flags.DEFINE_enum('dataset', 'cifar10',
4447
enum_values=['cifar10', 'cifar100'],
4548
help='Dataset.')
@@ -50,7 +53,7 @@
5053
flags.DEFINE_integer('corruptions_interval', 50,
5154
'Number of epochs between evaluating on the corrupted '
5255
'test data. Use -1 to never evaluate.')
53-
flags.DEFINE_integer('checkpoint_interval', 25,
56+
flags.DEFINE_integer('checkpoint_interval', -1,
5457
'Number of epochs between saving checkpoints. Use -1 to '
5558
'never save checkpoints.')
5659
flags.DEFINE_integer('num_bins', 15, 'Number of bins for ECE.')
@@ -342,10 +345,19 @@ def test_step(iterator, dataset_name):
342345
def step_fn(inputs):
343346
"""Per-Replica StepFn."""
344347
images, labels = inputs
345-
logits = model(images, training=False)
346-
if FLAGS.use_bfloat16:
347-
logits = tf.cast(logits, tf.float32)
348-
probs = tf.nn.softmax(logits)
348+
349+
logits_list = []
350+
for _ in range(FLAGS.num_dropout_samples):
351+
logits = model(images, training=False)
352+
if FLAGS.use_bfloat16:
353+
logits = tf.cast(logits, tf.float32)
354+
logits_list.append(logits)
355+
356+
# Logits dimension is (num_samples, batch_size, num_classes).
357+
logits_list = tf.stack(logits_list, axis=0)
358+
probs_list = tf.nn.softmax(logits_list)
359+
probs = tf.reduce_mean(probs_list, axis=0)
360+
349361
negative_log_likelihood = tf.reduce_mean(
350362
tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
351363

0 commit comments

Comments
 (0)