-
Notifications
You must be signed in to change notification settings - Fork 540
[Pytorch] change fused cross entropy backward grad to fp32 and reduce one read/… #2325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…write pass of logit
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactored the fused cross entropy implementation to separate forward and backward passes, reducing memory traffic by one read/write pass of logits.
Key Changes:
- Forward pass now computes and saves
m_d_X_ytensor (max, sum, target logit, and optional scaled_x_sum) instead of computing gradients - Backward pass uses saved
m_d_X_yto compute gradients in fp32, improving numerical accuracy - Removed gradient computation from
cross_entropy_kernel(renamed tocross_entropy_forward_kernel) - Created new
cross_entropy_backward_kernelthat computes gradients using saved statistics - Context now saves additional parameters (
target,m_d_X_y,label_smoothing,reduce_loss,dist_process_group) needed for backward pass
Issues Found:
- Critical bug in distributed case with label smoothing: when
world_size > 1andlabel_smoothing > 0, the code at lines 169-178 uses hardcoded offset of3instead of4, and doesn't gather/sumscaled_x_sumfrom other ranks
Confidence Score: 2/5
- This PR has a critical bug in the distributed label smoothing case that will cause incorrect loss computation
- The refactoring properly separates forward/backward passes and improves numerical precision by computing gradients in fp32. However, there's a critical logic error in
cross_entropy_forward_kernel(lines 169-178) where the distributed label smoothing case doesn't properly handle the 4th value (scaled_x_sum) in them_d_X_ytensor. The offset calculation still uses 3 instead of 4, andscaled_x_sumfrom other ranks is never gathered/accumulated. This will cause incorrect loss values when using both distributed training and label smoothing together. Note that line 356 hasassert Falsefor the distributed case, suggesting this code path may not be fully tested. - Pay close attention to
transformer_engine/pytorch/triton/cross_entropy.py, specifically the distributed label smoothing logic at lines 169-178
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/cross_entropy.py | 4/5 | Added saving of m_d_X_y tensor and context parameters (dist_process_group, label_smoothing, reduce_loss) to pass them to backward pass, enabling proper gradient computation |
| transformer_engine/pytorch/triton/cross_entropy.py | 2/5 | Refactored forward/backward kernels to separate gradient computation from forward pass; bug in distributed label smoothing case where scaled_x_sum is not gathered from other ranks (line 169-178) |
Sequence Diagram
sequenceDiagram
participant User
participant CrossEntropyFunction
participant Forward as cross_entropy_forward
participant OnlineSoftmax as online_softmax_kernel
participant ForwardKernel as cross_entropy_forward_kernel
participant Backward as cross_entropy_backward
participant BackwardKernel as cross_entropy_backward_kernel
User->>CrossEntropyFunction: forward(input, target, label_smoothing, etc)
CrossEntropyFunction->>Forward: cross_entropy_forward(...)
Forward->>OnlineSoftmax: Compute m, d, X_y, scaled_x_sum (if label_smoothing>0)
OnlineSoftmax-->>Forward: m_d_X_y tensor (3 or 4 values per row)
Forward->>ForwardKernel: Compute loss using m_d_X_y
ForwardKernel-->>Forward: loss values
Forward-->>CrossEntropyFunction: (loss, m_d_X_y)
CrossEntropyFunction->>CrossEntropyFunction: save_for_backward(input, target, m_d_X_y)
CrossEntropyFunction-->>User: loss
User->>CrossEntropyFunction: backward(grad_output)
CrossEntropyFunction->>Backward: cross_entropy_backward(input, target, m_d_X_y, grad_output, ...)
Backward->>BackwardKernel: Compute gradients using saved m_d_X_y
BackwardKernel->>BackwardKernel: Calculate softmax gradients in fp32
BackwardKernel->>BackwardKernel: Apply grad_output scaling
BackwardKernel->>BackwardKernel: Handle target token specially (dx_y)
BackwardKernel-->>Backward: grad_input (in original dtype)
Backward-->>CrossEntropyFunction: grad_input
CrossEntropyFunction-->>User: (grad_input, None, None, ...)
Additional Comments (1)
-
transformer_engine/pytorch/triton/cross_entropy.py, line 169-178 (link)logic: when
label_smoothing > 0, offset should use 4 instead of 3, andscaled_x_sumfrom other ranks must be gathered and summed
2 files reviewed, 1 comment
|
@sanandaraj5597 @timmoon10 Could you please review? The previous BF16 backward may lead to divergence in some cases (reported by several customers). |
Description
The fused cross entropy kernel in Transformer Engine uses 16-bit floating point (BF16) for the backward pass when the input is in BF16, whereas Megatron's VocabParallelCrossEntropy performs its computations in FP32. This discrepancy may lead to divergence in some cases.
This PR also reduces one read of
logits, which improves the performance by up to 1.25x.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: