|
29 | 29 |
|
30 | 30 | flags.DEFINE_integer('seed', 42, 'Random seed.') |
31 | 31 | flags.DEFINE_integer('per_core_batch_size', 64, 'Batch size per TPU core/GPU.') |
32 | | -flags.DEFINE_float('base_learning_rate', 0.1, |
| 32 | +flags.DEFINE_float('base_learning_rate', 0.05, |
33 | 33 | 'Base learning rate when total batch size is 128. It is ' |
34 | 34 | 'scaled by the ratio of the total batch size to 128.') |
35 | 35 | flags.DEFINE_integer('lr_warmup_epochs', 1, |
|
40 | 40 | 'Epochs to decay learning rate by.') |
41 | 41 | flags.DEFINE_float('l2', 3e-4, 'L2 regularization coefficient.') |
42 | 42 | flags.DEFINE_float('dropout_rate', 0.1, 'Dropout rate.') |
| 43 | +flags.DEFINE_integer('num_dropout_samples', 1, |
| 44 | + 'Number of dropout samples to use for prediction.') |
| 45 | + |
43 | 46 | flags.DEFINE_enum('dataset', 'cifar10', |
44 | 47 | enum_values=['cifar10', 'cifar100'], |
45 | 48 | help='Dataset.') |
|
50 | 53 | flags.DEFINE_integer('corruptions_interval', 50, |
51 | 54 | 'Number of epochs between evaluating on the corrupted ' |
52 | 55 | 'test data. Use -1 to never evaluate.') |
53 | | -flags.DEFINE_integer('checkpoint_interval', 25, |
| 56 | +flags.DEFINE_integer('checkpoint_interval', -1, |
54 | 57 | 'Number of epochs between saving checkpoints. Use -1 to ' |
55 | 58 | 'never save checkpoints.') |
56 | 59 | flags.DEFINE_integer('num_bins', 15, 'Number of bins for ECE.') |
@@ -342,10 +345,19 @@ def test_step(iterator, dataset_name): |
342 | 345 | def step_fn(inputs): |
343 | 346 | """Per-Replica StepFn.""" |
344 | 347 | images, labels = inputs |
345 | | - logits = model(images, training=False) |
346 | | - if FLAGS.use_bfloat16: |
347 | | - logits = tf.cast(logits, tf.float32) |
348 | | - probs = tf.nn.softmax(logits) |
| 348 | + |
| 349 | + logits_list = [] |
| 350 | + for _ in range(FLAGS.num_dropout_samples): |
| 351 | + logits = model(images, training=False) |
| 352 | + if FLAGS.use_bfloat16: |
| 353 | + logits = tf.cast(logits, tf.float32) |
| 354 | + logits_list.append(logits) |
| 355 | + |
| 356 | + # Logits dimension is (num_samples, batch_size, num_classes). |
| 357 | + logits_list = tf.stack(logits_list, axis=0) |
| 358 | + probs_list = tf.nn.softmax(logits_list) |
| 359 | + probs = tf.reduce_mean(probs_list, axis=0) |
| 360 | + |
349 | 361 | negative_log_likelihood = tf.reduce_mean( |
350 | 362 | tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) |
351 | 363 |
|
|
0 commit comments