Refactors the loss calculation to pull it out into a free function #1137
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Also adds a test case covering the bug that Hamish pointed out, and fixes it by switching from sample-wise to token-wise normalization. Before:
Now:
Hamish's original description of the bug:
Note
Extracts GRPO loss computation into a free function, introduces metrics utilities and truncated importance sampling, updates training loop accordingly, and adds focused unit tests.
calculate_loss_and_backward(...)and integrates into training loop.maybe_apply_importance_sampling(...)(truncated importance sampling) gated byargs.truncated_importance_sampling_ratio_cap.compare_logprobs(...)to log vLLM vs local logprob diffs and reverse KL.LossStatistics(inopen_instruct/metrics.py) for KL, clipfrac, policy/total loss, ratio, and optional entropy; updates metric aggregation/return values.local_step % accumulation_steps == 0).open_instruct/metrics.py:masked_mean(...)andLossStatisticsfor centralized metric computation.compare_logprobs,maybe_apply_importance_sampling, andcalculate_loss_and_backward; updates GRPO fast tests to cover new paths.CLAUDE.mdwith a comments policy.Written by Cursor Bugbot for commit adb108b. This will update automatically on new commits. Configure here.