-
Notifications
You must be signed in to change notification settings - Fork 171
cp: fix: support arbitrary values for checkpointing.metric_name (1291) into r0.4.0
#1449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: ashors1 <ashors@nvidia.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
📝 WalkthroughWalkthroughThe PR migrates checkpoint metric naming conventions to use explicit "val:" or "train:" prefixes indicating metric source phase. Configuration files are updated to use the new format, and algorithm implementations now validate and parse the prefixed metric names to select appropriate metrics sources during checkpointing. Changes
Sequence DiagramsequenceDiagram
actor Config as Configuration
participant CkptLogic as Checkpoint Logic
participant MetricSrc as Metrics Source
participant SaveState as Save State
Note over Config,SaveState: Old Behavior
Config->>CkptLogic: metric_name = "val_loss"
CkptLogic->>MetricSrc: Check if "val_loss" exists
alt Found
MetricSrc-->>CkptLogic: Return metric value
CkptLogic->>SaveState: Save under "val_loss"
else Not Found
MetricSrc-->>CkptLogic: Not found
CkptLogic->>SaveState: Skip or use fallback
end
Note over Config,SaveState: New Behavior
Config->>CkptLogic: metric_name = "val:val_loss"
rect rgb(220, 240, 255)
CkptLogic->>CkptLogic: Parse prefix ("val") & name ("val_loss")
CkptLogic->>CkptLogic: Validate format (must have prefix)
CkptLogic->>MetricSrc: Select val_metrics by prefix
end
alt Metrics source exists & metric found
MetricSrc-->>CkptLogic: Return metric value
CkptLogic->>SaveState: Save under "val:val_loss"
else Metrics source empty
MetricSrc-->>CkptLogic: No metrics
CkptLogic->>SaveState: Warning & remove entry
else Metric not in source
MetricSrc-->>CkptLogic: Metric missing
CkptLogic->>CkptLogic: Raise ValueError
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (3 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/configs/rm.yaml (1)
18-18: Align the lower guidance comment with the new prefix requirement.Line 18 documents the
val:/train:prefix format, but the guidance block around Line 158 still showsmetric_name: "validation-..._loss"without the prefix. Please update that example to include the prefix (e.g.,val:validation-foo_loss) so readers don’t revert to the legacy form.nemo_rl/algorithms/dpo.py (1)
654-661: Raise a ValueError instead of relying onassert.
assertcan be stripped withPYTHONOPTIMIZE, so malformed configs could slip through. Please switch to an explicit conditional that raisesValueError, and fix the example text while you’re there.- assert full_metric_name.startswith( - "train:" - ) or full_metric_name.startswith("val:"), ( - f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" - f'followed by the corresponding name in the "val" or "train" metrics dictionary.' - f" If you are using an old config, please updated checkpointing.metric_name to the new format, " - f" e.g. 'val_loss --> 'val:validation-default_loss'" - ) + if not ( + full_metric_name.startswith("train:") + or full_metric_name.startswith("val:") + ): + raise ValueError( + "checkpointing.metric_name must start with 'val:' or 'train:' " + "followed by a metric present in the corresponding metrics dictionary. " + "For example, update 'val_loss' to 'val:validation-default_loss'." + )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
examples/configs/distillation_math.yaml(1 hunks)examples/configs/dpo.yaml(2 hunks)examples/configs/grpo_math_1B.yaml(1 hunks)examples/configs/grpo_math_1B_megatron.yaml(1 hunks)examples/configs/grpo_sliding_puzzle.yaml(1 hunks)examples/configs/rm.yaml(1 hunks)examples/configs/sft.yaml(1 hunks)examples/configs/sft_openmathinstruct2.yaml(1 hunks)examples/configs/sft_openmathinstruct2_megatron.yaml(1 hunks)examples/configs/sft_vlm_3B.yaml(1 hunks)examples/configs/vlm_grpo_3B.yaml(1 hunks)examples/configs/vlm_grpo_3B_megatron.yaml(1 hunks)nemo_rl/algorithms/distillation.py(1 hunks)nemo_rl/algorithms/dpo.py(1 hunks)nemo_rl/algorithms/grpo.py(2 hunks)nemo_rl/algorithms/rm.py(1 hunks)nemo_rl/algorithms/sft.py(1 hunks)nemo_rl/utils/checkpoint.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/sft_vlm_3B.yamlexamples/configs/sft.yamlexamples/configs/grpo_math_1B_megatron.yamlexamples/configs/sft_openmathinstruct2.yamlexamples/configs/rm.yamlexamples/configs/grpo_math_1B.yamlexamples/configs/vlm_grpo_3B.yamlexamples/configs/vlm_grpo_3B_megatron.yamlexamples/configs/dpo.yamlexamples/configs/distillation_math.yamlexamples/configs/grpo_sliding_puzzle.yamlexamples/configs/sft_openmathinstruct2_megatron.yaml
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/algorithms/dpo.pynemo_rl/algorithms/rm.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/grpo.pynemo_rl/utils/checkpoint.pynemo_rl/algorithms/distillation.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/dpo.pynemo_rl/algorithms/rm.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/grpo.pynemo_rl/utils/checkpoint.pynemo_rl/algorithms/distillation.py
🧠 Learnings (4)
📓 Common learnings
Learnt from: zpqiu
PR: NVIDIA-NeMo/RL#1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:19-26
Timestamp: 2025-09-18T13:26:43.307Z
Learning: In on-policy distillation workflows, validation can use downstream task performance (like math problem solving) as RL-like reward metrics rather than traditional distillation metrics like KL divergence. In this case, "val_reward" with "higher_is_better: true" is the correct checkpoint monitoring configuration.
📚 Learning: 2025-09-18T13:26:43.307Z
Learnt from: zpqiu
PR: NVIDIA-NeMo/RL#1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:19-26
Timestamp: 2025-09-18T13:26:43.307Z
Learning: In on-policy distillation workflows, validation can use downstream task performance (like math problem solving) as RL-like reward metrics rather than traditional distillation metrics like KL divergence. In this case, "val_reward" with "higher_is_better: true" is the correct checkpoint monitoring configuration.
Applied to files:
examples/configs/sft_vlm_3B.yamlexamples/configs/sft.yamlexamples/configs/grpo_math_1B_megatron.yamlexamples/configs/sft_openmathinstruct2.yamlexamples/configs/rm.yamlexamples/configs/grpo_math_1B.yamlexamples/configs/vlm_grpo_3B.yamlexamples/configs/vlm_grpo_3B_megatron.yamlexamples/configs/dpo.yamlexamples/configs/distillation_math.yamlexamples/configs/grpo_sliding_puzzle.yamlnemo_rl/algorithms/distillation.pyexamples/configs/sft_openmathinstruct2_megatron.yaml
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
PR: NVIDIA-NeMo/RL#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
Applied to files:
nemo_rl/utils/checkpoint.py
📚 Learning: 2025-09-19T03:19:35.875Z
Learnt from: shuo-nvidia
PR: NVIDIA-NeMo/RL#1006
File: tests/unit/algorithms/test_loss_functions.py:1580-1619
Timestamp: 2025-09-19T03:19:35.875Z
Learning: The DistillationLossFn in nemo_rl/algorithms/loss_functions.py does not have k truncation logic - it processes whatever topk size is provided without capping it to vocabulary size or other limits. Large k values in tests will create correspondingly large GPU tensors.
Applied to files:
nemo_rl/algorithms/distillation.py
🪛 Ruff (0.14.2)
nemo_rl/algorithms/dpo.py
673-675: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/algorithms/rm.py
601-603: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/algorithms/sft.py
540-542: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/algorithms/grpo.py
944-946: Avoid specifying long messages outside the exception class
(TRY003)
1769-1771: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/algorithms/distillation.py
810-812: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (15)
examples/configs/vlm_grpo_3B_megatron.yaml (1)
30-30: Metric naming convention update is consistent and well-documented.The prefix-based metric naming ("val:accuracy") is clearly documented inline, aligning with the broader convention shift across configurations. The format is consistent with other GRPO configs in this PR.
examples/configs/grpo_math_1B.yaml (1)
38-38: Consistent with metric naming convention across GRPO configs.The prefixed metric name ("val:accuracy") and inline documentation maintain consistency with related configurations. The quoted format is valid YAML.
examples/configs/vlm_grpo_3B.yaml (1)
35-35: Consistent application of the prefix-based metric naming convention.The metric configuration aligns with other GRPO variants and includes clear inline documentation.
examples/configs/grpo_math_1B_megatron.yaml (1)
32-32: Metric naming consistent with GRPO configuration convention.The prefixed metric name is properly documented. Note that checkpointing is disabled in this config (line 30), so this metric_name serves primarily as documentation for the expected format when checkpointing is enabled.
examples/configs/sft_openmathinstruct2_megatron.yaml (1)
17-17: Metric naming for SFT is consistent and semantically appropriate.The "val:val_loss" metric with "higher_is_better: false" correctly reflects loss-based monitoring for supervised fine-tuning, properly prefixed per the new convention.
examples/configs/sft_vlm_3B.yaml (1)
19-19: SFT VLM metric configuration follows the established convention.The "val:val_loss" metric with "higher_is_better: false" is properly prefixed and semantically coherent for loss monitoring across SFT variants.
examples/configs/sft_openmathinstruct2.yaml (1)
15-15: SFT base configuration metric naming is consistent and properly documented.The "val:val_loss" metric follows the prefix convention consistently across SFT configurations with appropriate semantic pairing (higher_is_better: false).
examples/configs/distillation_math.yaml (1)
23-23: Metric naming follows new convention, but verify metric availability.The "val:accuracy" metric with "higher_is_better: true" is properly prefixed and documented. However, the retrieved learning indicates distillation workflows typically monitor "val_reward" for downstream task performance. Verify that:
- The distillation algorithm actually emits an "accuracy" metric
- This metric represents the intended downstream task performance signal (e.g., math problem accuracy)
This is a cherry-pick from PR #1291, so implementation validation should have occurred there, but confirming metric availability would reduce risk.
examples/configs/sft.yaml (1)
18-18: Prefixed metric name looks good.
val:val_lossmatches the new checkpointing semantics.examples/configs/dpo.yaml (1)
25-25: Metric rename matches the validation namespace.
val:validation-default_losslines up with the keys emitted from validation and satisfies the new prefix contract.examples/configs/grpo_sliding_puzzle.yaml (1)
14-14: Nice switch to an explicit validation metric.Tracking
val:accuracywithhigher_is_better: truekeeps checkpointing aligned with the success criterion for this recipe.nemo_rl/utils/checkpoint.py (1)
43-44: Docstring update correctly communicates the prefix contract.This will help config authors avoid the legacy unprefixed metrics.
nemo_rl/algorithms/distillation.py (1)
789-816: Checkpoint metric resolution looks solid.The prefix validation and metric extraction mirror the shared pattern cleanly; no issues spotted here.
nemo_rl/algorithms/sft.py (1)
519-546: SFT checkpoint metric handling is tidy.The new
train:/val:parsing integrates smoothly with the existing save-state updates.nemo_rl/algorithms/rm.py (1)
580-607: RM checkpoint metric logic aligns with the shared helper pattern.Validation and warning behaviour look correct for both train and val prefixes.
| full_metric_name = master_config["checkpointing"]["metric_name"] | ||
| if full_metric_name is not None: | ||
| assert full_metric_name.startswith( | ||
| "train:" | ||
| ) or full_metric_name.startswith("val:"), ( | ||
| f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" | ||
| f'followed by the corresponding name in the "val" or "train" metrics dictionary.' | ||
| f" If you are using an old config, please updated checkpointing.metric_name to the new format, " | ||
| f" e.g. 'val_reward --> 'val:reward'" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the migration hint for GRPO checkpoint metrics.
The guidance here tells users to migrate val_reward to val:reward, but GRPO’s validation metrics only expose accuracy. Following the suggested string will immediately trigger the new ValueError("Metric reward not found in val metrics"), breaking checkpointing for anyone applying the hint verbatim. Please point to the real key (val:accuracy) and clean up the typo so the migration path actually succeeds.
- f" If you are using an old config, please updated checkpointing.metric_name to the new format, "
- f" e.g. 'val_reward --> 'val:reward'"
+ f" If you are using an old config, please update checkpointing.metric_name to the new format, "
+ f" e.g. 'val_reward' --> 'val:accuracy'"🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 923 to 932, the migration hint
incorrectly suggests converting "val_reward" → "val:reward" even though GRPO
exposes only "accuracy" in validation metrics and the example contains a stray
quote; update the assertion error message to reference the correct validation
metric key (e.g., "val_accuracy" → "val:accuracy") and remove the
formatting/typo so the example is valid and will work when followed.
…291)` into `r0.4.0` (#1449) Signed-off-by: ashors1 <ashors@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com> Co-authored-by: Anna Shors <ashors@nvidia.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
beep boop [🤖]: Hi @ashors1 👋,
Summary by CodeRabbit
Chores
Updated checkpointing metric naming convention — Checkpointing now requires metrics to be prefixed with either "train:" or "val:" to specify the source (e.g., "val:accuracy" instead of "val_reward"). All example configurations have been updated to reflect the new format.
Enhanced metric validation — The system now validates metric names against the new prefix format and provides clearer error messages when metrics are missing or improperly configured.