Skip to content

Conversation

@RandMist
Copy link

@RandMist RandMist commented Oct 31, 2025

Description

The fused cross entropy kernel in Transformer Engine uses 16-bit floating point (BF16) for the backward pass when the input is in BF16, whereas Megatron's VocabParallelCrossEntropy performs its computations in FP32. This discrepancy may lead to divergence in some cases.

This PR also reduces one read of logits, which improves the performance by up to 1.25x.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Changed the fused cross entropy backward gradient computation to fp32 for consistency with Megatron's VocabParallelCrossEntropy.
  • Optimized the computation logic to reduce one read/write operation of the logits.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactored the fused cross entropy implementation to separate forward and backward passes, reducing memory traffic by one read/write pass of logits.

Key Changes:

  • Forward pass now computes and saves m_d_X_y tensor (max, sum, target logit, and optional scaled_x_sum) instead of computing gradients
  • Backward pass uses saved m_d_X_y to compute gradients in fp32, improving numerical accuracy
  • Removed gradient computation from cross_entropy_kernel (renamed to cross_entropy_forward_kernel)
  • Created new cross_entropy_backward_kernel that computes gradients using saved statistics
  • Context now saves additional parameters (target, m_d_X_y, label_smoothing, reduce_loss, dist_process_group) needed for backward pass

Issues Found:

  • Critical bug in distributed case with label smoothing: when world_size > 1 and label_smoothing > 0, the code at lines 169-178 uses hardcoded offset of 3 instead of 4, and doesn't gather/sum scaled_x_sum from other ranks

Confidence Score: 2/5

  • This PR has a critical bug in the distributed label smoothing case that will cause incorrect loss computation
  • The refactoring properly separates forward/backward passes and improves numerical precision by computing gradients in fp32. However, there's a critical logic error in cross_entropy_forward_kernel (lines 169-178) where the distributed label smoothing case doesn't properly handle the 4th value (scaled_x_sum) in the m_d_X_y tensor. The offset calculation still uses 3 instead of 4, and scaled_x_sum from other ranks is never gathered/accumulated. This will cause incorrect loss values when using both distributed training and label smoothing together. Note that line 356 has assert False for the distributed case, suggesting this code path may not be fully tested.
  • Pay close attention to transformer_engine/pytorch/triton/cross_entropy.py, specifically the distributed label smoothing logic at lines 169-178

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/cross_entropy.py 4/5 Added saving of m_d_X_y tensor and context parameters (dist_process_group, label_smoothing, reduce_loss) to pass them to backward pass, enabling proper gradient computation
transformer_engine/pytorch/triton/cross_entropy.py 2/5 Refactored forward/backward kernels to separate gradient computation from forward pass; bug in distributed label smoothing case where scaled_x_sum is not gathered from other ranks (line 169-178)

Sequence Diagram

sequenceDiagram
    participant User
    participant CrossEntropyFunction
    participant Forward as cross_entropy_forward
    participant OnlineSoftmax as online_softmax_kernel
    participant ForwardKernel as cross_entropy_forward_kernel
    participant Backward as cross_entropy_backward
    participant BackwardKernel as cross_entropy_backward_kernel

    User->>CrossEntropyFunction: forward(input, target, label_smoothing, etc)
    CrossEntropyFunction->>Forward: cross_entropy_forward(...)
    Forward->>OnlineSoftmax: Compute m, d, X_y, scaled_x_sum (if label_smoothing>0)
    OnlineSoftmax-->>Forward: m_d_X_y tensor (3 or 4 values per row)
    Forward->>ForwardKernel: Compute loss using m_d_X_y
    ForwardKernel-->>Forward: loss values
    Forward-->>CrossEntropyFunction: (loss, m_d_X_y)
    CrossEntropyFunction->>CrossEntropyFunction: save_for_backward(input, target, m_d_X_y)
    CrossEntropyFunction-->>User: loss

    User->>CrossEntropyFunction: backward(grad_output)
    CrossEntropyFunction->>Backward: cross_entropy_backward(input, target, m_d_X_y, grad_output, ...)
    Backward->>BackwardKernel: Compute gradients using saved m_d_X_y
    BackwardKernel->>BackwardKernel: Calculate softmax gradients in fp32
    BackwardKernel->>BackwardKernel: Apply grad_output scaling
    BackwardKernel->>BackwardKernel: Handle target token specially (dx_y)
    BackwardKernel-->>Backward: grad_input (in original dtype)
    Backward-->>CrossEntropyFunction: grad_input
    CrossEntropyFunction-->>User: (grad_input, None, None, ...)
Loading

Additional Comments (1)

  1. transformer_engine/pytorch/triton/cross_entropy.py, line 169-178 (link)

    logic: when label_smoothing > 0, offset should use 4 instead of 3, and scaled_x_sum from other ranks must be gathered and summed

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@RandMist RandMist changed the title change fused cross entropy backward grad to fp32 and reduce one read/… [PYTORCH] change fused cross entropy backward grad to fp32 and reduce one read/… Oct 31, 2025
@RandMist RandMist changed the title [PYTORCH] change fused cross entropy backward grad to fp32 and reduce one read/… [Pytorch] change fused cross entropy backward grad to fp32 and reduce one read/… Oct 31, 2025
@yaox12 yaox12 requested a review from timmoon10 October 31, 2025 09:10
@yaox12
Copy link
Member

yaox12 commented Oct 31, 2025

@sanandaraj5597 @timmoon10 Could you please review? The previous BF16 backward may lead to divergence in some cases (reported by several customers).

@yaox12
Copy link
Member

yaox12 commented Oct 31, 2025

@RandMist You need to sign-off your commits (git commit -s). See this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants