diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index 4b21f0c23..286c638a1 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -147,6 +147,7 @@ CKPT_ARGS=( --save /root/GLM-Z1-9B-0414_slime/ # Model save interval (steps) --save-interval 20 + # --max-actor-ckpt-to-keep 5 # Keep only the 5 most recent actor checkpoints ) ``` diff --git a/docs/en/get_started/usage.md b/docs/en/get_started/usage.md index 5a59812e8..122bbbe93 100644 --- a/docs/en/get_started/usage.md +++ b/docs/en/get_started/usage.md @@ -130,6 +130,9 @@ When using slime, there are three parameters for loading and saving checkpoints: - `--ref-load`: The Megatron checkpoint for the reference model. - `--load`: The Megatron checkpoint for the actor. If `--load` is not set, or if the specified directory does not exist or does not contain `latest_checkpointed_iteration.txt`, the actor will be initialized from the `--ref-load` checkpoint. - `--save`: The path where the actor's checkpoints are saved. + - `--max-actor-ckpt-to-keep`: Maximum number of actor checkpoints to keep on disk. When exceeded after saving, the oldest checkpoints are deleted. Applies to Megatron checkpoints (`--save`) and HF checkpoints (`--save-hf`). Default: `None` (unlimited). + - `--max-critic-ckpt-to-keep`: Maximum number of critic checkpoints to keep on disk. When exceeded after saving, the oldest checkpoints are deleted. Applies to critic Megatron checkpoints (`--critic-save`). Default: `None` (unlimited). + - `--checkpoint-storage-type`: Checkpoint storage type for cleanup. `shared` (default): all ranks share a filesystem (e.g. FSx) — cleanup runs on global rank 0 only. `local`: each node has its own disk (e.g. NVMe) — cleanup runs on local rank 0 of each node. Note: diff --git a/docs/zh/get_started/quick_start.md b/docs/zh/get_started/quick_start.md index 98fe2e16d..4d0213ec6 100644 --- a/docs/zh/get_started/quick_start.md +++ b/docs/zh/get_started/quick_start.md @@ -145,6 +145,7 @@ CKPT_ARGS=( --save /root/GLM-Z1-9B-0414_slime/ # 模型保存间隔(步数) --save-interval 20 + # --max-actor-ckpt-to-keep 5 # 仅保留最近的 5 个 actor 检查点 ) ``` diff --git a/docs/zh/get_started/usage.md b/docs/zh/get_started/usage.md index 01746514b..652ffbedf 100644 --- a/docs/zh/get_started/usage.md +++ b/docs/zh/get_started/usage.md @@ -134,6 +134,9 @@ torch 格式是 megatron 的老存储格式,里面的结构大约是一些 `mp - `--ref-load`:reference model 用的 megatron ckpt; - `--load`:actor 用的 megatron ckpt,如果没有设置 `--load`,或者设置的目录不存在,目录中没有 `latest_checkpointed_iteration.txt`,都会直接从 `--ref-load` 的 ckpt 进行初始化; - `--save`:actor 保存的路径。 +- `--max-actor-ckpt-to-keep`:磁盘上保留的 actor 检查点最大数量。保存后超出限制时,最旧的检查点将被自动删除。适用于 Megatron 检查点(`--save`)和 HF 检查点(`--save-hf`)。默认值:`None`(不限制)。 +- `--max-critic-ckpt-to-keep`:磁盘上保留的 critic 检查点最大数量。保存后超出限制时,最旧的检查点将被自动删除。适用于 critic Megatron 检查点(`--critic-save`)。默认值:`None`(不限制)。 +- `--checkpoint-storage-type`:检查点存储类型。`shared`(默认):所有 rank 共享同一个文件系统(如 FSx)——清理仅在全局 rank 0 上运行。`local`:每个节点有独立磁盘(如 NVMe)——清理在每个节点的 local rank 0 上运行。 注意: diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index f68d66553..5387680e1 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -23,6 +23,7 @@ from slime.utils.misc import Box from slime.utils.reloadable_process_group import destroy_process_groups, monkey_patch_torch_dist, reload_process_groups from slime.utils.routing_replay import RoutingReplay +from slime.utils.checkpoint_utils import cleanup_old_checkpoints, should_run_cleanup from slime.utils.timer import Timer, inverse_timer, timer, with_defer from slime.utils.types import RolloutBatch @@ -512,6 +513,40 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: log_perf_data(rollout_id, self.args) + def _maybe_cleanup_old_checkpoints(self) -> None: + """Clean up old checkpoints before writing the new one (keep-1 so peak = N).""" + keep = None + if self.role == "actor" and self.args.max_actor_ckpt_to_keep is not None: + keep = self.args.max_actor_ckpt_to_keep + elif self.role == "critic" and self.args.max_critic_ckpt_to_keep is not None: + keep = self.args.max_critic_ckpt_to_keep + + if keep is None: + return + + if not hasattr(self, "_saved_rollout_ids"): + self._saved_rollout_ids = [] + + storage_type = getattr(self.args, "checkpoint_storage_type", "shared") + should_cleanup_megatron, should_cleanup_hf = should_run_cleanup( + storage_type, dist.get_rank(), int(os.environ.get("LOCAL_RANK", 0)), + ) + + if should_cleanup_megatron: + save_dir = self.args.save if self.role == "actor" else self.args.critic_save + cleanup_old_checkpoints( + self._saved_rollout_ids, + keep - 1, + path_fn=lambda rid: os.path.join(save_dir, f"iter_{rid:07d}"), + ) + + if should_cleanup_hf and self.args.save_hf is not None and self.role == "actor": + cleanup_old_checkpoints( + self._saved_rollout_ids, + keep - 1, + path_fn=lambda rid: self.args.save_hf.format(rollout_id=rid), + ) + @timer def save_model(self, rollout_id: int, force_sync: bool = False) -> None: if self.args.debug_rollout_only: @@ -526,6 +561,8 @@ def save_model(self, rollout_id: int, force_sync: bool = False) -> None: maybe_finalize_async_save(blocking=True) + self._maybe_cleanup_old_checkpoints() + save(rollout_id, self.model, self.optimizer, self.opt_param_scheduler) if force_sync and self.args.async_save: @@ -536,6 +573,10 @@ def save_model(self, rollout_id: int, force_sync: bool = False) -> None: save_hf_model(self.args, rollout_id, self.model) + if not hasattr(self, "_saved_rollout_ids"): + self._saved_rollout_ids = [] + self._saved_rollout_ids.append(rollout_id) + if self.args.offload_train: destroy_process_groups() diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index a634d1f00..8fa35acb9 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -745,6 +745,29 @@ def add_algo_arguments(parser): "The model will be saved to `save_hf.format(rollout_id)`. " ), ) + parser.add_argument( + "--max-actor-ckpt-to-keep", + type=int, + default=None, + help=( + "Maximum number of actor checkpoints to keep on disk. " + "When exceeded after saving, the oldest checkpoints are deleted. " + "Applies to Megatron checkpoints (--save) and HF checkpoints (--save-hf). " + "Default: None (unlimited)." + ), + ) + parser.add_argument( + "--checkpoint-storage-type", + type=str, + default="shared", + choices=["shared", "local"], + help=( + "Checkpoint storage type. 'shared' (default): all ranks share a " + "filesystem (e.g. FSx) — cleanup runs on global rank 0 only. " + "'local': each node has its own disk (e.g. NVMe) — cleanup runs " + "on local rank 0 of each node." + ), + ) reset_arg(parser, "--seed", type=int, default=1234) reset_arg(parser, "--clip-grad", type=float, default=1.0) reset_arg(parser, "--calculate-per-token-loss", action="store_true") @@ -753,6 +776,17 @@ def add_algo_arguments(parser): parser.add_argument("--num-critic-only-steps", type=int, default=0, help="Number of critic only steps") parser.add_argument("--critic-load", type=str, default=None, help="The checkpoint for critic model.") parser.add_argument("--critic-save", type=str, default=None, help="The checkpoint for critic model.") + parser.add_argument( + "--max-critic-ckpt-to-keep", + type=int, + default=None, + help=( + "Maximum number of critic checkpoints to keep on disk. " + "When exceeded after saving, the oldest checkpoints are deleted. " + "Applies to critic Megatron checkpoints (--critic-save). " + "Default: None (unlimited)." + ), + ) parser.add_argument("--critic-lr", type=float, default=None, help="The lr for critic model") parser.add_argument("--critic-train-only", action="store_true", default=False, help="Only train critic") parser.add_argument( @@ -1579,6 +1613,13 @@ def slime_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + if args.max_actor_ckpt_to_keep is not None: + assert args.max_actor_ckpt_to_keep >= 1, "'--max-actor-ckpt-to-keep' must be >= 1." + assert args.save_interval is not None, "'--save-interval' is required when '--max-actor-ckpt-to-keep' is set." + if args.max_critic_ckpt_to_keep is not None: + assert args.max_critic_ckpt_to_keep >= 1, "'--max-critic-ckpt-to-keep' must be >= 1." + assert args.save_interval is not None, "'--save-interval' is required when '--max-critic-ckpt-to-keep' is set." + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/slime/utils/checkpoint_utils.py b/slime/utils/checkpoint_utils.py new file mode 100644 index 000000000..485ef3c5e --- /dev/null +++ b/slime/utils/checkpoint_utils.py @@ -0,0 +1,57 @@ +"""Utilities for managing checkpoint retention during RL training.""" + +import logging +import os +import shutil +from typing import Callable, List + +logger = logging.getLogger(__name__) + + +def should_run_cleanup(storage_type: str, global_rank: int, local_rank: int) -> tuple: + """Determine which ranks should run checkpoint cleanup. + + Megatron: 'shared' → global rank 0; 'local' → local rank 0 per node. + HF: global rank 0 only (save_hf_pretrained always writes from global rank 0). + + Returns: + (should_cleanup_megatron, should_cleanup_hf) + """ + if storage_type == "local": + should_megatron = local_rank == 0 + else: + should_megatron = global_rank == 0 + should_hf = global_rank == 0 + return should_megatron, should_hf + + +def cleanup_old_checkpoints( + saved_rollout_ids: List[int], + keep: int, + path_fn: Callable[[int], str], +) -> List[str]: + """Delete the oldest checkpoints, keeping only the newest *keep*. + + *saved_rollout_ids* is an ordered list of rollout ids saved during the + current run (oldest first). *path_fn* maps a rollout id to its checkpoint + directory path on disk. + + Returns: + List of deleted directory paths. + """ + if len(saved_rollout_ids) <= keep: + return [] + + to_delete_ids = saved_rollout_ids[: len(saved_rollout_ids) - keep] + + deleted: list[str] = [] + for rid in to_delete_ids: + path = path_fn(rid) + if os.path.isdir(path): + try: + logger.info("Deleting old checkpoint: %s", path) + shutil.rmtree(path) + deleted.append(path) + except OSError: + logger.warning("Failed to delete checkpoint %s, will retry next save", path, exc_info=True) + return deleted diff --git a/tests/utils/test_checkpoint_utils.py b/tests/utils/test_checkpoint_utils.py new file mode 100644 index 000000000..37227d07d --- /dev/null +++ b/tests/utils/test_checkpoint_utils.py @@ -0,0 +1,339 @@ +"""Tests for slime.utils.checkpoint_utils.""" + +import os +import shutil + +import pytest + +from slime.utils.checkpoint_utils import cleanup_old_checkpoints, should_run_cleanup + + +def _megatron_path_fn(save_dir): + """Return a path_fn that maps rollout_id → Megatron iter dir.""" + return lambda rid: os.path.join(save_dir, f"iter_{rid:07d}") + + +def _hf_path_fn(template): + """Return a path_fn that maps rollout_id → HF checkpoint dir.""" + return lambda rid: template.format(rollout_id=rid) + + +def _make_megatron_dirs(tmp_path, iters): + """Create Megatron-style iter_NNNNNNN directories.""" + save_dir = str(tmp_path / "ckpt") + os.makedirs(save_dir, exist_ok=True) + for i in iters: + os.makedirs(os.path.join(save_dir, f"iter_{i:07d}")) + return save_dir + + +def _make_hf_dirs(tmp_path, template, rollout_ids): + """Create HF checkpoint directories from template.""" + for rid in rollout_ids: + os.makedirs(template.format(rollout_id=rid)) + + +# --------------------------------------------------------------------------- +# Core functionality +# --------------------------------------------------------------------------- + + +class TestCleanupOldCheckpoints: + def test_deletes_oldest_keeps_newest(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, [10, 20, 30, 40, 50]) + deleted = cleanup_old_checkpoints( + [10, 20, 30, 40, 50], keep=2, path_fn=_megatron_path_fn(save_dir), + ) + remaining = sorted(os.listdir(save_dir)) + assert remaining == ["iter_0000040", "iter_0000050"] + assert len(deleted) == 3 + + def test_noop_under_limit(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, [10, 20]) + deleted = cleanup_old_checkpoints( + [10, 20], keep=3, path_fn=_megatron_path_fn(save_dir), + ) + assert deleted == [] + assert len(os.listdir(save_dir)) == 2 + + def test_noop_at_exact_limit(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, [10, 20, 30]) + deleted = cleanup_old_checkpoints( + [10, 20, 30], keep=3, path_fn=_megatron_path_fn(save_dir), + ) + assert deleted == [] + + def test_keep_one(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, [5, 10, 15]) + deleted = cleanup_old_checkpoints( + [5, 10, 15], keep=1, path_fn=_megatron_path_fn(save_dir), + ) + remaining = os.listdir(save_dir) + assert remaining == ["iter_0000015"] + assert len(deleted) == 2 + + def test_keep_zero_deletes_all(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, [10, 20, 30]) + cleanup_old_checkpoints( + [10, 20, 30], keep=0, path_fn=_megatron_path_fn(save_dir), + ) + remaining = [e for e in os.listdir(save_dir) if e.startswith("iter_")] + assert remaining == [] + + def test_empty_list(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, []) + deleted = cleanup_old_checkpoints( + [], keep=2, path_fn=_megatron_path_fn(save_dir), + ) + assert deleted == [] + + def test_missing_dir_skipped(self, tmp_path): + """path_fn returns a path that doesn't exist on disk — silently skipped.""" + save_dir = str(tmp_path / "ckpt") + os.makedirs(save_dir) + # Only create dir for rollout_id 10 + os.makedirs(os.path.join(save_dir, "iter_0000010")) + deleted = cleanup_old_checkpoints( + [0, 5, 10], keep=1, path_fn=_megatron_path_fn(save_dir), + ) + # 0 and 5 don't exist on disk, so they're skipped + assert deleted == [] + assert os.path.isdir(os.path.join(save_dir, "iter_0000010")) + + def test_rmtree_failure_does_not_crash(self, tmp_path, monkeypatch): + """If shutil.rmtree raises OSError, cleanup logs a warning and continues.""" + save_dir = _make_megatron_dirs(tmp_path, [10, 20, 30]) + + original_rmtree = shutil.rmtree + call_count = 0 + + def flaky_rmtree(path, *a, **kw): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise OSError("simulated transient error") + return original_rmtree(path, *a, **kw) + + monkeypatch.setattr(shutil, "rmtree", flaky_rmtree) + deleted = cleanup_old_checkpoints( + [10, 20, 30], keep=1, path_fn=_megatron_path_fn(save_dir), + ) + assert len(deleted) == 1 + remaining = sorted(e for e in os.listdir(save_dir) if e.startswith("iter_")) + assert remaining == ["iter_0000010", "iter_0000030"] + + +# --------------------------------------------------------------------------- +# Megatron-specific behavior +# --------------------------------------------------------------------------- + + +class TestMegatronCheckpoints: + def test_preserves_latest_checkpointed_iteration_txt(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, [10, 20, 30]) + txt_path = os.path.join(save_dir, "latest_checkpointed_iteration.txt") + with open(txt_path, "w") as f: + f.write("30\n") + cleanup_old_checkpoints( + [10, 20, 30], keep=1, path_fn=_megatron_path_fn(save_dir), + ) + assert os.path.isfile(txt_path) + remaining = [e for e in os.listdir(save_dir) if e.startswith("iter_")] + assert remaining == ["iter_0000030"] + + def test_ignores_non_iter_entries(self, tmp_path): + save_dir = _make_megatron_dirs(tmp_path, [10, 20]) + os.makedirs(os.path.join(save_dir, "some_other_dir")) + with open(os.path.join(save_dir, "some_file.txt"), "w") as f: + f.write("hello") + cleanup_old_checkpoints( + [10, 20], keep=1, path_fn=_megatron_path_fn(save_dir), + ) + # Non-matching entries untouched + assert os.path.isdir(os.path.join(save_dir, "some_other_dir")) + assert os.path.isfile(os.path.join(save_dir, "some_file.txt")) + + def test_peak_equals_limit(self, tmp_path): + """Simulate cleanup(keep=limit-1) + save cycle — peak never exceeds limit.""" + save_dir = str(tmp_path / "ckpt") + os.makedirs(save_dir) + limit = 2 + saved_rollout_ids = [] + + for rollout_id in range(5): + cleanup_old_checkpoints( + saved_rollout_ids, keep=limit - 1, path_fn=_megatron_path_fn(save_dir), + ) + os.makedirs(os.path.join(save_dir, f"iter_{rollout_id:07d}")) + saved_rollout_ids.append(rollout_id) + n = len([e for e in os.listdir(save_dir) if e.startswith("iter_")]) + assert n <= limit, f"peak exceeded limit: {n} checkpoints on disk" + + def test_only_cleans_own_run(self, tmp_path): + """Checkpoints from a previous run are not touched.""" + save_dir = _make_megatron_dirs(tmp_path, [1, 2, 3]) # previous run + saved_rollout_ids = [] # current run starts fresh + + # Current run saves rollout 4, 5, 6 with keep=2 + for rollout_id in [4, 5, 6]: + cleanup_old_checkpoints( + saved_rollout_ids, keep=2 - 1, path_fn=_megatron_path_fn(save_dir), + ) + os.makedirs(os.path.join(save_dir, f"iter_{rollout_id:07d}")) + saved_rollout_ids.append(rollout_id) + + # Previous run's checkpoints still intact + assert os.path.isdir(os.path.join(save_dir, "iter_0000001")) + assert os.path.isdir(os.path.join(save_dir, "iter_0000002")) + assert os.path.isdir(os.path.join(save_dir, "iter_0000003")) + # Current run: keep=2, so 5 and 6 remain, 4 deleted + assert not os.path.isdir(os.path.join(save_dir, "iter_0000004")) + assert os.path.isdir(os.path.join(save_dir, "iter_0000005")) + assert os.path.isdir(os.path.join(save_dir, "iter_0000006")) + + +# --------------------------------------------------------------------------- +# HF-specific behavior +# --------------------------------------------------------------------------- + + +class TestHfCheckpoints: + def test_deletes_oldest_by_rollout_id(self, tmp_path): + template = str(tmp_path / "hf_ckpt_{rollout_id}") + rollout_ids = [0, 5, 10, 15, 20] + _make_hf_dirs(tmp_path, template, rollout_ids) + deleted = cleanup_old_checkpoints( + rollout_ids, keep=2, path_fn=_hf_path_fn(template), + ) + assert len(deleted) == 3 + assert os.path.isdir(template.format(rollout_id=15)) + assert os.path.isdir(template.format(rollout_id=20)) + assert not os.path.isdir(template.format(rollout_id=0)) + + def test_noop_under_limit(self, tmp_path): + template = str(tmp_path / "hf_ckpt_{rollout_id}") + rollout_ids = [0, 5] + _make_hf_dirs(tmp_path, template, rollout_ids) + deleted = cleanup_old_checkpoints( + rollout_ids, keep=3, path_fn=_hf_path_fn(template), + ) + assert deleted == [] + + def test_noop_at_exact_limit(self, tmp_path): + template = str(tmp_path / "hf_{rollout_id}") + rollout_ids = [0, 5, 10] + _make_hf_dirs(tmp_path, template, rollout_ids) + deleted = cleanup_old_checkpoints( + rollout_ids, keep=3, path_fn=_hf_path_fn(template), + ) + assert deleted == [] + + def test_none_template_skips(self, tmp_path): + """When save_hf is None, no dirs exist — all skipped.""" + deleted = cleanup_old_checkpoints( + [0, 1, 2], keep=1, path_fn=lambda rid: str(tmp_path / f"nonexistent_{rid}"), + ) + assert deleted == [] + + def test_keep_one(self, tmp_path): + template = str(tmp_path / "hf_{rollout_id}") + rollout_ids = [0, 5, 10] + _make_hf_dirs(tmp_path, template, rollout_ids) + deleted = cleanup_old_checkpoints( + rollout_ids, keep=1, path_fn=_hf_path_fn(template), + ) + assert len(deleted) == 2 + assert not os.path.isdir(template.format(rollout_id=0)) + assert not os.path.isdir(template.format(rollout_id=5)) + assert os.path.isdir(template.format(rollout_id=10)) + + def test_keep_zero_deletes_all(self, tmp_path): + template = str(tmp_path / "hf_{rollout_id}") + rollout_ids = [0, 5, 10] + _make_hf_dirs(tmp_path, template, rollout_ids) + cleanup_old_checkpoints( + rollout_ids, keep=0, path_fn=_hf_path_fn(template), + ) + for rid in rollout_ids: + assert not os.path.isdir(template.format(rollout_id=rid)) + + def test_peak_equals_limit(self, tmp_path): + """Simulate cleanup(keep=limit-1) + save + track cycle — peak never exceeds limit.""" + template = str(tmp_path / "hf_{rollout_id}") + limit = 2 + saved_rollout_ids = [] + + for rollout_id in range(5): + cleanup_old_checkpoints( + saved_rollout_ids, keep=limit - 1, path_fn=_hf_path_fn(template), + ) + os.makedirs(template.format(rollout_id=rollout_id)) + saved_rollout_ids.append(rollout_id) + n = sum(1 for rid in saved_rollout_ids + if os.path.isdir(template.format(rollout_id=rid))) + assert n <= limit, f"peak exceeded limit: {n} HF checkpoints on disk" + + def test_rmtree_failure_does_not_crash(self, tmp_path, monkeypatch): + template = str(tmp_path / "hf_{rollout_id}") + rollout_ids = [0, 5, 10] + _make_hf_dirs(tmp_path, template, rollout_ids) + + original_rmtree = shutil.rmtree + call_count = 0 + + def flaky_rmtree(path, *a, **kw): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise OSError("simulated transient error") + return original_rmtree(path, *a, **kw) + + monkeypatch.setattr(shutil, "rmtree", flaky_rmtree) + deleted = cleanup_old_checkpoints( + rollout_ids, keep=1, path_fn=_hf_path_fn(template), + ) + assert len(deleted) == 1 + assert os.path.isdir(template.format(rollout_id=0)) + assert not os.path.isdir(template.format(rollout_id=5)) + assert os.path.isdir(template.format(rollout_id=10)) + + +# --------------------------------------------------------------------------- +# should_run_cleanup — rank selection logic +# --------------------------------------------------------------------------- + + +class TestShouldRunCleanup: + def test_shared_global_rank_0(self): + megatron, hf = should_run_cleanup("shared", global_rank=0, local_rank=0) + assert megatron is True + assert hf is True + + def test_shared_global_rank_nonzero(self): + megatron, hf = should_run_cleanup("shared", global_rank=3, local_rank=3) + assert megatron is False + assert hf is False + + def test_shared_global_rank_nonzero_local_rank_zero(self): + """On node 1, local_rank=0 but global_rank=8 — shared uses global rank.""" + megatron, hf = should_run_cleanup("shared", global_rank=8, local_rank=0) + assert megatron is False + assert hf is False + + def test_local_local_rank_0_global_rank_0(self): + """Node 0, local_rank=0 — both should cleanup.""" + megatron, hf = should_run_cleanup("local", global_rank=0, local_rank=0) + assert megatron is True + assert hf is True + + def test_local_local_rank_0_global_rank_nonzero(self): + """Node 1, local_rank=0 but global_rank=8 — megatron yes, HF no.""" + megatron, hf = should_run_cleanup("local", global_rank=8, local_rank=0) + assert megatron is True + assert hf is False + + def test_local_local_rank_nonzero(self): + """local_rank=3 — neither should cleanup.""" + megatron, hf = should_run_cleanup("local", global_rank=3, local_rank=3) + assert megatron is False + assert hf is False