2626from baselines .cifar import utils # local file import
2727from experimental .rank1_bnns import cifar_model # local file import
2828from experimental .rank1_bnns import refining # local file import
29+ from edward2 .google .rank1_pert .ensemble_keras import utils as be_utils
30+
2931import numpy as np
3032import tensorflow as tf
3133import tensorflow_datasets as tfds
8991 'training/evaluation summaries are stored.' )
9092flags .DEFINE_integer ('train_epochs' , 250 , 'Number of training epochs.' )
9193
92- flags .DEFINE_integer ('num_eval_samples' , 1 ,
94+ flags .DEFINE_integer ('num_eval_samples' , 4 ,
9395 'Number of model predictions to sample per example at '
9496 'eval time.' )
9597# Refinement flags.
114116flags .DEFINE_integer ('num_cores' , 8 , 'Number of TPU cores or number of GPUs.' )
115117flags .DEFINE_string ('tpu' , None ,
116118 'Name of the TPU. Only used if use_gpu is False.' )
119+
120+ flags .DEFINE_string ('similarity_metric' , 'cosine' , 'Similarity metric in '
121+ '[cosine, dpp_logdet]' )
122+ flags .DEFINE_string ('dpp_kernel' , 'linear' , 'Kernel for DPP log determinant' )
123+ flags .DEFINE_bool ('use_output_similarity' , False ,
124+ 'If true, compute similarity on the ensemble outputs.' )
125+ flags .DEFINE_enum ('diversity_scheduler' , 'LinearAnnealing' ,
126+ ['LinearAnnealing' , 'ExponentialDecay' , 'Fixed' ],
127+ 'Diversity coefficient scheduler..' )
128+ flags .DEFINE_float ('annealing_epochs' , 200 ,
129+ 'Number of epochs over which to linearly anneal' )
130+ flags .DEFINE_float ('diversity_coeff' , 0. , 'Diversity loss coefficient.' )
131+ flags .DEFINE_float ('diversity_decay_epoch' , 4 , 'Diversity decay epoch.' )
132+ flags .DEFINE_float ('diversity_decay_rate' , 0.97 , 'Rate of exponential decay.' )
133+ flags .DEFINE_integer ('diversity_start_epoch' , 100 ,
134+ 'Diversity loss starting epoch' )
135+
117136FLAGS = flags .FLAGS
118137
119138
@@ -218,7 +237,28 @@ def main(argv):
218237 optimizer = tf .keras .optimizers .SGD (lr_schedule ,
219238 momentum = 0.9 ,
220239 nesterov = True )
240+
241+ if FLAGS .diversity_scheduler == 'ExponentialDecay' :
242+ diversity_schedule = be_utils .ExponentialDecay (
243+ initial_coeff = FLAGS .diversity_coeff ,
244+ start_epoch = FLAGS .diversity_start_epoch ,
245+ decay_epoch = FLAGS .diversity_decay_epoch ,
246+ steps_per_epoch = steps_per_epoch ,
247+ decay_rate = FLAGS .diversity_decay_rate ,
248+ staircase = True )
249+
250+ elif FLAGS .diversity_scheduler == 'LinearAnnealing' :
251+ diversity_schedule = be_utils .LinearAnnealing (
252+ initial_coeff = FLAGS .diversity_coeff ,
253+ annealing_epochs = FLAGS .annealing_epochs ,
254+ steps_per_epoch = steps_per_epoch )
255+ else :
256+ diversity_schedule = lambda x : FLAGS .diversity_coeff
257+
221258 metrics = {
259+ 'train/similarity_loss' : tf .keras .metrics .Mean (),
260+ 'train/weights_similarity' : tf .keras .metrics .Mean (),
261+ 'train/outputs_similarity' : tf .keras .metrics .Mean (),
222262 'train/negative_log_likelihood' : tf .keras .metrics .Mean (),
223263 'train/accuracy' : tf .keras .metrics .SparseCategoricalAccuracy (),
224264 'train/loss' : tf .keras .metrics .Mean (),
@@ -230,6 +270,8 @@ def main(argv):
230270 'test/accuracy' : tf .keras .metrics .SparseCategoricalAccuracy (),
231271 'test/ece' : ed .metrics .ExpectedCalibrationError (
232272 num_bins = FLAGS .num_bins ),
273+ 'test/weights_similarity' : tf .keras .metrics .Mean (),
274+ 'test/outputs_similarity' : tf .keras .metrics .Mean (),
233275 }
234276 if FLAGS .ensemble_size > 1 :
235277 for i in range (FLAGS .ensemble_size ):
@@ -286,6 +328,22 @@ def step_fn(inputs):
286328 'bias' in var .name ):
287329 filtered_variables .append (tf .reshape (var , (- 1 ,)))
288330
331+ print (' > logits shape {}' .format (logits .shape ))
332+ outputs = tf .nn .softmax (logits )
333+ print (' > otuputs shape {}' .format (outputs .shape ))
334+ ensemble_outputs_tensor = tf .reshape (outputs ,[FLAGS .ensemble_size , - 1 , outputs .shape [- 1 ]])
335+ print (' > ensemble_outputs_tensor shape {}' .format (ensemble_outputs_tensor .shape ))
336+
337+ similarity_coeff , similarity_loss = be_utils .scaled_similarity_loss (
338+ FLAGS .diversity_coeff , diversity_schedule , optimizer .iterations ,
339+ FLAGS .similarity_metric , FLAGS .dpp_kernel ,
340+ model .trainable_variables , FLAGS .use_output_similarity , ensemble_outputs_tensor )
341+ weights_similarity = be_utils .fast_weights_similarity (
342+ model .trainable_variables , FLAGS .similarity_metric ,
343+ FLAGS .dpp_kernel )
344+ outputs_similarity = be_utils .outputs_similarity (
345+ ensemble_outputs_tensor , FLAGS .similarity_metric , FLAGS .dpp_kernel )
346+
289347 l2_loss = FLAGS .l2 * 2 * tf .nn .l2_loss (
290348 tf .concat (filtered_variables , axis = 0 ))
291349 kl = sum (model .losses ) / train_dataset_size
@@ -295,7 +353,7 @@ def step_fn(inputs):
295353 kl_loss = kl_scale * kl
296354
297355 # Scale the loss given the TPUStrategy will reduce sum all gradients.
298- loss = negative_log_likelihood + l2_loss + kl_loss
356+ loss = negative_log_likelihood + l2_loss + kl_loss + similarity_coeff * similarity_loss
299357 scaled_loss = loss / strategy .num_replicas_in_sync
300358
301359 grads = tape .gradient (scaled_loss , model .trainable_variables )
@@ -325,6 +383,10 @@ def step_fn(inputs):
325383 metrics ['train/kl' ].update_state (kl )
326384 metrics ['train/kl_scale' ].update_state (kl_scale )
327385 metrics ['train/accuracy' ].update_state (labels , logits )
386+ metrics ['train/similarity_loss' ].update_state (similarity_coeff * similarity_loss )
387+ metrics ['train/weights_similarity' ].update_state (weights_similarity )
388+ metrics ['train/outputs_similarity' ].update_state (outputs_similarity )
389+
328390
329391 strategy .run (step_fn , args = (next (iterator ),))
330392
@@ -346,6 +408,8 @@ def step_fn(inputs):
346408
347409 if FLAGS .ensemble_size > 1 :
348410 per_probs = tf .reduce_mean (probs , axis = 0 ) # marginalize samples
411+ outputs_similarity = be_utils .outputs_similarity (
412+ per_probs , FLAGS .similarity_metric , FLAGS .dpp_kernel )
349413 for i in range (FLAGS .ensemble_size ):
350414 member_probs = per_probs [i ]
351415 member_loss = tf .keras .losses .sparse_categorical_crossentropy (
@@ -370,6 +434,11 @@ def step_fn(inputs):
370434 negative_log_likelihood )
371435 metrics ['test/accuracy' ].update_state (labels , probs )
372436 metrics ['test/ece' ].update_state (labels , probs )
437+ weights_similarity = be_utils .fast_weights_similarity (
438+ model .trainable_variables , FLAGS .similarity_metric , FLAGS .dpp_kernel )
439+ metrics ['test/weights_similarity' ].update_state (weights_similarity )
440+ metrics ['test/outputs_similarity' ].update_state (outputs_similarity )
441+
373442 else :
374443 corrupt_metrics ['test/nll_{}' .format (dataset_name )].update_state (
375444 negative_log_likelihood )
0 commit comments