Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ loss_fn:
checkpointing:
enabled: true
checkpoint_dir: "checkpoints/distillation-${policy.model_name}"
metric_name: "val_reward"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down
4 changes: 2 additions & 2 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dpo:
checkpointing:
enabled: true
checkpoint_dir: "results/dpo"
metric_name: "val_loss"
metric_name: "val:validation-default_loss"
higher_is_better: false
keep_top_k: 3
save_period: 50
Expand Down Expand Up @@ -186,7 +186,7 @@ data:

# If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example:
# checkpointing:
# metric_name: "validation-<NameOfValidationDataset1>_loss"
# metric_name: "val:validation-<NameOfValidationDataset1>_loss"
# ...

logger:
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ loss_fn:
checkpointing:
enabled: true
checkpoint_dir: "results/grpo"
metric_name: "val_reward"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ loss_fn:
checkpointing:
enabled: false
checkpoint_dir: "results/grpo_megatron"
metric_name: "val_reward"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_sliding_puzzle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ grpo:
checkpointing:
enabled: true
checkpoint_dir: "results/grpo-sliding-puzzle"
metric_name: "val_reward"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ rm:
checkpointing:
enabled: true
checkpoint_dir: "results/rm"
metric_name: "val_loss"
metric_name: "val:validation-default_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
keep_top_k: 3
save_period: ${rm.val_period}
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ sft:
checkpointing:
enabled: true
checkpoint_dir: "results/sft"
metric_name: "val_loss" ## set to null to save most recent k checkpoints
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
keep_top_k: 3
save_period: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ sft:
checkpointing:
enabled: true
checkpoint_dir: "results/sft_openmathinstruct2"
metric_name: "val_loss"
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
keep_top_k: 100
save_period: 500
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/sft_openmathinstruct2_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ sft:
checkpointing:
enabled: true
checkpoint_dir: "results/sft_openmathinstruct2"
metric_name: "val_loss"
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
keep_top_k: 100
save_period: 500
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/sft_vlm_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ policy:
checkpointing:
enabled: true
checkpoint_dir: "results/sft_${policy.model_name}"
metric_name: "val_loss" ## set to null to save most recent k checkpoints
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
keep_top_k: 1
save_period: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ loss_fn:
checkpointing:
enabled: true
checkpoint_dir: "results/clevr_grpo_${policy.model_name}"
metric_name: "val_reward"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ loss_fn:
checkpointing:
enabled: true
checkpoint_dir: results/clevr_grpo_${policy.model_name}
metric_name: val_reward
metric_name: val:accuracy # one of "val:" or "train:" followed by the metric name
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down
33 changes: 25 additions & 8 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,17 +786,34 @@ def distillation_train(
del distillation_save_state["val_reward"]
distillation_save_state["consumed_samples"] = consumed_samples

if master_config["checkpointing"]["metric_name"] is not None:
if (
master_config["checkpointing"]["metric_name"]
not in distillation_save_state
):
full_metric_name = master_config["checkpointing"]["metric_name"]
if full_metric_name is not None:
assert full_metric_name.startswith(
"train:"
) or full_metric_name.startswith("val:"), (
f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n"
f'followed by the corresponding name in the "val" or "train" metrics dictionary.'
f" If you are using an old config, please updated checkpointing.metric_name to the new format, "
f" e.g. 'val_reward --> 'val:accuracy'"
)
prefix, metric_name = full_metric_name.split(":", 1)
metrics_source = metrics if prefix == "train" else val_metrics
if not metrics_source:
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"Saving most recent k checkpoints instead.",
f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. "
"This checkpoint will not be saved as top-k.",
stacklevel=2,
)
master_config["checkpointing"]["metric_name"] = None
if full_metric_name in distillation_save_state:
del distillation_save_state[full_metric_name]
elif metric_name not in metrics_source:
raise ValueError(
f"Metric {metric_name} not found in {prefix} metrics"
)
else:
distillation_save_state[full_metric_name] = metrics_source[
metric_name
]

with timer.time("checkpointing"):
print(
Expand Down
33 changes: 26 additions & 7 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,15 +649,34 @@ def dpo_train(
if val_metrics is not None:
dpo_save_state.update(val_metrics)

if master_config["checkpointing"]["metric_name"] is not None:
if (
master_config["checkpointing"]["metric_name"]
not in dpo_save_state
):
full_metric_name = master_config["checkpointing"]["metric_name"]
if full_metric_name is not None:
assert full_metric_name.startswith(
"train:"
) or full_metric_name.startswith("val:"), (
f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n"
f'followed by the corresponding name in the "val" or "train" metrics dictionary.'
f" If you are using an old config, please updated checkpointing.metric_name to the new format, "
f" e.g. 'val_loss --> 'val:validation-default_loss'"
)
prefix, metric_name = full_metric_name.split(":", 1)
metrics_source = metrics if prefix == "train" else val_metrics
if not metrics_source:
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"This checkpoint will not be saved as top-k."
f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. "
"This checkpoint will not be saved as top-k.",
stacklevel=2,
)
if full_metric_name in dpo_save_state:
del dpo_save_state[full_metric_name]
elif metric_name not in metrics_source:
raise ValueError(
f"Metric {metric_name} not found in {prefix} metrics"
)
else:
dpo_save_state[full_metric_name] = metrics_source[
metric_name
]

with timer.time("checkpointing"):
print(f"Saving checkpoint for step {total_steps + 1}...")
Expand Down
67 changes: 52 additions & 15 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,15 +920,34 @@ def grpo_train(
del grpo_save_state["val_reward"]
grpo_save_state["consumed_samples"] = consumed_samples

if master_config["checkpointing"]["metric_name"] is not None:
if (
master_config["checkpointing"]["metric_name"]
not in grpo_save_state
):
full_metric_name = master_config["checkpointing"]["metric_name"]
if full_metric_name is not None:
assert full_metric_name.startswith(
"train:"
) or full_metric_name.startswith("val:"), (
f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n"
f'followed by the corresponding name in the "val" or "train" metrics dictionary.'
f" If you are using an old config, please updated checkpointing.metric_name to the new format, "
f" e.g. 'val_reward --> 'val:reward'"
)
Comment on lines +923 to +932
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix the migration hint for GRPO checkpoint metrics.

The guidance here tells users to migrate val_reward to val:reward, but GRPO’s validation metrics only expose accuracy. Following the suggested string will immediately trigger the new ValueError("Metric reward not found in val metrics"), breaking checkpointing for anyone applying the hint verbatim. Please point to the real key (val:accuracy) and clean up the typo so the migration path actually succeeds.

-                            f"  If you are using an old config, please updated checkpointing.metric_name to the new format, "
-                            f" e.g. 'val_reward --> 'val:reward'"
+                            f"  If you are using an old config, please update checkpointing.metric_name to the new format, "
+                            f" e.g. 'val_reward' --> 'val:accuracy'"
πŸ€– Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 923 to 932, the migration hint
incorrectly suggests converting "val_reward" β†’ "val:reward" even though GRPO
exposes only "accuracy" in validation metrics and the example contains a stray
quote; update the assertion error message to reference the correct validation
metric key (e.g., "val_accuracy" β†’ "val:accuracy") and remove the
formatting/typo so the example is valid and will work when followed.

prefix, metric_name = full_metric_name.split(":", 1)
metrics_source = metrics if prefix == "train" else val_metrics
if not metrics_source:
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"This checkpoint will not be saved as top-k."
f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. "
"This checkpoint will not be saved as top-k.",
stacklevel=2,
)
if full_metric_name in grpo_save_state:
del grpo_save_state[full_metric_name]
elif metric_name not in metrics_source:
raise ValueError(
f"Metric {metric_name} not found in {prefix} metrics"
)
else:
grpo_save_state[full_metric_name] = metrics_source[
metric_name
]

with timer.time("checkpointing"):
print(
Expand Down Expand Up @@ -1726,16 +1745,34 @@ def async_grpo_train(
del grpo_save_state["val_reward"]
grpo_save_state["consumed_samples"] = consumed_samples

if master_config["checkpointing"]["metric_name"] is not None:
if (
master_config["checkpointing"]["metric_name"]
not in grpo_save_state
):
full_metric_name = master_config["checkpointing"]["metric_name"]
if full_metric_name is not None:
assert full_metric_name.startswith(
"train:"
) or full_metric_name.startswith("val:"), (
f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n"
f'followed by the corresponding name in the "val" or "train" metrics dictionary.'
f" If you are using an old config, please updated checkpointing.metric_name to the new format, "
f" e.g. 'val_reward --> 'val:accuracy'"
)
prefix, metric_name = full_metric_name.split(":", 1)
metrics_source = metrics if prefix == "train" else val_metrics
if not metrics_source:
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"Saving most recent k checkpoints instead."
f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. "
"This checkpoint will not be saved as top-k.",
stacklevel=2,
)
if full_metric_name in grpo_save_state:
del grpo_save_state[full_metric_name]
elif metric_name not in metrics_source:
raise ValueError(
f"Metric {metric_name} not found in {prefix} metrics"
)
master_config["checkpointing"]["metric_name"] = None
else:
grpo_save_state[full_metric_name] = metrics_source[
metric_name
]

with timer.time("checkpointing"):
print(f"Saving checkpoint for step {step + 1}...")
Expand Down
34 changes: 26 additions & 8 deletions nemo_rl/algorithms/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,16 +577,34 @@ def rm_train(
if val_metrics is not None:
rm_save_state.update(val_metrics)

if master_config["checkpointing"]["metric_name"] is not None:
if (
master_config["checkpointing"]["metric_name"]
not in rm_save_state
):
full_metric_name = master_config["checkpointing"]["metric_name"]
if full_metric_name is not None:
assert full_metric_name.startswith(
"train:"
) or full_metric_name.startswith("val:"), (
f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n"
f'followed by the corresponding name in the "val" or "train" metrics dictionary.'
f" If you are using an old config, please updated checkpointing.metric_name to the new format, "
f" e.g. 'val_loss --> 'val:validation-default_loss'"
)
prefix, metric_name = full_metric_name.split(":", 1)
metrics_source = metrics if prefix == "train" else val_metrics
if not metrics_source:
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"Saving most recent k checkpoints instead."
f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. "
"This checkpoint will not be saved as top-k.",
stacklevel=2,
)
if full_metric_name in rm_save_state:
del rm_save_state[full_metric_name]
elif metric_name not in metrics_source:
raise ValueError(
f"Metric {metric_name} not found in {prefix} metrics"
)
master_config["checkpointing"]["metric_name"] = None
else:
rm_save_state[full_metric_name] = metrics_source[
metric_name
]

with timer.time("checkpointing"):
print(f"Saving checkpoint for step {total_steps + 1}...")
Expand Down
39 changes: 27 additions & 12 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,20 +515,35 @@ def sft_train(
sft_save_state["total_steps"] = total_steps + 1
sft_save_state["epoch"] = current_epoch
sft_save_state["total_valid_tokens"] = total_valid_tokens
if val_metrics is not None:
sft_save_state["val_loss"] = val_metrics["val_loss"]
elif "val_loss" in sft_save_state:
del sft_save_state["val_loss"]

if master_config["checkpointing"]["metric_name"] is not None:
if (
master_config["checkpointing"]["metric_name"]
not in sft_save_state
):

full_metric_name = master_config["checkpointing"]["metric_name"]
if full_metric_name is not None:
assert full_metric_name.startswith(
"train:"
) or full_metric_name.startswith("val:"), (
f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n"
f'followed by the corresponding name in the "val" or "train" metrics dictionary.'
f" If you are using an old config, please updated checkpointing.metric_name to the new format, "
f" e.g. 'val_loss --> 'val:val_loss'"
)
prefix, metric_name = full_metric_name.split(":", 1)
metrics_source = metrics if prefix == "train" else val_metrics
if not metrics_source:
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"This checkpoint will not be saved as top-k."
f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. "
"This checkpoint will not be saved as top-k.",
stacklevel=2,
)
if full_metric_name in sft_save_state:
del sft_save_state[full_metric_name]
elif metric_name not in metrics_source:
raise ValueError(
f"Metric {metric_name} not found in {prefix} metrics"
)
else:
sft_save_state[full_metric_name] = metrics_source[
metric_name
]

with timer.time("checkpointing"):
print(f"Saving checkpoint for step {total_steps + 1}...")
Expand Down
Loading
Loading