Skip to content

Commit 989e4fb

Browse files
authored
Merge pull request #358 from leseb/opt-rm-checkpoints
feat: retain only last checkpoint directory
2 parents 59c4611 + 0e8605a commit 989e4fb

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ for training jobs. There are a number of options you can specify, such as settin
105105
| fsdp_options | The settings for controlling FSDP when it's selected as the distributed backend. |
106106
| distributed_backend | Specifies which distributed training backend to use. Supported options are "fsdp" and "deepspeed". |
107107
| disable_flash_attn | Disables flash attention when set to true. This allows for training on older devices. |
108+
| keep_last_checkpoint_only | Determines whether we should only keep the last checkpoint directory - the previous checkpoint directory is always overwritten. The checkpoint directory is called `last_epoch`. |
108109

109110
### `DeepSpeedOptions`
110111

src/instructlab/training/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,8 @@ class TrainingArgs(BaseModel):
206206

207207
# This field defines whether or not data processing will occur inside of `run_training()`
208208
process_data: Optional[bool] = True
209+
210+
# This field specifies whether only the last checkpoint should be retained. When set to true, it
211+
# will overwrite the previous checkpoint directory, keeping only one directory called
212+
# "last_epoch". This works alongside the '--checkpoint_at_epoch' flag.
213+
keep_last_checkpoint_only: Optional[bool] = False

src/instructlab/training/main_ds.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
707707
f"--max_batch_len={train_args.max_batch_len}",
708708
f"--seed={train_args.random_seed}",
709709
f"--chat-tmpl-path={train_args.chat_tmpl_path}",
710+
f"--keep_last_checkpoint_only={train_args.keep_last_checkpoint_only}",
710711
]
711712

712713
if train_args.checkpoint_at_epoch:
@@ -787,6 +788,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
787788
f"--fsdp_sharding_strategy={train_args.fsdp_options.sharding_strategy.value}"
788789
)
789790

791+
if train_args.keep_last_checkpoint_only:
792+
command.append("--keep_last_checkpoint_only")
793+
790794
print(f"\033[92mRunning training command as subprocess: {' '.join(command)}\033[0m")
791795
process = None
792796
interrupt: KeyboardInterrupt | Exception | None = None
@@ -962,6 +966,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
962966
),
963967
)
964968
parser.add_argument("--disable_flash_attn", action="store_true")
969+
parser.add_argument(
970+
"--keep_last_checkpoint_only",
971+
action="store_true",
972+
help=(
973+
"Keep only the last checkpoint directory - overwrite the previous ones. Useful for saving disk space."
974+
"The last checkpoint will be saved as 'last_epoch'."
975+
),
976+
)
965977
args = parser.parse_args()
966978
set_random_seed(args.seed)
967979
main(args)

src/instructlab/training/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,13 @@ def save_hf_format_accelerate(
925925
samples_seen,
926926
is_lora=False,
927927
):
928+
# Build the subdirectory name
929+
subdir = (
930+
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
931+
)
932+
928933
log_rank_0(
929-
f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m",
934+
f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
930935
to_print=True,
931936
)
932937
start = time.time()
@@ -936,7 +941,9 @@ def save_hf_format_accelerate(
936941
else:
937942
convert_dolomite = True
938943

939-
final_output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}"
944+
# Build the final output directory path
945+
final_output_dir = Path(args.output_dir) / "hf_format" / subdir
946+
940947
if args.use_dolomite and convert_dolomite:
941948
tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with
942949
output_dir = Path(tmpdir.name)

0 commit comments

Comments
 (0)