Skip to content

Commit 12fab9b

Browse files
authored
Ensure last checkpoint decoder results are written to metrics fail when cleaning up training (#368)
1 parent dd7933d commit 12fab9b

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa
1010

1111
Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.
1212

13+
## [1.18.5]
14+
### Fixed
15+
- Fixed a problem with trainer not waiting for the last checkpoint decoder (#367).
16+
1317
## [1.18.4]
1418
### Added
1519
- Added options to control training length w.r.t number of updates/batches or number of samples:

sockeye/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
__version__ = '1.18.4'
14+
__version__ = '1.18.5'

sockeye/training.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def fit(self,
630630

631631
tic = time.time()
632632

633-
self._cleanup(lr_decay_opt_states_reset)
633+
self._cleanup(lr_decay_opt_states_reset, process_manager=process_manager)
634634
logger.info("Training finished. Best checkpoint: %d. Best validation %s: %.6f",
635635
self.state.best_checkpoint, early_stopping_metric, self.state.best_metric)
636636
return self.state.best_metric
@@ -723,7 +723,6 @@ def _update_metrics(self,
723723
checkpoint_metrics["%s-val" % name] = value
724724

725725
if process_manager is not None:
726-
process_manager.wait_to_finish()
727726
result = process_manager.collect_results()
728727
if result is not None:
729728
decoded_checkpoint, decoder_metrics = result
@@ -749,12 +748,12 @@ def _cleanup(self, lr_decay_opt_states_reset: str, process_manager: Optional['De
749748
utils.cleanup_params_files(self.model.output_dir, self.max_params_files_to_keep,
750749
self.state.checkpoint, self.state.best_checkpoint)
751750
if process_manager is not None:
752-
process_manager.wait_to_finish()
753751
result = process_manager.collect_results()
754752
if result is not None:
755753
decoded_checkpoint, decoder_metrics = result
756754
self.state.metrics[decoded_checkpoint - 1].update(decoder_metrics)
757755
self.tflogger.log_metrics(decoder_metrics, decoded_checkpoint)
756+
utils.write_metrics_file(self.state.metrics, self.metrics_fname)
758757

759758
final_training_state_dirname = os.path.join(self.model.output_dir, C.TRAINING_STATE_DIRNAME)
760759
if os.path.exists(final_training_state_dirname):
@@ -1139,6 +1138,7 @@ def collect_results(self) -> Optional[Tuple[int, Dict[str, float]]]:
11391138
return None
11401139
decoded_checkpoint, decoder_metrics = self.decoder_metric_queue.get()
11411140
assert self.decoder_metric_queue.empty()
1141+
logger.info("Decoder-%d finished: %s", decoded_checkpoint, decoder_metrics)
11421142
return decoded_checkpoint, decoder_metrics
11431143

11441144
def wait_to_finish(self):
@@ -1147,14 +1147,15 @@ def wait_to_finish(self):
11471147
if not self.decoder_process.is_alive():
11481148
self.decoder_process = None
11491149
return
1150-
logger.warning("Waiting for process %s to finish.", self.decoder_process.name)
1150+
name = self.decoder_process.name
1151+
logger.warning("Waiting for process %s to finish.", name)
11511152
wait_start = time.time()
11521153
self.decoder_process.join()
11531154
self.decoder_process = None
11541155
wait_time = int(time.time() - wait_start)
1155-
logger.warning("Had to wait %d seconds for the checkpoint decoder to finish. Consider increasing the "
1156+
logger.warning("Had to wait %d seconds for the Checkpoint %s to finish. Consider increasing the "
11561157
"checkpoint frequency (updates between checkpoints, see %s) or reducing the size of the "
1157-
"validation samples that are decoded (see %s)." % (wait_time,
1158+
"validation samples that are decoded (see %s)." % (wait_time, name,
11581159
C.TRAIN_ARGS_CHECKPOINT_FREQUENCY,
11591160
C.TRAIN_ARGS_MONITOR_BLEU))
11601161

0 commit comments

Comments
 (0)