Skip to content

Conversation

@finbarrtimbers
Copy link
Collaborator

@finbarrtimbers finbarrtimbers commented Nov 3, 2025

Also adds a test case covering the bug that Hamish pointed out, and fixes it by switching from sample-wise to token-wise normalization. Before:

loss = masked_mean(
      pg_loss_max + (args.beta * kl), response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator)
loss = loss / accumulation_steps

Now:

total_loss = pg_loss_max + (args.beta * kl)
loss_sum = (total_loss * response_masks_bool).sum()
denominator = (
      args.masked_mean_denominator if args.masked_mean_denominator is not None else response_masks_bool.sum()
)
loss = loss_sum / denominator

Hamish's original description of the bug:

One thing I realised while writing some stuff up on the weekend - I think our loss computation for the RL is actually technically a little wrong.
We do gradient accumulation to simulate higher batch sizes, but in really basically always do real bsz=1.
Gradient accumulation means we average the gradients for each item in the batch, so we get something like a sample-wise loss (image 1)
But in practice we should probably do a token-level loss like DAPO (image 2), or sum + divide by constant like Dr GRPO?
I think this requires keeping track of the num of tokens in each batch. You swap the loss to take a sum, and then compute loss as:
loss = (loss * grad_acc steps * num_gpus) / num_total_tokens
basically divide the loss by the total number of tokens in the batch, and then we multiply by grad_acc_steps to account for the grad_acc averaging, and num_gpus to account for the fact that each GPU is doing this separately and we are averaging across them…. if that makes sense
theres an explanation here

image image (2)

Note

Extracts GRPO loss computation into a free function, introduces metrics utilities and truncated importance sampling, updates training loop accordingly, and adds focused unit tests.

  • Training/Algo (GRPO):
    • Extracts loss computation to calculate_loss_and_backward(...) and integrates into training loop.
    • Adds maybe_apply_importance_sampling(...) (truncated importance sampling) gated by args.truncated_importance_sampling_ratio_cap.
    • Introduces compare_logprobs(...) to log vLLM vs local logprob diffs and reverse KL.
    • Replaces ad-hoc metrics tracking with LossStatistics (in open_instruct/metrics.py) for KL, clipfrac, policy/total loss, ratio, and optional entropy; updates metric aggregation/return values.
    • Refactors old/vLLM logprobs handling and optimizer stepping condition (local_step % accumulation_steps == 0).
  • New Module:
    • open_instruct/metrics.py: masked_mean(...) and LossStatistics for centralized metric computation.
  • Tests:
    • Adds unit tests for compare_logprobs, maybe_apply_importance_sampling, and calculate_loss_and_backward; updates GRPO fast tests to cover new paths.
  • Docs:
    • Updates CLAUDE.md with a comments policy.

Written by Cursor Bugbot for commit adb108b. This will update automatically on new commits. Configure here.

@finbarrtimbers finbarrtimbers changed the title Refactors the loss calculation to make it testable. Refactors the loss calculation to pull it out into a free function Nov 6, 2025
@finbarrtimbers finbarrtimbers marked this pull request as ready for review November 7, 2025 21:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants