Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ loss_fn:
truncated_importance_sampling_ratio: null
sequence_level_importance_ratios: false
token_level_loss: true
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)

checkpointing:
enabled: true
Expand Down
35 changes: 33 additions & 2 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,17 @@ def init_vllm():

loss_fn = ClippedPGLossFn(loss_config)

# Validate force_on_policy_ratio
if loss_config.get("force_on_policy_ratio", False):
assert (
grpo_config["num_prompts_per_step"]
* grpo_config["num_generations_per_prompt"]
== policy_config["train_global_batch_size"]
), (
"force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt"
)
print(" ✓ force_on_policy_ratio enabled")

# Calculate total setup time
total_setup_time = time.perf_counter() - setup_start_time
worker_init_timing_metrics["total_setup_time_s"] = total_setup_time
Expand Down Expand Up @@ -1326,7 +1337,17 @@ def grpo_train(

metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.min(valid_values).item() if valid_values else -1.0
)
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.max(valid_values).item() if valid_values else -1.0
)
elif k in {
"lr",
"wd",
"reward",
Expand Down Expand Up @@ -2229,7 +2250,17 @@ def async_grpo_train(
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.min(valid_values).item() if valid_values else -1.0
)
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.max(valid_values).item() if valid_values else -1.0
)
elif k in {
"lr",
"wd",
"reward",
Expand Down
59 changes: 56 additions & 3 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, NotRequired, Optional, TypedDict, TypeVar

import torch
Expand Down Expand Up @@ -50,6 +51,12 @@ class ClippedPGLossConfig(TypedDict):
# If False (default), correction is applied at the token level as in the
# original GRPO paper.
sequence_level_importance_ratios: NotRequired[bool]
disable_ppo_ratio: NotRequired[bool]
# If True, force the ratio to 1.0 for truly on-policy behavior,
# eliminating any importance sampling effects.
# NOTE: This should only be used when doing exactly one update per rollout
# (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size)
force_on_policy_ratio: NotRequired[bool]


class ClippedPGLossDataDict(TypedDict):
Expand All @@ -74,6 +81,7 @@ class ClippedPGLossFn(LossFunction):
- GRPO - https://arxiv.org/abs/2402.03300
- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071
- Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout)

Formula:
L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)
Expand Down Expand Up @@ -114,6 +122,9 @@ def __init__(self, cfg: ClippedPGLossConfig):
self.kl_input_clamp_value = cfg["kl_input_clamp_value"]
self.kl_output_clamp_value = cfg["kl_output_clamp_value"]
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)
self.force_on_policy_ratio = cfg.get(
"force_on_policy_ratio", False
) # Force ratio to 1.0
self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"]
self.use_importance_sampling_correction = cfg[
"use_importance_sampling_correction"
Expand Down Expand Up @@ -296,7 +307,13 @@ def __call__(
kl = torch.tensor(0.0)

# Calculate clipped loss function if ppo ratio is enabled.
if not self.disable_ppo_ratio:
if self.force_on_policy_ratio:
# Force ratio to 1.0 for truly on-policy behavior
# Use curr_logprobs twice so ratio=1 but gradients still flow
log_ratios = curr_logprobs - curr_logprobs.detach()
ratios = log_ratios.exp() # = exp(0) = 1.0, but depends on curr_logprobs
ratios_clamped = ratios
elif not self.disable_ppo_ratio:
log_ratios = curr_logprobs - prev_logprobs
if self.sequence_level_importance_ratios:
seq_log_ratio_mean = masked_mean(
Expand Down Expand Up @@ -419,6 +436,22 @@ def __call__(
global_normalization_factor=global_valid_toks,
).item()

# Calculate min/max values for ratios (only for valid tokens)
masked_ratios = ratios.detach()[mask.bool()]
masked_ratios_clamped = ratios_clamped.detach()[mask.bool()]

# Handle edge case where there might be no valid tokens
if masked_ratios.numel() > 0:
probs_ratio_min = masked_ratios.min().item()
probs_ratio_max = masked_ratios.max().item()
probs_ratio_clamped_min = masked_ratios_clamped.min().item()
probs_ratio_clamped_max = masked_ratios_clamped.max().item()
else:
probs_ratio_min = float("inf")
probs_ratio_max = float("-inf")
probs_ratio_clamped_min = float("inf")
probs_ratio_clamped_max = float("-inf")

# If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized
# by either sequence or token count, depending on particular metric.
# To get the true metric, you'll need to sum over the microbatch.
Expand All @@ -428,6 +461,10 @@ def __call__(
"loss": loss.item(),
"probs_ratio": probs_ratio,
"probs_ratio_clamped": probs_ratio_clamped,
"probs_ratio_min": probs_ratio_min,
"probs_ratio_max": probs_ratio_max,
"probs_ratio_clamped_min": probs_ratio_clamped_min,
"probs_ratio_clamped_max": probs_ratio_clamped_max,
"kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0,
"token_mult_prob_error": mult_prob_error,
"gen_kl_error": gen_kl_error,
Expand Down Expand Up @@ -903,8 +940,24 @@ def __call__(
loss_accum += loss
for k, v in metrics.items():
if k not in metrics_accum:
metrics_accum[k] = 0
metrics_accum[k] += v
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
metrics_accum[k] = float("inf")
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
metrics_accum[k] = float("-inf")
else:
metrics_accum[k] = 0

val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v

# Skip inf/-inf sentinel values (from sequences with no valid tokens)
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
if not math.isinf(val):
metrics_accum[k] = min(metrics_accum[k], val)
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
if not math.isinf(val):
metrics_accum[k] = max(metrics_accum[k], val)
else:
metrics_accum[k] += val

return loss_accum, metrics_accum

Expand Down
Loading