From f7a972aaed9579d52f78a29222443da403892e25 Mon Sep 17 00:00:00 2001 From: Kavya Date: Wed, 22 Jan 2025 14:57:41 +0530 Subject: [PATCH] Save checkpoint to temporary folder first Since partial/missing files due to failures throw error during load --- src/transformers/trainer.py | 51 ++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8179ee9f5306..27f08e2f5554 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -18,6 +18,7 @@ import contextlib import copy +import errno import functools import glob import importlib.metadata @@ -3210,31 +3211,41 @@ def _save_checkpoint(self, model, trial): self.store_flos() run_dir = self._get_output_dir(trial=trial) - output_dir = os.path.join(run_dir, checkpoint_folder) - self.save_model(output_dir, _internal_call=True) + checkpoint_dir = os.path.join(run_dir, checkpoint_folder) + with tempfile.TemporaryDirectory(prefix=f"tmp-{PREFIX_CHECKPOINT_DIR}-", dir=run_dir) as output_dir: + self.save_model(output_dir, _internal_call=True) - if not self.args.save_only_model: - # Save optimizer and scheduler - self._save_optimizer_and_scheduler(output_dir) - # Save RNG state - self._save_rng_state(output_dir) + if not self.args.save_only_model: + # Save optimizer and scheduler + self._save_optimizer_and_scheduler(output_dir) + # Save RNG state + self._save_rng_state(output_dir) - # Save the Trainer state - if self.args.should_save: - # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently - for cb in [ - cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) - ]: - cb_name = cb.__class__.__name__ - cb_state = cb.state() - if isinstance(self.state.stateful_callbacks[cb_name], list): - self.state.stateful_callbacks[cb_name].append(cb_state) + # Save the Trainer state + if self.args.should_save: + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently + for cb in [ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ]: + cb_name = cb.__class__.__name__ + cb_state = cb.state() + if isinstance(self.state.stateful_callbacks[cb_name], list): + self.state.stateful_callbacks[cb_name].append(cb_state) + else: + self.state.stateful_callbacks[cb_name] = cb_state + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + try: + os.renames(output_dir, checkpoint_dir) + except OSError as e: + if e.errno in [errno.ENOTEMPTY, errno.EEXIST]: # Directory/File already exists + shutil.rmtree(checkpoint_dir) + os.renames(output_dir, checkpoint_dir) else: - self.state.stateful_callbacks[cb_name] = cb_state - self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + raise if self.args.push_to_hub: - self._push_from_checkpoint(output_dir) + self._push_from_checkpoint(checkpoint_dir) # Maybe delete some older checkpoints. if self.args.should_save: