Skip to content

Comments

[trainer] fix: normalize sft loss by num_tokens in global batch#3994

Merged
wuxibin89 merged 4 commits intomainfrom
wuxibin/fix_model_engine_loss
Nov 3, 2025
Merged

[trainer] fix: normalize sft loss by num_tokens in global batch#3994
wuxibin89 merged 4 commits intomainfrom
wuxibin/fix_model_engine_loss

Conversation

@wuxibin89
Copy link
Collaborator

What does this PR do?

Normalize sft loss by num_tokens in global batch.

TODO

  • Normalize pg_loss and value_loss in rl trainer pipeline.

Credit to @techkang in #3729.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to normalize the SFT loss by the number of tokens in the global batch. While the intention is correct, I've identified a few critical issues in the implementation that will lead to incorrect loss calculation. Firstly, the logic to handle sequence boundaries when rolling the loss_mask in sft_loss has been removed, causing incorrect loss terms at sequence boundaries. Secondly, the calculation of batch_num_tokens in both FSDP and Megatron implementations is flawed because it doesn't account for the tokens that are subsequently masked out in the loss function. This results in an incorrect normalization factor. I have provided specific suggestions to address these critical bugs.

loss_mask_flatten = loss_mask.values()

# left-shift the loss mask by one token to align with log_prob
loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic to handle sequence boundaries when rolling the loss_mask has been removed. When torch.roll is applied to a flattened tensor of multiple sequences, it incorrectly wraps elements from the beginning of one sequence to the end of the previous one. The original code correctly zeroed out the mask at the last token of each sequence to prevent incorrect loss calculation at these boundaries. This removal will lead to incorrect loss values. Please restore this boundary handling logic.

Suggested change
loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0)
loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0)
cu_seqlens = log_prob.offsets()
loss_mask_flatten[cu_seqlens[1:] - 1] = 0

Copy link
Collaborator Author

@wuxibin89 wuxibin89 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last token in each sequence is masked by first loss_mask in next sequence after left shift.

loss_mask:         [0, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 1, 1, 1]
loss_mask_flatten: [0, 1, 1, 1, 0], [0, 0, 1, 1, 0], [1, 1, 1, 0]

And the num of valid tokens is not changed after left shift.

global_bsz = data["global_batch_size"]
loss_scale_factor = local_micro_bsz / (global_bsz / self.get_data_parallel_size())
loss = loss * loss_scale_factor
loss = loss * data["num_micro_batch"] / mpu.get_context_parallel_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a note: the sft_loss is used in the FSDP backend (which handles scaling automatically), while it requires manual scaling in the Megatron backend.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add note in verl/workers/roles/utils/losses.py

@wuxibin89 wuxibin89 merged commit b49178f into main Nov 3, 2025
81 of 84 checks passed
@wuxibin89 wuxibin89 deleted the wuxibin/fix_model_engine_loss branch November 3, 2025 15:58
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
…-project#3994)

### What does this PR do?

Normalize sft loss by num_tokens in global batch.

TODO
- [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline.

Credit to @techkang in verl-project#3729.
chenhaiq pushed a commit to The-Hierophant/verl-1 that referenced this pull request Nov 18, 2025
…-project#3994)

### What does this PR do?

Normalize sft loss by num_tokens in global batch.

TODO
- [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline.

Credit to @techkang in verl-project#3729.
wuwendyy pushed a commit to wuwendyy/verl that referenced this pull request Nov 19, 2025
…-project#3994)

### What does this PR do?

Normalize sft loss by num_tokens in global batch.

TODO
- [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline.

Credit to @techkang in verl-project#3729.
albertimff pushed a commit to albertimff/verl that referenced this pull request Dec 1, 2025
…-project#3994)

### What does this PR do?

Normalize sft loss by num_tokens in global batch.

TODO
- [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline.

Credit to @techkang in verl-project#3729.
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
…-project#3994)

### What does this PR do?

Normalize sft loss by num_tokens in global batch.

TODO
- [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline.

Credit to @techkang in verl-project#3729.
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
…-project#3994)

### What does this PR do?

Normalize sft loss by num_tokens in global batch.

TODO
- [ ] Normalize `pg_loss` and `value_loss` in rl trainer pipeline.

Credit to @techkang in verl-project#3729.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants