-
Notifications
You must be signed in to change notification settings - Fork 159
feat: Add support for IPO and RPO #1388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Add support for IPO and RPO #1388
Conversation
Signed-off-by: Sanjana Ravi <sanjana@inflection.ai>
Signed-off-by: Sanjana Ravi <sanjana@inflection.ai>
Signed-off-by: Sanjana Ravi <sanjana@inflection.ai>
Signed-off-by: Sanjana Ravi <sanjana@inflection.ai>
📝 WalkthroughWalkthroughThis PR extends DPO training to support multiple preference loss objectives (IPO, RPO variants) by adding optional reward fields to data schemas, introducing new DPO configuration options for preference loss selection and ground-truth reward scaling, and implementing corresponding loss computation logic. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/data/collate_fn.py (1)
152-165
: Make rewards a torch tensor and validate presence once.Currently
rewards
is a Python list;DPOLossFn
uses tensor ops and will error. Also, assert triggers late if only some samples include rewards.Apply:
@@ - rewards = [] + rewards: list[float] = [] + # Ensure all-or-none rewards presence for the batch to avoid silent partials + has_rewards_for_all = all( + ("reward_chosen" in ds and "reward_rejected" in ds) for ds in data_batch + ) @@ - if "reward_chosen" in datum_spec and "reward_rejected" in datum_spec: - rewards.append(datum_spec["reward_chosen"]) - rewards.append(datum_spec["reward_rejected"]) + if has_rewards_for_all: + rewards.append(float(datum_spec["reward_chosen"])) # type: ignore[arg-type] + rewards.append(float(datum_spec["reward_rejected"])) # type: ignore[arg-type] @@ - if rewards: - assert len(rewards) == len(message_log), ( - f"rewards length ({len(rewards)}) and message_log length ({len(message_log)}) mismatch" - ) + if rewards: + if len(rewards) != len(message_log): + raise ValueError( + f"rewards length ({len(rewards)}) and message_log length ({len(message_log)}) mismatch; " + "either provide rewards for every (chosen,rejected) pair in the batch or for none." + ) @@ - if rewards: - data["rewards"] = rewards + if rewards: + # Align dtype/device with input_ids + data["rewards"] = torch.tensor(rewards, dtype=torch.float32, device=data["input_ids"].device)Also applies to: 166-168, 204-205
nemo_rl/algorithms/loss_functions.py (1)
700-707
: Avoid hardcoded.cuda()
; use the logits’ device.Using
.cuda()
breaks on CPU-only runs and heterogeneous devices. Align withnext_token_logits
.Apply:
- next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + next_tokens = data["input_ids"][:, 1:].to(next_token_logits.device) # Skip first tokenConsider making the same change in
NLLLoss
for consistency (separate PR ok).
🧹 Nitpick comments (8)
nemo_rl/data/interfaces.py (1)
50-51
: Document new DPODatumSpec fields.Add concise comments describing semantics to keep schema self-explanatory.
Apply:
- reward_chosen: NotRequired[float] - reward_rejected: NotRequired[float] + reward_chosen: NotRequired[float] # Ground-truth reward for the chosen completion (used by RPO) + reward_rejected: NotRequired[float] # Ground-truth reward for the rejected completion (used by RPO)nemo_rl/algorithms/loss_functions.py (2)
595-609
: Replace Greek σ in docstrings to satisfy linters and improve clarity.Use
sigmoid(x)
instead ofσ(x)
; also render Δ consistently asDelta
.Apply (excerpt):
- L_pref(θ) = -E[log(σ(β * Δ_r))] + L_pref(theta) = -E[log(sigmoid(beta * Delta_r))] @@ - L_pref(θ) = E[(Δ_r - (1/(2β))) ^ 2] + L_pref(theta) = E[(Delta_r - (1/(2*beta)))^2] @@ - L_pref(θ) = E[(Δ_r - Δ_gtr) ^ 2] + L_pref(theta) = E[(Delta_r - Delta_gtr)^2] @@ - - σ is the sigmoid function - - β is the reference_policy_kl_penalty + - sigmoid is the logistic function + - beta is the reference_policy_kl_penalty
805-807
: Remove/adjust outdated TODO.Comment refers to inheriting from PreferenceLoss, but
DPOLossFn
now inherits fromLossFunction
.Replace with a current action item or drop the comment.
tests/unit/algorithms/test_loss_functions.py (4)
348-380
: Replace unusedmetrics_dict
with underscore.The
metrics_dict
variable is unpacked but never used. Replace it with_
to indicate it's intentionally unused.Apply this diff:
- loss, metrics_dict = loss_fn( + loss, _ = loss_fn(
382-416
: Replace unusedmetrics_dict
with underscore.The
metrics_dict
variable is unpacked but never used. Replace it with_
to indicate it's intentionally unused.Apply this diff:
- loss, metrics_dict = loss_fn( + loss, _ = loss_fn(
418-452
: Replace unusedmetrics_dict
with underscore.The
metrics_dict
variable is unpacked but never used. Replace it with_
to indicate it's intentionally unused.Apply this diff:
- loss, metrics_dict = loss_fn( + loss, _ = loss_fn(
454-488
: Replace unusedmetrics_dict
with underscore.The
metrics_dict
variable is unpacked but never used. Replace it with_
to indicate it's intentionally unused.Apply this diff:
- loss, metrics_dict = loss_fn( + loss, _ = loss_fn(docs/guides/dpo.md (1)
49-49
: Clarify reward field requirement and use consistent capitalization.The comment states "Optional, float" followed by "required for rpo", which may confuse readers. Consider rephrasing to clarify that the field is optional for DPO/IPO but required when using RPO variants. Also, use "RPO" (uppercase) for consistency with other acronym usage in the documentation.
Consider updating the comments to:
- "reward": 10.0, // Optional, float - The ground truth reward of the completion (required for rpo) + "reward": 10.0, // Optional, float - The ground truth reward of the completion (required for RPO variants)Also applies to: 54-54, 87-87, 97-97
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
docs/guides/dpo.md
(5 hunks)examples/run_dpo.py
(1 hunks)nemo_rl/algorithms/loss_functions.py
(6 hunks)nemo_rl/data/collate_fn.py
(3 hunks)nemo_rl/data/datasets/preference_datasets/preference_dataset.py
(1 hunks)nemo_rl/data/interfaces.py
(1 hunks)tests/unit/algorithms/test_loss_functions.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
docs/**/*.md
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
When a markdown doc under docs/**/*.md is added or renamed, update docs/index.md to include it in the appropriate section
Files:
docs/guides/dpo.md
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/data/interfaces.py
examples/run_dpo.py
nemo_rl/data/datasets/preference_datasets/preference_dataset.py
nemo_rl/algorithms/loss_functions.py
nemo_rl/data/collate_fn.py
tests/unit/algorithms/test_loss_functions.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py
: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/data/interfaces.py
nemo_rl/data/datasets/preference_datasets/preference_dataset.py
nemo_rl/algorithms/loss_functions.py
nemo_rl/data/collate_fn.py
🧬 Code graph analysis (2)
nemo_rl/algorithms/loss_functions.py (2)
nemo_rl/algorithms/interfaces.py (2)
LossFunction
(28-70)LossType
(23-25)nemo_rl/algorithms/utils.py (1)
masked_mean
(134-146)
tests/unit/algorithms/test_loss_functions.py (2)
nemo_rl/algorithms/loss_functions.py (1)
DPOLossFn
(565-869)nemo_rl/distributed/batched_data_dict.py (1)
to
(825-832)
🪛 Ruff (0.14.0)
nemo_rl/algorithms/loss_functions.py
596-596: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
605-605: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
605-605: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
605-605: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
605-605: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
605-605: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
605-605: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
608-608: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
608-608: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
608-608: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
608-608: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
608-608: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
608-608: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
611-611: Docstring contains ambiguous σ
(GREEK SMALL LETTER SIGMA). Did you mean o
(LATIN SMALL LETTER O)?
(RUF002)
tests/unit/algorithms/test_loss_functions.py
369-369: Unpacked variable metrics_dict
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
404-404: Unpacked variable metrics_dict
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
440-440: Unpacked variable metrics_dict
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
476-476: Unpacked variable metrics_dict
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (3)
nemo_rl/algorithms/loss_functions.py (1)
737-746
: Ensuredata["rewards"]
is a torch.Tensor.RPO branches expect tensor ops; lists will fail. After applying the collate fix, please re-verify shapes as
(2*B,)
or(2*B, 1)
and dtypes float32.You can sanity-check at runtime:
assert isinstance(data["rewards"], torch.Tensor) assert data["rewards"].ndim in (1, 2)docs/guides/dpo.md (2)
7-11
: LGTM!The "Other Objectives" section clearly introduces IPO and RPO variants with appropriate academic references.
144-145
: LGTM!The new parameter documentation is clear, comprehensive, and accurately describes the new DPO configuration options.
nemo_rl/data/datasets/preference_datasets/preference_dataset.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sanjana Ravi <sanjana@inflection.ai>
Signed-off-by: Sanjana Ravi <sanjana@inflection.ai>
Signed-off-by: Sanjana Ravi <sanjana@inflection.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the contribution @sanjana-inflection . @ashors1 could you review?
What does this PR do ?
Add support for IPO and RPO (with forward and backward KL distance, and squared distance)
Issues
Issue #193:
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests