Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions experimental/rank1_bnns/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from baselines.cifar import utils # local file import
from experimental.rank1_bnns import cifar_model # local file import
from experimental.rank1_bnns import refining # local file import
from edward2.google.rank1_pert.ensemble_keras import utils as be_utils

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -89,7 +91,7 @@
'training/evaluation summaries are stored.')
flags.DEFINE_integer('train_epochs', 250, 'Number of training epochs.')

flags.DEFINE_integer('num_eval_samples', 1,
flags.DEFINE_integer('num_eval_samples', 4,
'Number of model predictions to sample per example at '
'eval time.')
# Refinement flags.
Expand All @@ -114,6 +116,23 @@
flags.DEFINE_integer('num_cores', 8, 'Number of TPU cores or number of GPUs.')
flags.DEFINE_string('tpu', None,
'Name of the TPU. Only used if use_gpu is False.')

flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in '
'[cosine, dpp_logdet]')
flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant')
flags.DEFINE_bool('use_output_similarity', False,
'If true, compute similarity on the ensemble outputs.')
flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing',
['LinearAnnealing', 'ExponentialDecay', 'Fixed'],
'Diversity coefficient scheduler..')
flags.DEFINE_float('annealing_epochs', 200,
'Number of epochs over which to linearly anneal')
flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.')
flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.')
flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.')
flags.DEFINE_integer('diversity_start_epoch', 100,
'Diversity loss starting epoch')

FLAGS = flags.FLAGS


Expand Down Expand Up @@ -218,7 +237,28 @@ def main(argv):
optimizer = tf.keras.optimizers.SGD(lr_schedule,
momentum=0.9,
nesterov=True)

if FLAGS.diversity_scheduler == 'ExponentialDecay':
diversity_schedule = be_utils.ExponentialDecay(
initial_coeff=FLAGS.diversity_coeff,
start_epoch=FLAGS.diversity_start_epoch,
decay_epoch=FLAGS.diversity_decay_epoch,
steps_per_epoch=steps_per_epoch,
decay_rate=FLAGS.diversity_decay_rate,
staircase=True)

elif FLAGS.diversity_scheduler == 'LinearAnnealing':
diversity_schedule = be_utils.LinearAnnealing(
initial_coeff=FLAGS.diversity_coeff,
annealing_epochs=FLAGS.annealing_epochs,
steps_per_epoch=steps_per_epoch)
else:
diversity_schedule = lambda x: FLAGS.diversity_coeff

metrics = {
'train/similarity_loss': tf.keras.metrics.Mean(),
'train/weights_similarity': tf.keras.metrics.Mean(),
'train/outputs_similarity': tf.keras.metrics.Mean(),
'train/negative_log_likelihood': tf.keras.metrics.Mean(),
'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'train/loss': tf.keras.metrics.Mean(),
Expand All @@ -230,6 +270,8 @@ def main(argv):
'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'test/ece': ed.metrics.ExpectedCalibrationError(
num_bins=FLAGS.num_bins),
'test/weights_similarity': tf.keras.metrics.Mean(),
'test/outputs_similarity': tf.keras.metrics.Mean(),
}
if FLAGS.ensemble_size > 1:
for i in range(FLAGS.ensemble_size):
Expand Down Expand Up @@ -286,6 +328,22 @@ def step_fn(inputs):
'bias' in var.name):
filtered_variables.append(tf.reshape(var, (-1,)))

print(' > logits shape {}'.format(logits.shape))
outputs = tf.nn.softmax(logits)
print(' > otuputs shape {}'.format(outputs.shape))
ensemble_outputs_tensor = tf.reshape(outputs,[FLAGS.ensemble_size, -1, outputs.shape[-1]])
print(' > ensemble_outputs_tensor shape {}'.format(ensemble_outputs_tensor.shape))

similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss(
FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations,
FLAGS.similarity_metric, FLAGS.dpp_kernel,
model.trainable_variables, FLAGS.use_output_similarity, ensemble_outputs_tensor)
weights_similarity = be_utils.fast_weights_similarity(
model.trainable_variables, FLAGS.similarity_metric,
FLAGS.dpp_kernel)
outputs_similarity = be_utils.outputs_similarity(
ensemble_outputs_tensor, FLAGS.similarity_metric, FLAGS.dpp_kernel)

l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
kl = sum(model.losses) / train_dataset_size
Expand All @@ -295,7 +353,7 @@ def step_fn(inputs):
kl_loss = kl_scale * kl

# Scale the loss given the TPUStrategy will reduce sum all gradients.
loss = negative_log_likelihood + l2_loss + kl_loss
loss = negative_log_likelihood + l2_loss + kl_loss + similarity_coeff * similarity_loss
scaled_loss = loss / strategy.num_replicas_in_sync

grads = tape.gradient(scaled_loss, model.trainable_variables)
Expand Down Expand Up @@ -325,6 +383,10 @@ def step_fn(inputs):
metrics['train/kl'].update_state(kl)
metrics['train/kl_scale'].update_state(kl_scale)
metrics['train/accuracy'].update_state(labels, logits)
metrics['train/similarity_loss'].update_state(similarity_coeff * similarity_loss)
metrics['train/weights_similarity'].update_state(weights_similarity)
metrics['train/outputs_similarity'].update_state(outputs_similarity)


strategy.run(step_fn, args=(next(iterator),))

Expand All @@ -346,6 +408,8 @@ def step_fn(inputs):

if FLAGS.ensemble_size > 1:
per_probs = tf.reduce_mean(probs, axis=0) # marginalize samples
outputs_similarity = be_utils.outputs_similarity(
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)
for i in range(FLAGS.ensemble_size):
member_probs = per_probs[i]
member_loss = tf.keras.losses.sparse_categorical_crossentropy(
Expand All @@ -370,6 +434,11 @@ def step_fn(inputs):
negative_log_likelihood)
metrics['test/accuracy'].update_state(labels, probs)
metrics['test/ece'].update_state(labels, probs)
weights_similarity = be_utils.fast_weights_similarity(
model.trainable_variables, FLAGS.similarity_metric, FLAGS.dpp_kernel)
metrics['test/weights_similarity'].update_state(weights_similarity)
metrics['test/outputs_similarity'].update_state(outputs_similarity)

else:
corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state(
negative_log_likelihood)
Expand Down
68 changes: 66 additions & 2 deletions experimental/rank1_bnns/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import edward2 as ed
from baselines.imagenet import utils # local file import
from experimental.rank1_bnns import imagenet_model # local file import
import tensorflow as tf
from edward2.google.rank1_pert.ensemble_keras import utils as be_utils
import tensorflow.compat.v2 as tf

flags.DEFINE_integer('kl_annealing_epochs', 90,
'Number of epochs over which to anneal the KL term to 1.')
Expand Down Expand Up @@ -83,6 +84,22 @@
flags.DEFINE_integer('num_cores', 32, 'Number of TPU cores or number of GPUs.')
flags.DEFINE_string('tpu', None,
'Name of the TPU. Only used if use_gpu is False.')
flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in '
'[cosine, dpp_logdet]')
flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant')
flags.DEFINE_bool('use_output_similarity', False,
'If true, compute similarity on the ensemble outputs.')
flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing',
['LinearAnnealing', 'ExponentialDecay', 'Fixed'],
'Diversity coefficient scheduler.')
flags.DEFINE_float('annealing_epochs', 200,
'Number of epochs over which to linearly anneal.')
flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.')
flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.')
flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.')
flags.DEFINE_integer('diversity_start_epoch', 100,
'Diversity loss starting epoch.')

FLAGS = flags.FLAGS

# Number of images in ImageNet-1k train dataset.
Expand Down Expand Up @@ -184,7 +201,28 @@ def main(argv):
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
momentum=0.9,
nesterov=True)

if FLAGS.diversity_scheduler == 'ExponentialDecay':
diversity_schedule = be_utils.ExponentialDecay(
initial_coeff=FLAGS.diversity_coeff,
start_epoch=FLAGS.diversity_start_epoch,
decay_epoch=FLAGS.diversity_decay_epoch,
steps_per_epoch=steps_per_epoch,
decay_rate=FLAGS.diversity_decay_rate,
staircase=True)

elif FLAGS.diversity_scheduler == 'LinearAnnealing':
diversity_schedule = be_utils.LinearAnnealing(
initial_coeff=FLAGS.diversity_coeff,
annealing_epochs=FLAGS.annealing_epochs,
steps_per_epoch=steps_per_epoch)
else:
diversity_schedule = lambda x: FLAGS.diversity_coeff

metrics = {
'train/similarity_loss': tf.keras.metrics.Mean(),
'train/weights_similarity': tf.keras.metrics.Mean(),
'train/outputs_similarity': tf.keras.metrics.Mean(),
'train/negative_log_likelihood': tf.keras.metrics.Mean(),
'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'train/loss': tf.keras.metrics.Mean(),
Expand All @@ -196,6 +234,9 @@ def main(argv):
'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'test/ece': ed.metrics.ExpectedCalibrationError(
num_bins=FLAGS.num_bins),
'test/weights_similarity': tf.keras.metrics.Mean(),
'test/outputs_similarity': tf.keras.metrics.Mean(),

}
if FLAGS.corruptions_interval > 0:
corrupt_metrics = {}
Expand Down Expand Up @@ -261,6 +302,16 @@ def step_fn(inputs):
diversity_results = ed.metrics.average_pairwise_diversity(
per_probs, FLAGS.ensemble_size)

similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss(
FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations,
FLAGS.similarity_metric, FLAGS.dpp_kernel,
model.trainable_variables, FLAGS.use_output_similarity, per_probs)
weights_similarity = be_utils.fast_weights_similarity(
model.trainable_variables, FLAGS.similarity_metric,
FLAGS.dpp_kernel)
outputs_similarity = be_utils.outputs_similarity(
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)

negative_log_likelihood = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(labels,
logits,
Expand All @@ -282,7 +333,7 @@ def step_fn(inputs):
kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs
kl_scale = tf.minimum(1., kl_scale)
kl_loss = kl_scale * kl
loss = negative_log_likelihood + l2_loss + kl_loss
loss = negative_log_likelihood + l2_loss + kl_loss + similarity_coeff * similarity_loss
# Scale the loss given the TPUStrategy will reduce sum all gradients.
scaled_loss = loss / strategy.num_replicas_in_sync

Expand Down Expand Up @@ -310,6 +361,11 @@ def step_fn(inputs):
metrics['train/kl'].update_state(kl)
metrics['train/kl_scale'].update_state(kl_scale)
metrics['train/accuracy'].update_state(labels, logits)
metrics['train/similarity_loss'].update_state(similarity_coeff *
similarity_loss)
metrics['train/weights_similarity'].update_state(weights_similarity)
metrics['train/outputs_similarity'].update_state(outputs_similarity)

if FLAGS.ensemble_size > 1:
for k, v in diversity_results.items():
training_diversity['train/' + k].update_state(v)
Expand Down Expand Up @@ -346,6 +402,14 @@ def step_fn(inputs):
if dataset_name == 'clean':
if FLAGS.ensemble_size > 1:
per_probs = tf.reduce_mean(all_probs, axis=0) # marginalize samples
outputs_similarity = be_utils.outputs_similarity(
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)
weights_similarity = be_utils.fast_weights_similarity(
model.trainable_variables, FLAGS.similarity_metric,
FLAGS.dpp_kernel)
metrics['test/weights_similarity'].update_state(weights_similarity)
metrics['test/outputs_similarity'].update_state(outputs_similarity)

diversity_results = ed.metrics.average_pairwise_diversity(
per_probs, FLAGS.ensemble_size)
for k, v in diversity_results.items():
Expand Down