@@ -630,7 +630,7 @@ def fit(self,
630
630
631
631
tic = time .time ()
632
632
633
- self ._cleanup (lr_decay_opt_states_reset )
633
+ self ._cleanup (lr_decay_opt_states_reset , process_manager = process_manager )
634
634
logger .info ("Training finished. Best checkpoint: %d. Best validation %s: %.6f" ,
635
635
self .state .best_checkpoint , early_stopping_metric , self .state .best_metric )
636
636
return self .state .best_metric
@@ -723,7 +723,6 @@ def _update_metrics(self,
723
723
checkpoint_metrics ["%s-val" % name ] = value
724
724
725
725
if process_manager is not None :
726
- process_manager .wait_to_finish ()
727
726
result = process_manager .collect_results ()
728
727
if result is not None :
729
728
decoded_checkpoint , decoder_metrics = result
@@ -749,12 +748,12 @@ def _cleanup(self, lr_decay_opt_states_reset: str, process_manager: Optional['De
749
748
utils .cleanup_params_files (self .model .output_dir , self .max_params_files_to_keep ,
750
749
self .state .checkpoint , self .state .best_checkpoint )
751
750
if process_manager is not None :
752
- process_manager .wait_to_finish ()
753
751
result = process_manager .collect_results ()
754
752
if result is not None :
755
753
decoded_checkpoint , decoder_metrics = result
756
754
self .state .metrics [decoded_checkpoint - 1 ].update (decoder_metrics )
757
755
self .tflogger .log_metrics (decoder_metrics , decoded_checkpoint )
756
+ utils .write_metrics_file (self .state .metrics , self .metrics_fname )
758
757
759
758
final_training_state_dirname = os .path .join (self .model .output_dir , C .TRAINING_STATE_DIRNAME )
760
759
if os .path .exists (final_training_state_dirname ):
@@ -1139,6 +1138,7 @@ def collect_results(self) -> Optional[Tuple[int, Dict[str, float]]]:
1139
1138
return None
1140
1139
decoded_checkpoint , decoder_metrics = self .decoder_metric_queue .get ()
1141
1140
assert self .decoder_metric_queue .empty ()
1141
+ logger .info ("Decoder-%d finished: %s" , decoded_checkpoint , decoder_metrics )
1142
1142
return decoded_checkpoint , decoder_metrics
1143
1143
1144
1144
def wait_to_finish (self ):
@@ -1147,14 +1147,15 @@ def wait_to_finish(self):
1147
1147
if not self .decoder_process .is_alive ():
1148
1148
self .decoder_process = None
1149
1149
return
1150
- logger .warning ("Waiting for process %s to finish." , self .decoder_process .name )
1150
+ name = self .decoder_process .name
1151
+ logger .warning ("Waiting for process %s to finish." , name )
1151
1152
wait_start = time .time ()
1152
1153
self .decoder_process .join ()
1153
1154
self .decoder_process = None
1154
1155
wait_time = int (time .time () - wait_start )
1155
- logger .warning ("Had to wait %d seconds for the checkpoint decoder to finish. Consider increasing the "
1156
+ logger .warning ("Had to wait %d seconds for the Checkpoint %s to finish. Consider increasing the "
1156
1157
"checkpoint frequency (updates between checkpoints, see %s) or reducing the size of the "
1157
- "validation samples that are decoded (see %s)." % (wait_time ,
1158
+ "validation samples that are decoded (see %s)." % (wait_time , name ,
1158
1159
C .TRAIN_ARGS_CHECKPOINT_FREQUENCY ,
1159
1160
C .TRAIN_ARGS_MONITOR_BLEU ))
1160
1161
0 commit comments