Skip to content

Commit c367a5f

Browse files
GhassenJedward-bot
authored andcommitted
Add diversity regularization to Rank1 BNNs.
PiperOrigin-RevId: 317344143
1 parent 990e3e7 commit c367a5f

File tree

2 files changed

+137
-4
lines changed

2 files changed

+137
-4
lines changed

experimental/rank1_bnns/cifar.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from baselines.cifar import utils # local file import
2727
from experimental.rank1_bnns import cifar_model # local file import
2828
from experimental.rank1_bnns import refining # local file import
29+
from edward2.google.rank1_pert.ensemble_keras import utils as be_utils
30+
2931
import numpy as np
3032
import tensorflow as tf
3133
import tensorflow_datasets as tfds
@@ -89,7 +91,7 @@
8991
'training/evaluation summaries are stored.')
9092
flags.DEFINE_integer('train_epochs', 250, 'Number of training epochs.')
9193

92-
flags.DEFINE_integer('num_eval_samples', 1,
94+
flags.DEFINE_integer('num_eval_samples', 4,
9395
'Number of model predictions to sample per example at '
9496
'eval time.')
9597
# Refinement flags.
@@ -114,6 +116,23 @@
114116
flags.DEFINE_integer('num_cores', 8, 'Number of TPU cores or number of GPUs.')
115117
flags.DEFINE_string('tpu', None,
116118
'Name of the TPU. Only used if use_gpu is False.')
119+
120+
flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in '
121+
'[cosine, dpp_logdet]')
122+
flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant')
123+
flags.DEFINE_bool('use_output_similarity', False,
124+
'If true, compute similarity on the ensemble outputs.')
125+
flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing',
126+
['LinearAnnealing', 'ExponentialDecay', 'Fixed'],
127+
'Diversity coefficient scheduler..')
128+
flags.DEFINE_float('annealing_epochs', 200,
129+
'Number of epochs over which to linearly anneal')
130+
flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.')
131+
flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.')
132+
flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.')
133+
flags.DEFINE_integer('diversity_start_epoch', 100,
134+
'Diversity loss starting epoch')
135+
117136
FLAGS = flags.FLAGS
118137

119138

@@ -218,7 +237,28 @@ def main(argv):
218237
optimizer = tf.keras.optimizers.SGD(lr_schedule,
219238
momentum=0.9,
220239
nesterov=True)
240+
241+
if FLAGS.diversity_scheduler == 'ExponentialDecay':
242+
diversity_schedule = be_utils.ExponentialDecay(
243+
initial_coeff=FLAGS.diversity_coeff,
244+
start_epoch=FLAGS.diversity_start_epoch,
245+
decay_epoch=FLAGS.diversity_decay_epoch,
246+
steps_per_epoch=steps_per_epoch,
247+
decay_rate=FLAGS.diversity_decay_rate,
248+
staircase=True)
249+
250+
elif FLAGS.diversity_scheduler == 'LinearAnnealing':
251+
diversity_schedule = be_utils.LinearAnnealing(
252+
initial_coeff=FLAGS.diversity_coeff,
253+
annealing_epochs=FLAGS.annealing_epochs,
254+
steps_per_epoch=steps_per_epoch)
255+
else:
256+
diversity_schedule = lambda x: FLAGS.diversity_coeff
257+
221258
metrics = {
259+
'train/similarity_loss': tf.keras.metrics.Mean(),
260+
'train/weights_similarity': tf.keras.metrics.Mean(),
261+
'train/outputs_similarity': tf.keras.metrics.Mean(),
222262
'train/negative_log_likelihood': tf.keras.metrics.Mean(),
223263
'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
224264
'train/loss': tf.keras.metrics.Mean(),
@@ -230,6 +270,8 @@ def main(argv):
230270
'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
231271
'test/ece': ed.metrics.ExpectedCalibrationError(
232272
num_bins=FLAGS.num_bins),
273+
'test/weights_similarity': tf.keras.metrics.Mean(),
274+
'test/outputs_similarity': tf.keras.metrics.Mean(),
233275
}
234276
if FLAGS.ensemble_size > 1:
235277
for i in range(FLAGS.ensemble_size):
@@ -286,6 +328,22 @@ def step_fn(inputs):
286328
'bias' in var.name):
287329
filtered_variables.append(tf.reshape(var, (-1,)))
288330

331+
print(' > logits shape {}'.format(logits.shape))
332+
outputs = tf.nn.softmax(logits)
333+
print(' > otuputs shape {}'.format(outputs.shape))
334+
ensemble_outputs_tensor = tf.reshape(outputs,[FLAGS.ensemble_size, -1, outputs.shape[-1]])
335+
print(' > ensemble_outputs_tensor shape {}'.format(ensemble_outputs_tensor.shape))
336+
337+
similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss(
338+
FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations,
339+
FLAGS.similarity_metric, FLAGS.dpp_kernel,
340+
model.trainable_variables, FLAGS.use_output_similarity, ensemble_outputs_tensor)
341+
weights_similarity = be_utils.fast_weights_similarity(
342+
model.trainable_variables, FLAGS.similarity_metric,
343+
FLAGS.dpp_kernel)
344+
outputs_similarity = be_utils.outputs_similarity(
345+
ensemble_outputs_tensor, FLAGS.similarity_metric, FLAGS.dpp_kernel)
346+
289347
l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
290348
tf.concat(filtered_variables, axis=0))
291349
kl = sum(model.losses) / train_dataset_size
@@ -295,7 +353,7 @@ def step_fn(inputs):
295353
kl_loss = kl_scale * kl
296354

297355
# Scale the loss given the TPUStrategy will reduce sum all gradients.
298-
loss = negative_log_likelihood + l2_loss + kl_loss
356+
loss = negative_log_likelihood + l2_loss + kl_loss + similarity_coeff * similarity_loss
299357
scaled_loss = loss / strategy.num_replicas_in_sync
300358

301359
grads = tape.gradient(scaled_loss, model.trainable_variables)
@@ -325,6 +383,10 @@ def step_fn(inputs):
325383
metrics['train/kl'].update_state(kl)
326384
metrics['train/kl_scale'].update_state(kl_scale)
327385
metrics['train/accuracy'].update_state(labels, logits)
386+
metrics['train/similarity_loss'].update_state(similarity_coeff * similarity_loss)
387+
metrics['train/weights_similarity'].update_state(weights_similarity)
388+
metrics['train/outputs_similarity'].update_state(outputs_similarity)
389+
328390

329391
strategy.run(step_fn, args=(next(iterator),))
330392

@@ -346,6 +408,8 @@ def step_fn(inputs):
346408

347409
if FLAGS.ensemble_size > 1:
348410
per_probs = tf.reduce_mean(probs, axis=0) # marginalize samples
411+
outputs_similarity = be_utils.outputs_similarity(
412+
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)
349413
for i in range(FLAGS.ensemble_size):
350414
member_probs = per_probs[i]
351415
member_loss = tf.keras.losses.sparse_categorical_crossentropy(
@@ -370,6 +434,11 @@ def step_fn(inputs):
370434
negative_log_likelihood)
371435
metrics['test/accuracy'].update_state(labels, probs)
372436
metrics['test/ece'].update_state(labels, probs)
437+
weights_similarity = be_utils.fast_weights_similarity(
438+
model.trainable_variables, FLAGS.similarity_metric, FLAGS.dpp_kernel)
439+
metrics['test/weights_similarity'].update_state(weights_similarity)
440+
metrics['test/outputs_similarity'].update_state(outputs_similarity)
441+
373442
else:
374443
corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state(
375444
negative_log_likelihood)

experimental/rank1_bnns/imagenet.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
import edward2 as ed
2626
from baselines.imagenet import utils # local file import
2727
from experimental.rank1_bnns import imagenet_model # local file import
28-
import tensorflow as tf
28+
from edward2.google.rank1_pert.ensemble_keras import utils as be_utils
29+
import tensorflow.compat.v2 as tf
2930

3031
flags.DEFINE_integer('kl_annealing_epochs', 90,
3132
'Number of epochs over which to anneal the KL term to 1.')
@@ -83,6 +84,22 @@
8384
flags.DEFINE_integer('num_cores', 32, 'Number of TPU cores or number of GPUs.')
8485
flags.DEFINE_string('tpu', None,
8586
'Name of the TPU. Only used if use_gpu is False.')
87+
flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in '
88+
'[cosine, dpp_logdet]')
89+
flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant')
90+
flags.DEFINE_bool('use_output_similarity', False,
91+
'If true, compute similarity on the ensemble outputs.')
92+
flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing',
93+
['LinearAnnealing', 'ExponentialDecay', 'Fixed'],
94+
'Diversity coefficient scheduler.')
95+
flags.DEFINE_float('annealing_epochs', 200,
96+
'Number of epochs over which to linearly anneal.')
97+
flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.')
98+
flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.')
99+
flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.')
100+
flags.DEFINE_integer('diversity_start_epoch', 100,
101+
'Diversity loss starting epoch.')
102+
86103
FLAGS = flags.FLAGS
87104

88105
# Number of images in ImageNet-1k train dataset.
@@ -184,7 +201,28 @@ def main(argv):
184201
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
185202
momentum=0.9,
186203
nesterov=True)
204+
205+
if FLAGS.diversity_scheduler == 'ExponentialDecay':
206+
diversity_schedule = be_utils.ExponentialDecay(
207+
initial_coeff=FLAGS.diversity_coeff,
208+
start_epoch=FLAGS.diversity_start_epoch,
209+
decay_epoch=FLAGS.diversity_decay_epoch,
210+
steps_per_epoch=steps_per_epoch,
211+
decay_rate=FLAGS.diversity_decay_rate,
212+
staircase=True)
213+
214+
elif FLAGS.diversity_scheduler == 'LinearAnnealing':
215+
diversity_schedule = be_utils.LinearAnnealing(
216+
initial_coeff=FLAGS.diversity_coeff,
217+
annealing_epochs=FLAGS.annealing_epochs,
218+
steps_per_epoch=steps_per_epoch)
219+
else:
220+
diversity_schedule = lambda x: FLAGS.diversity_coeff
221+
187222
metrics = {
223+
'train/similarity_loss': tf.keras.metrics.Mean(),
224+
'train/weights_similarity': tf.keras.metrics.Mean(),
225+
'train/outputs_similarity': tf.keras.metrics.Mean(),
188226
'train/negative_log_likelihood': tf.keras.metrics.Mean(),
189227
'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
190228
'train/loss': tf.keras.metrics.Mean(),
@@ -196,6 +234,9 @@ def main(argv):
196234
'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
197235
'test/ece': ed.metrics.ExpectedCalibrationError(
198236
num_bins=FLAGS.num_bins),
237+
'test/weights_similarity': tf.keras.metrics.Mean(),
238+
'test/outputs_similarity': tf.keras.metrics.Mean(),
239+
199240
}
200241
if FLAGS.corruptions_interval > 0:
201242
corrupt_metrics = {}
@@ -261,6 +302,16 @@ def step_fn(inputs):
261302
diversity_results = ed.metrics.average_pairwise_diversity(
262303
per_probs, FLAGS.ensemble_size)
263304

305+
similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss(
306+
FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations,
307+
FLAGS.similarity_metric, FLAGS.dpp_kernel,
308+
model.trainable_variables, FLAGS.use_output_similarity, per_probs)
309+
weights_similarity = be_utils.fast_weights_similarity(
310+
model.trainable_variables, FLAGS.similarity_metric,
311+
FLAGS.dpp_kernel)
312+
outputs_similarity = be_utils.outputs_similarity(
313+
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)
314+
264315
negative_log_likelihood = tf.reduce_mean(
265316
tf.keras.losses.sparse_categorical_crossentropy(labels,
266317
logits,
@@ -282,7 +333,7 @@ def step_fn(inputs):
282333
kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs
283334
kl_scale = tf.minimum(1., kl_scale)
284335
kl_loss = kl_scale * kl
285-
loss = negative_log_likelihood + l2_loss + kl_loss
336+
loss = negative_log_likelihood + l2_loss + kl_loss + similarity_coeff * similarity_loss
286337
# Scale the loss given the TPUStrategy will reduce sum all gradients.
287338
scaled_loss = loss / strategy.num_replicas_in_sync
288339

@@ -310,6 +361,11 @@ def step_fn(inputs):
310361
metrics['train/kl'].update_state(kl)
311362
metrics['train/kl_scale'].update_state(kl_scale)
312363
metrics['train/accuracy'].update_state(labels, logits)
364+
metrics['train/similarity_loss'].update_state(similarity_coeff *
365+
similarity_loss)
366+
metrics['train/weights_similarity'].update_state(weights_similarity)
367+
metrics['train/outputs_similarity'].update_state(outputs_similarity)
368+
313369
if FLAGS.ensemble_size > 1:
314370
for k, v in diversity_results.items():
315371
training_diversity['train/' + k].update_state(v)
@@ -346,6 +402,14 @@ def step_fn(inputs):
346402
if dataset_name == 'clean':
347403
if FLAGS.ensemble_size > 1:
348404
per_probs = tf.reduce_mean(all_probs, axis=0) # marginalize samples
405+
outputs_similarity = be_utils.outputs_similarity(
406+
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)
407+
weights_similarity = be_utils.fast_weights_similarity(
408+
model.trainable_variables, FLAGS.similarity_metric,
409+
FLAGS.dpp_kernel)
410+
metrics['test/weights_similarity'].update_state(weights_similarity)
411+
metrics['test/outputs_similarity'].update_state(outputs_similarity)
412+
349413
diversity_results = ed.metrics.average_pairwise_diversity(
350414
per_probs, FLAGS.ensemble_size)
351415
for k, v in diversity_results.items():

0 commit comments

Comments
 (0)