-
Notifications
You must be signed in to change notification settings - Fork 164
fix: support arbitrary values for checkpointing.metric_name
#1291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
|
Thanks for the comments @samodi-nv! I've addressed them |
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
a662a80 to
678dbf3
Compare
📝 WalkthroughWalkthroughThis pull request refactors metric-based checkpointing across the codebase by introducing a namespaced metric format with "train:" or "val:" prefixes. Configuration files are updated to use the new format, and algorithm implementations are enhanced to parse these prefixes, validate metric existence, and handle missing metrics gracefully with warnings. Changes
Sequence DiagramsequenceDiagram
participant Config as Config
participant Algo as Algorithm (SFT/GRPO/etc)
participant Metrics as Metrics Storage
participant SaveState as Checkpoint SaveState
Config->>Algo: metric_name = "val:loss"
Algo->>Algo: Parse prefix<br/>("val" or "train")
alt Prefix valid ("val" or "train")
Algo->>Algo: Extract metric_name
alt Prefix is "val"
Algo->>Metrics: Select val_metrics
else Prefix is "train"
Algo->>Metrics: Select train metrics
end
alt Metric exists
Metrics-->>Algo: metric_value
Algo->>SaveState: Store metric under<br/>"val:loss" key
else Metric missing
Algo-->>Algo: Emit warning
Algo->>SaveState: Remove "val:loss"<br/>entry if present
end
else Invalid prefix
Algo-->>Algo: Emit warning
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes The changes demonstrate consistent patterns across configuration files (homogeneous updates), but the algorithm implementations introduce heterogeneous logic for metric parsing, validation, and conditional branching based on metric availability. Multiple files require reasoning about the new control flow and error-handling paths. Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
examples/configs/rm.yaml (1)
152-155: Update stale example to match new prefix requirementThe comment still shows the pre-change form. Recommend this fix for consistency:
- # metric_name: "validation-<NameOfValidationDataset1>_loss" + # metric_name: "val:validation-<NameOfValidationDataset1>_loss"nemo_rl/algorithms/rm.py (1)
308-312: Wrong config key in assertion (RM uses rm, not dpo).Assertion references master_config["dpo"]["val_period"]; should be ["rm"]["val_period"]. This can raise KeyError or mask validation misconfig.
Apply:
- assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + assert val_dataloader is not None or master_config["rm"]["val_period"] == 0, ( "val_dataloader is None, so dpo.val_period must be 0" )nemo_rl/algorithms/grpo.py (1)
1054-1056: Wrong config key in GRPO validate assertion.Should reference grpo.val_period, not dpo.val_period.
Apply:
- assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + assert val_dataloader is not None or master_config["grpo"]["val_period"] == 0, ( "val_dataloader is None, so dpo.val_period must be 0" )
🧹 Nitpick comments (3)
nemo_rl/algorithms/sft.py (1)
511-520: Consider more explicit parsing logic for clarity.The logic at line 519 uses
"val" in parts[0]to determine the metric source. While the assertion above ensures the format is correct, using an explicit comparison would be clearer:- train_or_val = "val" if "val" in parts[0] else "train" + train_or_val = parts[0] # Already validated to be "val" or "train"This makes the intent clearer and leverages the assertion's validation.
nemo_rl/algorithms/distillation.py (1)
734-759: Parse metric prefix safely and preserve metric names containing colonsCurrent parsing uses
split(":")without maxsplit, which fails for metrics containing colons after the prefix, and relies on substring checks instead of exact comparison. Usesplit(":", 1)for safe parsing.Apply consistently across all algorithm files:
- parts = full_metric_name.split(":") - train_or_val = "val" if "val" in parts[0] else "train" - metric_name = parts[1] + train_or_val, metric_name = full_metric_name.split(":", 1) + assert train_or_val in ("train", "val"), ( + f"Invalid metric prefix '{train_or_val}'. Expected 'train' or 'val'." + )Files to update:
nemo_rl/algorithms/distillation.py:741nemo_rl/algorithms/sft.py:518nemo_rl/algorithms/rm.py:582nemo_rl/algorithms/dpo.py:654nemo_rl/algorithms/grpo.py:914andgrpo.py:1714Optionally, improve the warning message for clarity:
- warnings.warn( - f"You asked to save checkpoints based on {metric_name} but the metric is not found in the {train_or_val} metrics. " - "This checkpoint will not be saved as top-k.", - stacklevel=2, - ) + warnings.warn( + f"Checkpoint metric '{full_metric_name}' not found in {train_or_val} metrics; skipping top-k update.", + stacklevel=2, + )nemo_rl/algorithms/grpo.py (1)
907-931: Deduplicate metric_name parsing via a small utility.Parsing logic is duplicated across RM/DPO/GRPO (sync + async). Consider a shared helper in nemo_rl/utils/checkpoint.py, e.g., parse_checkpoint_metric_name(full_metric_name) -> tuple[prefix, metric], and reuse. Reduces drift and enforces a single policy.
If helpful, I can draft the utility and apply call-site changes across modules.
Also applies to: 1709-1731
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
examples/configs/distillation_math.yaml(1 hunks)examples/configs/dpo.yaml(2 hunks)examples/configs/grpo_math_1B.yaml(1 hunks)examples/configs/grpo_math_1B_megatron.yaml(1 hunks)examples/configs/grpo_sliding_puzzle.yaml(1 hunks)examples/configs/rm.yaml(1 hunks)examples/configs/sft.yaml(1 hunks)examples/configs/sft_openmathinstruct2.yaml(1 hunks)examples/configs/sft_openmathinstruct2_megatron.yaml(1 hunks)examples/configs/sft_vlm_3B.yaml(1 hunks)examples/configs/vlm_grpo_3B.yaml(1 hunks)examples/configs/vlm_grpo_3B_megatron.yaml(1 hunks)nemo_rl/algorithms/distillation.py(1 hunks)nemo_rl/algorithms/dpo.py(1 hunks)nemo_rl/algorithms/grpo.py(2 hunks)nemo_rl/algorithms/rm.py(1 hunks)nemo_rl/algorithms/sft.py(1 hunks)nemo_rl/utils/checkpoint.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/grpo_math_1B.yamlexamples/configs/sft_openmathinstruct2.yamlexamples/configs/rm.yamlexamples/configs/distillation_math.yamlexamples/configs/vlm_grpo_3B_megatron.yamlexamples/configs/sft_openmathinstruct2_megatron.yamlexamples/configs/sft_vlm_3B.yamlexamples/configs/sft.yamlexamples/configs/grpo_math_1B_megatron.yamlexamples/configs/vlm_grpo_3B.yamlexamples/configs/grpo_sliding_puzzle.yamlexamples/configs/dpo.yaml
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/utils/checkpoint.pynemo_rl/algorithms/distillation.pynemo_rl/algorithms/grpo.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/dpo.pynemo_rl/algorithms/rm.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/utils/checkpoint.pynemo_rl/algorithms/distillation.pynemo_rl/algorithms/grpo.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/dpo.pynemo_rl/algorithms/rm.py
🧠 Learnings (1)
📚 Learning: 2025-09-18T13:26:43.307Z
Learnt from: zpqiu
PR: NVIDIA-NeMo/RL#1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:19-26
Timestamp: 2025-09-18T13:26:43.307Z
Learning: In on-policy distillation workflows, validation can use downstream task performance (like math problem solving) as RL-like reward metrics rather than traditional distillation metrics like KL divergence. In this case, "val_reward" with "higher_is_better: true" is the correct checkpoint monitoring configuration.
Applied to files:
examples/configs/grpo_math_1B.yamlexamples/configs/sft_openmathinstruct2.yamlexamples/configs/rm.yamlexamples/configs/distillation_math.yamlexamples/configs/vlm_grpo_3B_megatron.yamlexamples/configs/sft_openmathinstruct2_megatron.yamlexamples/configs/sft_vlm_3B.yamlexamples/configs/sft.yamlexamples/configs/grpo_math_1B_megatron.yamlexamples/configs/vlm_grpo_3B.yamlexamples/configs/grpo_sliding_puzzle.yamlexamples/configs/dpo.yaml
🪛 Ruff (0.14.0)
nemo_rl/algorithms/grpo.py
922-922: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
1722-1722: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
nemo_rl/algorithms/sft.py
531-531: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
nemo_rl/algorithms/dpo.py
662-662: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
nemo_rl/algorithms/rm.py
590-590: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🔇 Additional comments (13)
nemo_rl/utils/checkpoint.py (1)
43-44: LGTM! Clear documentation of the required metric format.The documentation clearly specifies that metric_name must use "val:" or "train:" prefixes, which aligns with the implementation across algorithm files and addresses the requirement from past review comments.
examples/configs/grpo_math_1B_megatron.yaml (1)
32-36: LGTM! Config correctly adopts the new metric naming convention.The change properly updates the metric_name to use the "val:" prefix and includes a helpful inline comment. The checkpoint_must_save_by field addition aligns with broader checkpointing patterns.
Based on learnings, "val:reward" with "higher_is_better: true" is the correct configuration for RL-based reward metrics.
examples/configs/vlm_grpo_3B_megatron.yaml (1)
29-33: LGTM! Consistent with the new metric naming convention.The configuration correctly adopts the "val:" prefix format with appropriate inline documentation.
examples/configs/grpo_sliding_puzzle.yaml (1)
14-18: LGTM! Consistent adoption of the new format.examples/configs/distillation_math.yaml (1)
22-26: LGTM! Config correctly updated.The metric_name properly uses the new format. The past review comment about documenting format options in checkpoint.py has been addressed in this PR.
examples/configs/sft_openmathinstruct2.yaml (1)
15-19: LGTM! Correctly configured for loss metric.The metric_name properly uses the "val:loss" format, and
higher_is_better: falseis correctly set for loss metrics.examples/configs/grpo_math_1B.yaml (1)
37-41: LGTM! Final config correctly updated.The metric_name properly adopts the "val:reward" format with appropriate documentation.
examples/configs/sft_vlm_3B.yaml (1)
19-19: LGTM: namespaced checkpoint metricSwitch to "val:loss" with clarifying comment matches the new convention; higher_is_better: false remains correct for loss.
examples/configs/sft_openmathinstruct2_megatron.yaml (1)
17-18: LGTM: prefixed metric format"val:loss" and the inline note align with the new full_metric_name workflow.
examples/configs/dpo.yaml (1)
25-26: LGTM: DPO metric now explicitly from validation setUsing "val:validation-default_loss" and updating the comment example reduces ambiguity and matches the new requirement.
Also applies to: 183-184
examples/configs/vlm_grpo_3B.yaml (1)
34-36: LGTM: reward metric correctly namespaced"val:reward" with higher_is_better: true matches GRPO usage.
examples/configs/rm.yaml (1)
18-20: LGTM: namespaced loss metric"val:loss" and higher_is_better: false are consistent.
examples/configs/sft.yaml (1)
18-19: LGTM: prefixed loss metric"val:loss" matches the new convention; no further changes needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. small comment on warning
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Anna Shors <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
closes #1261
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
Chores
Documentation