|
32 | 32 | import tempfile
|
33 | 33 | import time
|
34 | 34 | import warnings
|
| 35 | +import errno |
35 | 36 | from collections.abc import Mapping
|
36 | 37 | from pathlib import Path
|
37 | 38 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
@@ -3210,31 +3211,41 @@ def _save_checkpoint(self, model, trial):
|
3210 | 3211 | self.store_flos()
|
3211 | 3212 |
|
3212 | 3213 | run_dir = self._get_output_dir(trial=trial)
|
3213 |
| - output_dir = os.path.join(run_dir, checkpoint_folder) |
3214 |
| - self.save_model(output_dir, _internal_call=True) |
| 3214 | + checkpoint_dir = os.path.join(run_dir, checkpoint_folder) |
| 3215 | + with tempfile.TemporaryDirectory(prefix=f"tmp-{PREFIX_CHECKPOINT_DIR}-", dir=run_dir) as output_dir: |
| 3216 | + self.save_model(output_dir, _internal_call=True) |
3215 | 3217 |
|
3216 |
| - if not self.args.save_only_model: |
3217 |
| - # Save optimizer and scheduler |
3218 |
| - self._save_optimizer_and_scheduler(output_dir) |
3219 |
| - # Save RNG state |
3220 |
| - self._save_rng_state(output_dir) |
| 3218 | + if not self.args.save_only_model: |
| 3219 | + # Save optimizer and scheduler |
| 3220 | + self._save_optimizer_and_scheduler(output_dir) |
| 3221 | + # Save RNG state |
| 3222 | + self._save_rng_state(output_dir) |
3221 | 3223 |
|
3222 |
| - # Save the Trainer state |
3223 |
| - if self.args.should_save: |
3224 |
| - # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently |
3225 |
| - for cb in [ |
3226 |
| - cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) |
3227 |
| - ]: |
3228 |
| - cb_name = cb.__class__.__name__ |
3229 |
| - cb_state = cb.state() |
3230 |
| - if isinstance(self.state.stateful_callbacks[cb_name], list): |
3231 |
| - self.state.stateful_callbacks[cb_name].append(cb_state) |
| 3224 | + # Save the Trainer state |
| 3225 | + if self.args.should_save: |
| 3226 | + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently |
| 3227 | + for cb in [ |
| 3228 | + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) |
| 3229 | + ]: |
| 3230 | + cb_name = cb.__class__.__name__ |
| 3231 | + cb_state = cb.state() |
| 3232 | + if isinstance(self.state.stateful_callbacks[cb_name], list): |
| 3233 | + self.state.stateful_callbacks[cb_name].append(cb_state) |
| 3234 | + else: |
| 3235 | + self.state.stateful_callbacks[cb_name] = cb_state |
| 3236 | + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
| 3237 | + |
| 3238 | + try: |
| 3239 | + os.renames(output_dir, checkpoint_dir) |
| 3240 | + except OSError as e: |
| 3241 | + if e.errno in [errno.ENOTEMPTY, errno.EEXIST]: # Directory/File already exists |
| 3242 | + shutil.rmtree(checkpoint_dir) |
| 3243 | + os.renames(output_dir, checkpoint_dir) |
3232 | 3244 | else:
|
3233 |
| - self.state.stateful_callbacks[cb_name] = cb_state |
3234 |
| - self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
| 3245 | + raise |
3235 | 3246 |
|
3236 | 3247 | if self.args.push_to_hub:
|
3237 |
| - self._push_from_checkpoint(output_dir) |
| 3248 | + self._push_from_checkpoint(checkpoint_dir) |
3238 | 3249 |
|
3239 | 3250 | # Maybe delete some older checkpoints.
|
3240 | 3251 | if self.args.should_save:
|
|
0 commit comments