From e2f8ec97dbda4ec10866b125036d98c71b145912 Mon Sep 17 00:00:00 2001 From: Ghassen Jerfel Date: Fri, 19 Jun 2020 11:21:18 -0700 Subject: [PATCH] Add diversity regularization to Rank1 BNNs. PiperOrigin-RevId: 317344143 --- experimental/rank1_bnns/cifar.py | 73 ++++++++++++++++++++++++++++- experimental/rank1_bnns/imagenet.py | 68 ++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 4 deletions(-) diff --git a/experimental/rank1_bnns/cifar.py b/experimental/rank1_bnns/cifar.py index 11c7f27a..518c249d 100644 --- a/experimental/rank1_bnns/cifar.py +++ b/experimental/rank1_bnns/cifar.py @@ -26,6 +26,8 @@ from baselines.cifar import utils # local file import from experimental.rank1_bnns import cifar_model # local file import from experimental.rank1_bnns import refining # local file import +from edward2.google.rank1_pert.ensemble_keras import utils as be_utils + import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -89,7 +91,7 @@ 'training/evaluation summaries are stored.') flags.DEFINE_integer('train_epochs', 250, 'Number of training epochs.') -flags.DEFINE_integer('num_eval_samples', 1, +flags.DEFINE_integer('num_eval_samples', 4, 'Number of model predictions to sample per example at ' 'eval time.') # Refinement flags. @@ -114,6 +116,23 @@ flags.DEFINE_integer('num_cores', 8, 'Number of TPU cores or number of GPUs.') flags.DEFINE_string('tpu', None, 'Name of the TPU. Only used if use_gpu is False.') + +flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in ' + '[cosine, dpp_logdet]') +flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant') +flags.DEFINE_bool('use_output_similarity', False, + 'If true, compute similarity on the ensemble outputs.') +flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing', + ['LinearAnnealing', 'ExponentialDecay', 'Fixed'], + 'Diversity coefficient scheduler..') +flags.DEFINE_float('annealing_epochs', 200, + 'Number of epochs over which to linearly anneal') +flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.') +flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.') +flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.') +flags.DEFINE_integer('diversity_start_epoch', 100, + 'Diversity loss starting epoch') + FLAGS = flags.FLAGS @@ -218,7 +237,28 @@ def main(argv): optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True) + + if FLAGS.diversity_scheduler == 'ExponentialDecay': + diversity_schedule = be_utils.ExponentialDecay( + initial_coeff=FLAGS.diversity_coeff, + start_epoch=FLAGS.diversity_start_epoch, + decay_epoch=FLAGS.diversity_decay_epoch, + steps_per_epoch=steps_per_epoch, + decay_rate=FLAGS.diversity_decay_rate, + staircase=True) + + elif FLAGS.diversity_scheduler == 'LinearAnnealing': + diversity_schedule = be_utils.LinearAnnealing( + initial_coeff=FLAGS.diversity_coeff, + annealing_epochs=FLAGS.annealing_epochs, + steps_per_epoch=steps_per_epoch) + else: + diversity_schedule = lambda x: FLAGS.diversity_coeff + metrics = { + 'train/similarity_loss': tf.keras.metrics.Mean(), + 'train/weights_similarity': tf.keras.metrics.Mean(), + 'train/outputs_similarity': tf.keras.metrics.Mean(), 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), @@ -230,6 +270,8 @@ def main(argv): 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), + 'test/weights_similarity': tf.keras.metrics.Mean(), + 'test/outputs_similarity': tf.keras.metrics.Mean(), } if FLAGS.ensemble_size > 1: for i in range(FLAGS.ensemble_size): @@ -286,6 +328,22 @@ def step_fn(inputs): 'bias' in var.name): filtered_variables.append(tf.reshape(var, (-1,))) + print(' > logits shape {}'.format(logits.shape)) + outputs = tf.nn.softmax(logits) + print(' > otuputs shape {}'.format(outputs.shape)) + ensemble_outputs_tensor = tf.reshape(outputs,[FLAGS.ensemble_size, -1, outputs.shape[-1]]) + print(' > ensemble_outputs_tensor shape {}'.format(ensemble_outputs_tensor.shape)) + + similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss( + FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations, + FLAGS.similarity_metric, FLAGS.dpp_kernel, + model.trainable_variables, FLAGS.use_output_similarity, ensemble_outputs_tensor) + weights_similarity = be_utils.fast_weights_similarity( + model.trainable_variables, FLAGS.similarity_metric, + FLAGS.dpp_kernel) + outputs_similarity = be_utils.outputs_similarity( + ensemble_outputs_tensor, FLAGS.similarity_metric, FLAGS.dpp_kernel) + l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) / train_dataset_size @@ -295,7 +353,7 @@ def step_fn(inputs): kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. - loss = negative_log_likelihood + l2_loss + kl_loss + loss = negative_log_likelihood + l2_loss + kl_loss + similarity_coeff * similarity_loss scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) @@ -325,6 +383,10 @@ def step_fn(inputs): metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, logits) + metrics['train/similarity_loss'].update_state(similarity_coeff * similarity_loss) + metrics['train/weights_similarity'].update_state(weights_similarity) + metrics['train/outputs_similarity'].update_state(outputs_similarity) + strategy.run(step_fn, args=(next(iterator),)) @@ -346,6 +408,8 @@ def step_fn(inputs): if FLAGS.ensemble_size > 1: per_probs = tf.reduce_mean(probs, axis=0) # marginalize samples + outputs_similarity = be_utils.outputs_similarity( + per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel) for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( @@ -370,6 +434,11 @@ def step_fn(inputs): negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) + weights_similarity = be_utils.fast_weights_similarity( + model.trainable_variables, FLAGS.similarity_metric, FLAGS.dpp_kernel) + metrics['test/weights_similarity'].update_state(weights_similarity) + metrics['test/outputs_similarity'].update_state(outputs_similarity) + else: corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) diff --git a/experimental/rank1_bnns/imagenet.py b/experimental/rank1_bnns/imagenet.py index d80c7900..95f6d699 100644 --- a/experimental/rank1_bnns/imagenet.py +++ b/experimental/rank1_bnns/imagenet.py @@ -25,7 +25,8 @@ import edward2 as ed from baselines.imagenet import utils # local file import from experimental.rank1_bnns import imagenet_model # local file import -import tensorflow as tf +from edward2.google.rank1_pert.ensemble_keras import utils as be_utils +import tensorflow.compat.v2 as tf flags.DEFINE_integer('kl_annealing_epochs', 90, 'Number of epochs over which to anneal the KL term to 1.') @@ -83,6 +84,22 @@ flags.DEFINE_integer('num_cores', 32, 'Number of TPU cores or number of GPUs.') flags.DEFINE_string('tpu', None, 'Name of the TPU. Only used if use_gpu is False.') +flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in ' + '[cosine, dpp_logdet]') +flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant') +flags.DEFINE_bool('use_output_similarity', False, + 'If true, compute similarity on the ensemble outputs.') +flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing', + ['LinearAnnealing', 'ExponentialDecay', 'Fixed'], + 'Diversity coefficient scheduler.') +flags.DEFINE_float('annealing_epochs', 200, + 'Number of epochs over which to linearly anneal.') +flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.') +flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.') +flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.') +flags.DEFINE_integer('diversity_start_epoch', 100, + 'Diversity loss starting epoch.') + FLAGS = flags.FLAGS # Number of images in ImageNet-1k train dataset. @@ -184,7 +201,28 @@ def main(argv): optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) + + if FLAGS.diversity_scheduler == 'ExponentialDecay': + diversity_schedule = be_utils.ExponentialDecay( + initial_coeff=FLAGS.diversity_coeff, + start_epoch=FLAGS.diversity_start_epoch, + decay_epoch=FLAGS.diversity_decay_epoch, + steps_per_epoch=steps_per_epoch, + decay_rate=FLAGS.diversity_decay_rate, + staircase=True) + + elif FLAGS.diversity_scheduler == 'LinearAnnealing': + diversity_schedule = be_utils.LinearAnnealing( + initial_coeff=FLAGS.diversity_coeff, + annealing_epochs=FLAGS.annealing_epochs, + steps_per_epoch=steps_per_epoch) + else: + diversity_schedule = lambda x: FLAGS.diversity_coeff + metrics = { + 'train/similarity_loss': tf.keras.metrics.Mean(), + 'train/weights_similarity': tf.keras.metrics.Mean(), + 'train/outputs_similarity': tf.keras.metrics.Mean(), 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), @@ -196,6 +234,9 @@ def main(argv): 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), + 'test/weights_similarity': tf.keras.metrics.Mean(), + 'test/outputs_similarity': tf.keras.metrics.Mean(), + } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} @@ -261,6 +302,16 @@ def step_fn(inputs): diversity_results = ed.metrics.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) + similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss( + FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations, + FLAGS.similarity_metric, FLAGS.dpp_kernel, + model.trainable_variables, FLAGS.use_output_similarity, per_probs) + weights_similarity = be_utils.fast_weights_similarity( + model.trainable_variables, FLAGS.similarity_metric, + FLAGS.dpp_kernel) + outputs_similarity = be_utils.outputs_similarity( + per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel) + negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, logits, @@ -282,7 +333,7 @@ def step_fn(inputs): kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl - loss = negative_log_likelihood + l2_loss + kl_loss + loss = negative_log_likelihood + l2_loss + kl_loss + similarity_coeff * similarity_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync @@ -310,6 +361,11 @@ def step_fn(inputs): metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, logits) + metrics['train/similarity_loss'].update_state(similarity_coeff * + similarity_loss) + metrics['train/weights_similarity'].update_state(weights_similarity) + metrics['train/outputs_similarity'].update_state(outputs_similarity) + if FLAGS.ensemble_size > 1: for k, v in diversity_results.items(): training_diversity['train/' + k].update_state(v) @@ -346,6 +402,14 @@ def step_fn(inputs): if dataset_name == 'clean': if FLAGS.ensemble_size > 1: per_probs = tf.reduce_mean(all_probs, axis=0) # marginalize samples + outputs_similarity = be_utils.outputs_similarity( + per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel) + weights_similarity = be_utils.fast_weights_similarity( + model.trainable_variables, FLAGS.similarity_metric, + FLAGS.dpp_kernel) + metrics['test/weights_similarity'].update_state(weights_similarity) + metrics['test/outputs_similarity'].update_state(outputs_similarity) + diversity_results = ed.metrics.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) for k, v in diversity_results.items():