From 86ea100a2b85ddffe8bc01778d19080049ca1545 Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Thu, 30 Oct 2025 05:58:02 -0700 Subject: [PATCH] fix: support arbitrary values for `checkpointing.metric_name` (#1291) Signed-off-by: ashors1 Co-authored-by: Terry Kong Signed-off-by: NeMo Bot --- examples/configs/distillation_math.yaml | 2 +- examples/configs/dpo.yaml | 4 +- examples/configs/grpo_math_1B.yaml | 2 +- examples/configs/grpo_math_1B_megatron.yaml | 2 +- examples/configs/grpo_sliding_puzzle.yaml | 2 +- examples/configs/rm.yaml | 2 +- examples/configs/sft.yaml | 2 +- examples/configs/sft_openmathinstruct2.yaml | 2 +- .../sft_openmathinstruct2_megatron.yaml | 2 +- examples/configs/sft_vlm_3B.yaml | 2 +- examples/configs/vlm_grpo_3B.yaml | 2 +- examples/configs/vlm_grpo_3B_megatron.yaml | 2 +- nemo_rl/algorithms/distillation.py | 33 ++++++--- nemo_rl/algorithms/dpo.py | 33 +++++++-- nemo_rl/algorithms/grpo.py | 67 ++++++++++++++----- nemo_rl/algorithms/rm.py | 34 +++++++--- nemo_rl/algorithms/sft.py | 39 +++++++---- nemo_rl/utils/checkpoint.py | 2 + 18 files changed, 171 insertions(+), 63 deletions(-) diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index e0b8bcf283..92ae09d8ee 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -20,7 +20,7 @@ loss_fn: checkpointing: enabled: true checkpoint_dir: "checkpoints/distillation-${policy.model_name}" - metric_name: "val_reward" + metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name higher_is_better: true keep_top_k: 3 save_period: 10 diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 6af1405533..72823364cf 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -22,7 +22,7 @@ dpo: checkpointing: enabled: true checkpoint_dir: "results/dpo" - metric_name: "val_loss" + metric_name: "val:validation-default_loss" higher_is_better: false keep_top_k: 3 save_period: 50 @@ -186,7 +186,7 @@ data: # If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example: # checkpointing: - # metric_name: "validation-_loss" + # metric_name: "val:validation-_loss" # ... logger: diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 4c67640257..fbb10ed3a6 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -35,7 +35,7 @@ loss_fn: checkpointing: enabled: true checkpoint_dir: "results/grpo" - metric_name: "val_reward" + metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name higher_is_better: true keep_top_k: 3 save_period: 10 diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 4246fee777..01f33f63d6 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -29,7 +29,7 @@ loss_fn: checkpointing: enabled: false checkpoint_dir: "results/grpo_megatron" - metric_name: "val_reward" + metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name higher_is_better: true keep_top_k: 3 save_period: 10 diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index 725437041b..842318b133 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -11,7 +11,7 @@ grpo: checkpointing: enabled: true checkpoint_dir: "results/grpo-sliding-puzzle" - metric_name: "val_reward" + metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name higher_is_better: true keep_top_k: 3 save_period: 10 diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index cdbf900fb4..22e22ac5dd 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -15,7 +15,7 @@ rm: checkpointing: enabled: true checkpoint_dir: "results/rm" - metric_name: "val_loss" + metric_name: "val:validation-default_loss" # one of "val:" or "train:" followed by the metric name higher_is_better: false keep_top_k: 3 save_period: ${rm.val_period} diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 68f6b674e1..b742846d2d 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -15,7 +15,7 @@ sft: checkpointing: enabled: true checkpoint_dir: "results/sft" - metric_name: "val_loss" ## set to null to save most recent k checkpoints + metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name higher_is_better: false keep_top_k: 3 save_period: 10 diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index a022e04aff..614ecfa98d 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -12,7 +12,7 @@ sft: checkpointing: enabled: true checkpoint_dir: "results/sft_openmathinstruct2" - metric_name: "val_loss" + metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name higher_is_better: false keep_top_k: 100 save_period: 500 diff --git a/examples/configs/sft_openmathinstruct2_megatron.yaml b/examples/configs/sft_openmathinstruct2_megatron.yaml index c09d47297a..696d025976 100644 --- a/examples/configs/sft_openmathinstruct2_megatron.yaml +++ b/examples/configs/sft_openmathinstruct2_megatron.yaml @@ -14,7 +14,7 @@ sft: checkpointing: enabled: true checkpoint_dir: "results/sft_openmathinstruct2" - metric_name: "val_loss" + metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name higher_is_better: false keep_top_k: 100 save_period: 500 diff --git a/examples/configs/sft_vlm_3B.yaml b/examples/configs/sft_vlm_3B.yaml index 185dced165..5615e2f99d 100644 --- a/examples/configs/sft_vlm_3B.yaml +++ b/examples/configs/sft_vlm_3B.yaml @@ -16,7 +16,7 @@ policy: checkpointing: enabled: true checkpoint_dir: "results/sft_${policy.model_name}" - metric_name: "val_loss" ## set to null to save most recent k checkpoints + metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name higher_is_better: false keep_top_k: 1 save_period: 10 diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 31005825fd..67e1899c65 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -32,7 +32,7 @@ loss_fn: checkpointing: enabled: true checkpoint_dir: "results/clevr_grpo_${policy.model_name}" - metric_name: "val_reward" + metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name higher_is_better: true keep_top_k: 3 save_period: 10 diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index b73e4dcf52..4f53358fb3 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -27,7 +27,7 @@ loss_fn: checkpointing: enabled: true checkpoint_dir: results/clevr_grpo_${policy.model_name} - metric_name: val_reward + metric_name: val:accuracy # one of "val:" or "train:" followed by the metric name higher_is_better: true keep_top_k: 3 save_period: 10 diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index 4eacebc3eb..18a3b2ecdc 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -786,17 +786,34 @@ def distillation_train( del distillation_save_state["val_reward"] distillation_save_state["consumed_samples"] = consumed_samples - if master_config["checkpointing"]["metric_name"] is not None: - if ( - master_config["checkpointing"]["metric_name"] - not in distillation_save_state - ): + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_reward --> 'val:accuracy'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead.", + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", stacklevel=2, ) - master_config["checkpointing"]["metric_name"] = None + if full_metric_name in distillation_save_state: + del distillation_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" + ) + else: + distillation_save_state[full_metric_name] = metrics_source[ + metric_name + ] with timer.time("checkpointing"): print( diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 790b4bf9bd..e4bdae5ad5 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -649,15 +649,34 @@ def dpo_train( if val_metrics is not None: dpo_save_state.update(val_metrics) - if master_config["checkpointing"]["metric_name"] is not None: - if ( - master_config["checkpointing"]["metric_name"] - not in dpo_save_state - ): + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_loss --> 'val:validation-default_loss'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "This checkpoint will not be saved as top-k." + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in dpo_save_state: + del dpo_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" ) + else: + dpo_save_state[full_metric_name] = metrics_source[ + metric_name + ] with timer.time("checkpointing"): print(f"Saving checkpoint for step {total_steps + 1}...") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index b2fd2ab2d9..eab6fd56af 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -920,15 +920,34 @@ def grpo_train( del grpo_save_state["val_reward"] grpo_save_state["consumed_samples"] = consumed_samples - if master_config["checkpointing"]["metric_name"] is not None: - if ( - master_config["checkpointing"]["metric_name"] - not in grpo_save_state - ): + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_reward --> 'val:reward'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "This checkpoint will not be saved as top-k." + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in grpo_save_state: + del grpo_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" ) + else: + grpo_save_state[full_metric_name] = metrics_source[ + metric_name + ] with timer.time("checkpointing"): print( @@ -1726,16 +1745,34 @@ def async_grpo_train( del grpo_save_state["val_reward"] grpo_save_state["consumed_samples"] = consumed_samples - if master_config["checkpointing"]["metric_name"] is not None: - if ( - master_config["checkpointing"]["metric_name"] - not in grpo_save_state - ): + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_reward --> 'val:accuracy'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in grpo_save_state: + del grpo_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" ) - master_config["checkpointing"]["metric_name"] = None + else: + grpo_save_state[full_metric_name] = metrics_source[ + metric_name + ] with timer.time("checkpointing"): print(f"Saving checkpoint for step {step + 1}...") diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index de74b9f500..a035a55e98 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -577,16 +577,34 @@ def rm_train( if val_metrics is not None: rm_save_state.update(val_metrics) - if master_config["checkpointing"]["metric_name"] is not None: - if ( - master_config["checkpointing"]["metric_name"] - not in rm_save_state - ): + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_loss --> 'val:validation-default_loss'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in rm_save_state: + del rm_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" ) - master_config["checkpointing"]["metric_name"] = None + else: + rm_save_state[full_metric_name] = metrics_source[ + metric_name + ] with timer.time("checkpointing"): print(f"Saving checkpoint for step {total_steps + 1}...") diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 6d251cb222..ae93145b83 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -515,20 +515,35 @@ def sft_train( sft_save_state["total_steps"] = total_steps + 1 sft_save_state["epoch"] = current_epoch sft_save_state["total_valid_tokens"] = total_valid_tokens - if val_metrics is not None: - sft_save_state["val_loss"] = val_metrics["val_loss"] - elif "val_loss" in sft_save_state: - del sft_save_state["val_loss"] - - if master_config["checkpointing"]["metric_name"] is not None: - if ( - master_config["checkpointing"]["metric_name"] - not in sft_save_state - ): + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_loss --> 'val:val_loss'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "This checkpoint will not be saved as top-k." + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in sft_save_state: + del sft_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" ) + else: + sft_save_state[full_metric_name] = metrics_source[ + metric_name + ] with timer.time("checkpointing"): print(f"Saving checkpoint for step {total_steps + 1}...") diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index ca2bab3940..74ce5ac0cb 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -40,6 +40,8 @@ class CheckpointingConfig(TypedDict): enabled (bool): Whether checkpointing is enabled. checkpoint_dir (PathLike): Directory where checkpoints will be saved. metric_name (str | None): Name of the metric to use for determining best checkpoints. + Must be of the form "val:" or "train:" to indicate whether + the metric should be taken from the validation or training metrics. higher_is_better (bool): Whether higher values of the metric indicate better performance. keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept. model_save_format (str): Format for saving model ("torch_save" or "safetensors").