Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ CKPT_ARGS=(
--save /root/GLM-Z1-9B-0414_slime/
# Model save interval (steps)
--save-interval 20
# --max-actor-ckpt-to-keep 5 # Keep only the 5 most recent actor checkpoints
)
```

Expand Down
3 changes: 3 additions & 0 deletions docs/en/get_started/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ When using slime, there are three parameters for loading and saving checkpoints:
- `--ref-load`: The Megatron checkpoint for the reference model.
- `--load`: The Megatron checkpoint for the actor. If `--load` is not set, or if the specified directory does not exist or does not contain `latest_checkpointed_iteration.txt`, the actor will be initialized from the `--ref-load` checkpoint.
- `--save`: The path where the actor's checkpoints are saved.
- `--max-actor-ckpt-to-keep`: Maximum number of actor checkpoints to keep on disk. When exceeded after saving, the oldest checkpoints are deleted. Applies to Megatron checkpoints (`--save`) and HF checkpoints (`--save-hf`). Default: `None` (unlimited).
- `--max-critic-ckpt-to-keep`: Maximum number of critic checkpoints to keep on disk. When exceeded after saving, the oldest checkpoints are deleted. Applies to critic Megatron checkpoints (`--critic-save`). Default: `None` (unlimited).
- `--checkpoint-storage-type`: Checkpoint storage type for cleanup. `shared` (default): all ranks share a filesystem (e.g. FSx) — cleanup runs on global rank 0 only. `local`: each node has its own disk (e.g. NVMe) — cleanup runs on local rank 0 of each node.

Note:

Expand Down
1 change: 1 addition & 0 deletions docs/zh/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ CKPT_ARGS=(
--save /root/GLM-Z1-9B-0414_slime/
# 模型保存间隔(步数)
--save-interval 20
# --max-actor-ckpt-to-keep 5 # 仅保留最近的 5 个 actor 检查点
)
```

Expand Down
3 changes: 3 additions & 0 deletions docs/zh/get_started/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ torch 格式是 megatron 的老存储格式,里面的结构大约是一些 `mp
- `--ref-load`:reference model 用的 megatron ckpt;
- `--load`:actor 用的 megatron ckpt,如果没有设置 `--load`,或者设置的目录不存在,目录中没有 `latest_checkpointed_iteration.txt`,都会直接从 `--ref-load` 的 ckpt 进行初始化;
- `--save`:actor 保存的路径。
- `--max-actor-ckpt-to-keep`:磁盘上保留的 actor 检查点最大数量。保存后超出限制时,最旧的检查点将被自动删除。适用于 Megatron 检查点(`--save`)和 HF 检查点(`--save-hf`)。默认值:`None`(不限制)。
- `--max-critic-ckpt-to-keep`:磁盘上保留的 critic 检查点最大数量。保存后超出限制时,最旧的检查点将被自动删除。适用于 critic Megatron 检查点(`--critic-save`)。默认值:`None`(不限制)。
- `--checkpoint-storage-type`:检查点存储类型。`shared`(默认):所有 rank 共享同一个文件系统(如 FSx)——清理仅在全局 rank 0 上运行。`local`:每个节点有独立磁盘(如 NVMe)——清理在每个节点的 local rank 0 上运行。

注意:

Expand Down
41 changes: 41 additions & 0 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from slime.utils.misc import Box
from slime.utils.reloadable_process_group import destroy_process_groups, monkey_patch_torch_dist, reload_process_groups
from slime.utils.routing_replay import RoutingReplay
from slime.utils.checkpoint_utils import cleanup_old_checkpoints, should_run_cleanup
from slime.utils.timer import Timer, inverse_timer, timer, with_defer
from slime.utils.types import RolloutBatch

Expand Down Expand Up @@ -512,6 +513,40 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None:

log_perf_data(rollout_id, self.args)

def _maybe_cleanup_old_checkpoints(self) -> None:
"""Clean up old checkpoints before writing the new one (keep-1 so peak = N)."""
keep = None
if self.role == "actor" and self.args.max_actor_ckpt_to_keep is not None:
keep = self.args.max_actor_ckpt_to_keep
elif self.role == "critic" and self.args.max_critic_ckpt_to_keep is not None:
keep = self.args.max_critic_ckpt_to_keep

if keep is None:
return

if not hasattr(self, "_saved_rollout_ids"):
self._saved_rollout_ids = []

storage_type = getattr(self.args, "checkpoint_storage_type", "shared")
should_cleanup_megatron, should_cleanup_hf = should_run_cleanup(
storage_type, dist.get_rank(), int(os.environ.get("LOCAL_RANK", 0)),
)

if should_cleanup_megatron:
save_dir = self.args.save if self.role == "actor" else self.args.critic_save
cleanup_old_checkpoints(
self._saved_rollout_ids,
keep - 1,
path_fn=lambda rid: os.path.join(save_dir, f"iter_{rid:07d}"),
)

if should_cleanup_hf and self.args.save_hf is not None and self.role == "actor":
cleanup_old_checkpoints(
self._saved_rollout_ids,
keep - 1,
path_fn=lambda rid: self.args.save_hf.format(rollout_id=rid),
)

@timer
def save_model(self, rollout_id: int, force_sync: bool = False) -> None:
if self.args.debug_rollout_only:
Expand All @@ -526,6 +561,8 @@ def save_model(self, rollout_id: int, force_sync: bool = False) -> None:

maybe_finalize_async_save(blocking=True)

self._maybe_cleanup_old_checkpoints()

save(rollout_id, self.model, self.optimizer, self.opt_param_scheduler)

if force_sync and self.args.async_save:
Expand All @@ -536,6 +573,10 @@ def save_model(self, rollout_id: int, force_sync: bool = False) -> None:

save_hf_model(self.args, rollout_id, self.model)

if not hasattr(self, "_saved_rollout_ids"):
self._saved_rollout_ids = []
self._saved_rollout_ids.append(rollout_id)

if self.args.offload_train:
destroy_process_groups()

Expand Down
41 changes: 41 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,29 @@ def add_algo_arguments(parser):
"The model will be saved to `save_hf.format(rollout_id)`. "
),
)
parser.add_argument(
"--max-actor-ckpt-to-keep",
type=int,
default=None,
help=(
"Maximum number of actor checkpoints to keep on disk. "
"When exceeded after saving, the oldest checkpoints are deleted. "
"Applies to Megatron checkpoints (--save) and HF checkpoints (--save-hf). "
"Default: None (unlimited)."
),
)
parser.add_argument(
"--checkpoint-storage-type",
type=str,
default="shared",
choices=["shared", "local"],
help=(
"Checkpoint storage type. 'shared' (default): all ranks share a "
"filesystem (e.g. FSx) — cleanup runs on global rank 0 only. "
"'local': each node has its own disk (e.g. NVMe) — cleanup runs "
"on local rank 0 of each node."
),
)
reset_arg(parser, "--seed", type=int, default=1234)
reset_arg(parser, "--clip-grad", type=float, default=1.0)
reset_arg(parser, "--calculate-per-token-loss", action="store_true")
Expand All @@ -753,6 +776,17 @@ def add_algo_arguments(parser):
parser.add_argument("--num-critic-only-steps", type=int, default=0, help="Number of critic only steps")
parser.add_argument("--critic-load", type=str, default=None, help="The checkpoint for critic model.")
parser.add_argument("--critic-save", type=str, default=None, help="The checkpoint for critic model.")
parser.add_argument(
"--max-critic-ckpt-to-keep",
type=int,
default=None,
help=(
"Maximum number of critic checkpoints to keep on disk. "
"When exceeded after saving, the oldest checkpoints are deleted. "
"Applies to critic Megatron checkpoints (--critic-save). "
"Default: None (unlimited)."
),
)
parser.add_argument("--critic-lr", type=float, default=None, help="The lr for critic model")
parser.add_argument("--critic-train-only", action="store_true", default=False, help="Only train critic")
parser.add_argument(
Expand Down Expand Up @@ -1579,6 +1613,13 @@ def slime_validate_args(args):
if args.save_interval is not None:
assert args.save is not None, "'--save' is required when save_interval is set."

if args.max_actor_ckpt_to_keep is not None:
assert args.max_actor_ckpt_to_keep >= 1, "'--max-actor-ckpt-to-keep' must be >= 1."
assert args.save_interval is not None, "'--save-interval' is required when '--max-actor-ckpt-to-keep' is set."
if args.max_critic_ckpt_to_keep is not None:
assert args.max_critic_ckpt_to_keep >= 1, "'--max-critic-ckpt-to-keep' must be >= 1."
assert args.save_interval is not None, "'--save-interval' is required when '--max-critic-ckpt-to-keep' is set."

assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set"

if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]:
Expand Down
57 changes: 57 additions & 0 deletions slime/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Utilities for managing checkpoint retention during RL training."""

import logging
import os
import shutil
from typing import Callable, List

logger = logging.getLogger(__name__)


def should_run_cleanup(storage_type: str, global_rank: int, local_rank: int) -> tuple:
"""Determine which ranks should run checkpoint cleanup.

Megatron: 'shared' → global rank 0; 'local' → local rank 0 per node.
HF: global rank 0 only (save_hf_pretrained always writes from global rank 0).

Returns:
(should_cleanup_megatron, should_cleanup_hf)
"""
if storage_type == "local":
should_megatron = local_rank == 0
else:
should_megatron = global_rank == 0
should_hf = global_rank == 0
return should_megatron, should_hf


def cleanup_old_checkpoints(
saved_rollout_ids: List[int],
keep: int,
path_fn: Callable[[int], str],
) -> List[str]:
"""Delete the oldest checkpoints, keeping only the newest *keep*.

*saved_rollout_ids* is an ordered list of rollout ids saved during the
current run (oldest first). *path_fn* maps a rollout id to its checkpoint
directory path on disk.

Returns:
List of deleted directory paths.
"""
if len(saved_rollout_ids) <= keep:
return []

to_delete_ids = saved_rollout_ids[: len(saved_rollout_ids) - keep]

deleted: list[str] = []
for rid in to_delete_ids:
path = path_fn(rid)
if os.path.isdir(path):
try:
logger.info("Deleting old checkpoint: %s", path)
shutil.rmtree(path)
deleted.append(path)
except OSError:
logger.warning("Failed to delete checkpoint %s, will retry next save", path, exc_info=True)
return deleted
Loading