Skip to content

Commit f7a972a

Browse files
committed
Save checkpoint to temporary folder first
Since partial/missing files due to failures throw error during load
1 parent 8e4cedd commit f7a972a

File tree

1 file changed

+31
-20
lines changed

1 file changed

+31
-20
lines changed

src/transformers/trainer.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import contextlib
2020
import copy
21+
import errno
2122
import functools
2223
import glob
2324
import importlib.metadata
@@ -3210,31 +3211,41 @@ def _save_checkpoint(self, model, trial):
32103211
self.store_flos()
32113212

32123213
run_dir = self._get_output_dir(trial=trial)
3213-
output_dir = os.path.join(run_dir, checkpoint_folder)
3214-
self.save_model(output_dir, _internal_call=True)
3214+
checkpoint_dir = os.path.join(run_dir, checkpoint_folder)
3215+
with tempfile.TemporaryDirectory(prefix=f"tmp-{PREFIX_CHECKPOINT_DIR}-", dir=run_dir) as output_dir:
3216+
self.save_model(output_dir, _internal_call=True)
32153217

3216-
if not self.args.save_only_model:
3217-
# Save optimizer and scheduler
3218-
self._save_optimizer_and_scheduler(output_dir)
3219-
# Save RNG state
3220-
self._save_rng_state(output_dir)
3218+
if not self.args.save_only_model:
3219+
# Save optimizer and scheduler
3220+
self._save_optimizer_and_scheduler(output_dir)
3221+
# Save RNG state
3222+
self._save_rng_state(output_dir)
32213223

3222-
# Save the Trainer state
3223-
if self.args.should_save:
3224-
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
3225-
for cb in [
3226-
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
3227-
]:
3228-
cb_name = cb.__class__.__name__
3229-
cb_state = cb.state()
3230-
if isinstance(self.state.stateful_callbacks[cb_name], list):
3231-
self.state.stateful_callbacks[cb_name].append(cb_state)
3224+
# Save the Trainer state
3225+
if self.args.should_save:
3226+
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
3227+
for cb in [
3228+
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
3229+
]:
3230+
cb_name = cb.__class__.__name__
3231+
cb_state = cb.state()
3232+
if isinstance(self.state.stateful_callbacks[cb_name], list):
3233+
self.state.stateful_callbacks[cb_name].append(cb_state)
3234+
else:
3235+
self.state.stateful_callbacks[cb_name] = cb_state
3236+
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
3237+
3238+
try:
3239+
os.renames(output_dir, checkpoint_dir)
3240+
except OSError as e:
3241+
if e.errno in [errno.ENOTEMPTY, errno.EEXIST]: # Directory/File already exists
3242+
shutil.rmtree(checkpoint_dir)
3243+
os.renames(output_dir, checkpoint_dir)
32323244
else:
3233-
self.state.stateful_callbacks[cb_name] = cb_state
3234-
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
3245+
raise
32353246

32363247
if self.args.push_to_hub:
3237-
self._push_from_checkpoint(output_dir)
3248+
self._push_from_checkpoint(checkpoint_dir)
32383249

32393250
# Maybe delete some older checkpoints.
32403251
if self.args.should_save:

0 commit comments

Comments
 (0)