[trainer] feat: add per-round logprob mismatch metrics for multi-turn training #5229
+210
−2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Add per-round logprob mismatch metrics for multi-turn RL training. In multi-turn trajectories, the
response_maskcontains contiguous segments of 1s for each round of model generation, separated by 0s for environment tokens (e.g., images). This PR detects those segments and computes per-round mean absolute logprob difference between rollout and actor, making it easy to identify which round diverges most.This extends the existing debug metrics (#1712, #2808) without changing the existing API or adding any new dependencies.
Checklist Before Starting
[{modules}] {type}: {description}Test
Existing unit test passes unchanged:
Tested with multi-turn VLM RL training (Qwen2.5-VL) on 8-turn GUI agent trajectories.
API and Usage Example
No API changes.
calculate_debug_metrics(data)signature is unchanged. The returned dict now includes additionalper_round/prefixed keys:Design & Code Changes
verl/utils/debug/metrics.py(+120 lines, single file change):_find_contiguous_segments(mask_1d): Finds contiguous segments of 1s in response_mask to identify round boundaries_calculate_per_round_metrics(train_log_probs, rollout_log_probs, response_mask): Computes mean absolute logprob diff per round, aggregated across the batchcalculate_debug_metrics(): Now calls_calculate_per_round_metricsand includes the resultsChecklist Before Submitting
ci-requestchannel