[trainer] fix: normalize sft loss by num_tokens in global batch#3994
[trainer] fix: normalize sft loss by num_tokens in global batch#3994
Conversation
There was a problem hiding this comment.
Code Review
This pull request aims to normalize the SFT loss by the number of tokens in the global batch. While the intention is correct, I've identified a few critical issues in the implementation that will lead to incorrect loss calculation. Firstly, the logic to handle sequence boundaries when rolling the loss_mask in sft_loss has been removed, causing incorrect loss terms at sequence boundaries. Secondly, the calculation of batch_num_tokens in both FSDP and Megatron implementations is flawed because it doesn't account for the tokens that are subsequently masked out in the loss function. This results in an incorrect normalization factor. I have provided specific suggestions to address these critical bugs.
| loss_mask_flatten = loss_mask.values() | ||
|
|
||
| # left-shift the loss mask by one token to align with log_prob | ||
| loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0) |
There was a problem hiding this comment.
The logic to handle sequence boundaries when rolling the loss_mask has been removed. When torch.roll is applied to a flattened tensor of multiple sequences, it incorrectly wraps elements from the beginning of one sequence to the end of the previous one. The original code correctly zeroed out the mask at the last token of each sequence to prevent incorrect loss calculation at these boundaries. This removal will lead to incorrect loss values. Please restore this boundary handling logic.
| loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0) | |
| loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0) | |
| cu_seqlens = log_prob.offsets() | |
| loss_mask_flatten[cu_seqlens[1:] - 1] = 0 |
There was a problem hiding this comment.
Last token in each sequence is masked by first loss_mask in next sequence after left shift.
loss_mask: [0, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 1, 1, 1]
loss_mask_flatten: [0, 1, 1, 1, 0], [0, 0, 1, 1, 0], [1, 1, 1, 0]
And the num of valid tokens is not changed after left shift.
| global_bsz = data["global_batch_size"] | ||
| loss_scale_factor = local_micro_bsz / (global_bsz / self.get_data_parallel_size()) | ||
| loss = loss * loss_scale_factor | ||
| loss = loss * data["num_micro_batch"] / mpu.get_context_parallel_world_size() |
There was a problem hiding this comment.
We should add a note: the sft_loss is used in the FSDP backend (which handles scaling automatically), while it requires manual scaling in the Megatron backend.
There was a problem hiding this comment.
Add note in verl/workers/roles/utils/losses.py
…-project#3994) ### What does this PR do? Normalize sft loss by num_tokens in global batch. TODO - [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline. Credit to @techkang in verl-project#3729.
…-project#3994) ### What does this PR do? Normalize sft loss by num_tokens in global batch. TODO - [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline. Credit to @techkang in verl-project#3729.
…-project#3994) ### What does this PR do? Normalize sft loss by num_tokens in global batch. TODO - [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline. Credit to @techkang in verl-project#3729.
…-project#3994) ### What does this PR do? Normalize sft loss by num_tokens in global batch. TODO - [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline. Credit to @techkang in verl-project#3729.
…-project#3994) ### What does this PR do? Normalize sft loss by num_tokens in global batch. TODO - [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline. Credit to @techkang in verl-project#3729.
…-project#3994) ### What does this PR do? Normalize sft loss by num_tokens in global batch. TODO - [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline. Credit to @techkang in verl-project#3729.
What does this PR do?
Normalize sft loss by num_tokens in global batch.
TODO
pg_lossandvalue_lossin rl trainer pipeline.Credit to @techkang in #3729.