Skip to content

chunked cross entropy loss #2625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 24, 2025
Merged

chunked cross entropy loss #2625

merged 4 commits into from
Jun 24, 2025

Conversation

winglian
Copy link
Collaborator

@winglian winglian commented May 3, 2025

in testing with qwen3, reduced single gpu vram from 22-26GB -> 15-18GB

Summary by CodeRabbit

  • New Features
    • Introduced an option to enable memory-efficient chunked cross-entropy loss during training, with configurable chunk sizes.
    • Added new configuration fields to control chunked cross-entropy behavior.
  • Bug Fixes
    • Improved import handling for plugin integration.
  • Tests
    • Added tests to verify the correctness of the chunked cross-entropy loss implementation.

Copy link

codecov bot commented May 3, 2025

Codecov Report

Attention: Patch coverage is 80.35714% with 11 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/monkeypatch/loss/chunked.py 84.78% 7 Missing ⚠️
src/axolotl/loaders/patch_manager.py 42.85% 4 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice implementation. Left some comments on parts I'm unsure of

Comment on lines +69 to +71
if reduction == "sum":
return total_loss
return total_loss / total_elements
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should check if mean for second return and raise value/notimplemented error for other reduction

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for in the get_* function below.

ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Could you explain the relationship between num_items_in_batch and reduction?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean, we need to add a test with num_items_in_batch for reduction=sum?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the stupid gradient accumulation think that Daniel Han brought up that forced this pattern.

@winglian winglian added this to the Axolotl v0.10.0 milestone May 7, 2025

def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
loss_fn_ce.compute_cross_entropy = torch.compile(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question, would this work for all hardware? Do we need to check for compatibility?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah unlike flex attention there's no cudagraphs/autotuning here so inductor should be pretty hardware friendly

Copy link
Contributor

@SalmanMohammadi SalmanMohammadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good - the loss/grad norms lined up OK with the run you tested on?

Let's also make sure we add the license/code attribution link


def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
loss_fn_ce.compute_cross_entropy = torch.compile(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah unlike flex attention there's no cudagraphs/autotuning here so inductor should be pretty hardware friendly

@NanoCode012
Copy link
Collaborator

Rebased. Added license, arg to docs, and use TORCH_COMPILE_BACKEND for backend if available

Copy link
Contributor

github-actions bot commented May 14, 2025

@github-actions github-actions bot temporarily deployed to preview May 14, 2025 10:12 Inactive
Copy link

coderabbitai bot commented Jun 23, 2025

Walkthrough

This update introduces a memory-efficient, chunked cross-entropy loss function for large language models. It adds configuration options to enable and tune chunked loss, implements the chunked loss module and patching logic, integrates it into the patch manager, updates import paths, and provides tests to verify correctness against standard cross-entropy.

Changes

File(s) Change Summary
src/axolotl/monkeypatch/loss/chunked.py Added chunked cross-entropy loss class, patching, and helper functions for memory efficiency.
src/axolotl/loaders/patch_manager.py Added method to conditionally apply chunked cross-entropy patch based on config.
src/axolotl/utils/schemas/config.py Added config fields for enabling chunked cross-entropy and specifying number of chunks.
src/axolotl/integrations/cut_cross_entropy/init.py Changed import of cce_patch to use relative import syntax.
tests/test_chunked_xentropy.py Added tests comparing chunked cross-entropy loss to standard loss for correctness.

Sequence Diagram(s)

sequenceDiagram
    participant Config
    participant PatchManager
    participant ChunkedLossModule
    participant TransformersLoss

    Config->>PatchManager: chunked_cross_entropy = True
    PatchManager->>ChunkedLossModule: patch_chunked_ce_loss_fn(num_chunks)
    ChunkedLossModule->>TransformersLoss: Patch ForCausalLMLoss with chunked loss
    TransformersLoss-->>PatchManager: Uses chunked loss during training
Loading

Poem

In the land of memory, chunks now rule the day,
Cross-entropy loss, in pieces, finds its way.
Fewer bytes are needed, the models run with glee,
A patch here, a config there, as smooth as can be.
With tests to confirm, our code hops ahead—
🐇✨ Chunked loss is here, let training be widespread!

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions github-actions bot temporarily deployed to preview June 23, 2025 20:36 Inactive
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
src/axolotl/utils/schemas/config.py (1)

526-537: Configuration schema additions are well-structured.

The new configuration fields follow the established pattern with proper typing, default values, and descriptive JSON schema metadata. The field names and descriptions clearly convey their purpose.

src/axolotl/monkeypatch/loss/chunked.py (4)

38-38: Consider using "mean" as default reduction to match PyTorch convention.

The past review comment correctly noted that PyTorch's CrossEntropyLoss defaults to "mean" reduction. Consider aligning with this convention for consistency.

-    def forward(
-        self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
-    ) -> torch.Tensor:
+    def forward(
+        self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="mean"
+    ) -> torch.Tensor:

1-3: Add license block for BSD3 licensed source.

The past review comment correctly identified that a license block is needed since this code is derived from torchtune which uses BSD3 license.

Add the appropriate license header:

+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
 """
 chunked ce loss
 """

76-78: Use environment variable for torch.compile backend selection.

The backend should be configurable via environment variable to match torchtune's pattern and provide flexibility for different hardware configurations.

+import os
+
 def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
     loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
+    backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
     loss_fn_ce.compute_cross_entropy = torch.compile(
-        loss_fn_ce.compute_cross_entropy, backend="inductor"
+        loss_fn_ce.compute_cross_entropy, backend=backend
     )
     return loss_fn_ce

69-71: Add validation for reduction parameter.

The past review comment correctly identified that the code should validate the reduction parameter and raise appropriate errors for unsupported values.

+        if reduction not in ["sum", "mean"]:
+            raise ValueError(f"Unsupported reduction: {reduction}. Only 'sum' and 'mean' are supported.")
+
         if reduction == "sum":
             return total_loss
         return total_loss / total_elements
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1d8f500 and 4692413.

📒 Files selected for processing (5)
  • src/axolotl/integrations/cut_cross_entropy/__init__.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (2 hunks)
  • src/axolotl/monkeypatch/loss/chunked.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • tests/test_chunked_xentropy.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/loss/chunked.py

66-66: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

⏰ Context from checks skipped due to timeout of 90000ms (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: preview
🔇 Additional comments (5)
src/axolotl/integrations/cut_cross_entropy/__init__.py (1)

74-76: Import change aligns with established patterns.

The change to relative import is consistent with the maintainer's explanation that absolute imports can break in colab environments and relative imports are the standard pattern for plugins.

src/axolotl/loaders/patch_manager.py (2)

82-89: Well-structured patch integration following established patterns.

The implementation correctly checks the configuration flag, conditionally imports the patch function, and handles the optional chunk count parameter. The integration follows the same pattern as other patch methods in the class.


53-53: Proper placement in the patch application sequence.

The chunked cross entropy patch is appropriately positioned in the pre-model load patches sequence.

tests/test_chunked_xentropy.py (2)

12-22: Well-designed test fixture for realistic chunked loss scenario.

The fixture creates a realistic scenario with large vocabulary (256K) and long sequences (2048 tokens) that would benefit from memory-efficient chunking. The parameters are appropriate for testing the chunked implementation.


25-40: Solid correctness verification against standard cross entropy.

The test correctly compares the chunked implementation against PyTorch's standard cross entropy using flattened tensors. The tolerance levels (1e-2) are appropriate for numerical comparison between different computational paths.


# compute one chunk at a time
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add explicit strict parameter to zip() call.

Static analysis correctly identifies missing explicit strict= parameter. For safety, use strict=True to ensure logits and labels chunks have the same length.

-        for logits_chunk, labels_chunk in zip(logits, labels):
+        for logits_chunk, labels_chunk in zip(logits, labels, strict=True):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for logits_chunk, labels_chunk in zip(logits, labels):
for logits_chunk, labels_chunk in zip(logits, labels, strict=True):
🧰 Tools
🪛 Ruff (0.11.9)

66-66: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/loss/chunked.py at line 66, the zip() call lacks the
explicit strict parameter. Modify the zip() call to include strict=True to
enforce that logits and labels chunks have the same length, improving safety and
correctness.

@winglian winglian merged commit 12c8268 into main Jun 24, 2025
15 of 16 checks passed
@winglian winglian deleted the chunked-cel branch June 24, 2025 03:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants