-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
chunked cross entropy loss #2625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice implementation. Left some comments on parts I'm unsure of
if reduction == "sum": | ||
return total_loss | ||
return total_loss / total_elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should check if mean
for second return and raise value/notimplemented error for other reduction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for in the get_*
function below.
ignore_index: int = -100, | ||
**kwargs, | ||
): # pylint: disable=unused-argument | ||
reduction = "sum" if num_items_in_batch is not None else "mean" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? Could you explain the relationship between num_items_in_batch
and reduction
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that mean, we need to add a test with num_items_in_batch
for reduction=sum
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the stupid gradient accumulation think that Daniel Han brought up that forced this pattern.
|
||
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): | ||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) | ||
loss_fn_ce.compute_cross_entropy = torch.compile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another question, would this work for all hardware? Do we need to check for compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah unlike flex attention there's no cudagraphs/autotuning here so inductor should be pretty hardware friendly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good - the loss/grad norms lined up OK with the run you tested on?
Let's also make sure we add the license/code attribution link
|
||
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): | ||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) | ||
loss_fn_ce.compute_cross_entropy = torch.compile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah unlike flex attention there's no cudagraphs/autotuning here so inductor should be pretty hardware friendly
Rebased. Added license, arg to docs, and use TORCH_COMPILE_BACKEND for backend if available |
WalkthroughThis update introduces a memory-efficient, chunked cross-entropy loss function for large language models. It adds configuration options to enable and tune chunked loss, implements the chunked loss module and patching logic, integrates it into the patch manager, updates import paths, and provides tests to verify correctness against standard cross-entropy. Changes
Sequence Diagram(s)sequenceDiagram
participant Config
participant PatchManager
participant ChunkedLossModule
participant TransformersLoss
Config->>PatchManager: chunked_cross_entropy = True
PatchManager->>ChunkedLossModule: patch_chunked_ce_loss_fn(num_chunks)
ChunkedLossModule->>TransformersLoss: Patch ForCausalLMLoss with chunked loss
TransformersLoss-->>PatchManager: Uses chunked loss during training
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
src/axolotl/utils/schemas/config.py (1)
526-537
: Configuration schema additions are well-structured.The new configuration fields follow the established pattern with proper typing, default values, and descriptive JSON schema metadata. The field names and descriptions clearly convey their purpose.
src/axolotl/monkeypatch/loss/chunked.py (4)
38-38
: Consider using "mean" as default reduction to match PyTorch convention.The past review comment correctly noted that PyTorch's CrossEntropyLoss defaults to "mean" reduction. Consider aligning with this convention for consistency.
- def forward( - self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" - ) -> torch.Tensor: + def forward( + self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="mean" + ) -> torch.Tensor:
1-3
: Add license block for BSD3 licensed source.The past review comment correctly identified that a license block is needed since this code is derived from torchtune which uses BSD3 license.
Add the appropriate license header:
+# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + """ chunked ce loss """
76-78
: Use environment variable for torch.compile backend selection.The backend should be configurable via environment variable to match torchtune's pattern and provide flexibility for different hardware configurations.
+import os + def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") loss_fn_ce.compute_cross_entropy = torch.compile( - loss_fn_ce.compute_cross_entropy, backend="inductor" + loss_fn_ce.compute_cross_entropy, backend=backend ) return loss_fn_ce
69-71
: Add validation for reduction parameter.The past review comment correctly identified that the code should validate the reduction parameter and raise appropriate errors for unsupported values.
+ if reduction not in ["sum", "mean"]: + raise ValueError(f"Unsupported reduction: {reduction}. Only 'sum' and 'mean' are supported.") + if reduction == "sum": return total_loss return total_loss / total_elements
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/axolotl/integrations/cut_cross_entropy/__init__.py
(1 hunks)src/axolotl/loaders/patch_manager.py
(2 hunks)src/axolotl/monkeypatch/loss/chunked.py
(1 hunks)src/axolotl/utils/schemas/config.py
(1 hunks)tests/test_chunked_xentropy.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/loss/chunked.py
66-66: zip()
without an explicit strict=
parameter
Add explicit value for parameter strict=
(B905)
⏰ Context from checks skipped due to timeout of 90000ms (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: preview
🔇 Additional comments (5)
src/axolotl/integrations/cut_cross_entropy/__init__.py (1)
74-76
: Import change aligns with established patterns.The change to relative import is consistent with the maintainer's explanation that absolute imports can break in colab environments and relative imports are the standard pattern for plugins.
src/axolotl/loaders/patch_manager.py (2)
82-89
: Well-structured patch integration following established patterns.The implementation correctly checks the configuration flag, conditionally imports the patch function, and handles the optional chunk count parameter. The integration follows the same pattern as other patch methods in the class.
53-53
: Proper placement in the patch application sequence.The chunked cross entropy patch is appropriately positioned in the pre-model load patches sequence.
tests/test_chunked_xentropy.py (2)
12-22
: Well-designed test fixture for realistic chunked loss scenario.The fixture creates a realistic scenario with large vocabulary (256K) and long sequences (2048 tokens) that would benefit from memory-efficient chunking. The parameters are appropriate for testing the chunked implementation.
25-40
: Solid correctness verification against standard cross entropy.The test correctly compares the chunked implementation against PyTorch's standard cross entropy using flattened tensors. The tolerance levels (1e-2) are appropriate for numerical comparison between different computational paths.
|
||
# compute one chunk at a time | ||
total_loss = 0.0 | ||
for logits_chunk, labels_chunk in zip(logits, labels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add explicit strict parameter to zip() call.
Static analysis correctly identifies missing explicit strict=
parameter. For safety, use strict=True
to ensure logits and labels chunks have the same length.
- for logits_chunk, labels_chunk in zip(logits, labels):
+ for logits_chunk, labels_chunk in zip(logits, labels, strict=True):
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for logits_chunk, labels_chunk in zip(logits, labels): | |
for logits_chunk, labels_chunk in zip(logits, labels, strict=True): |
🧰 Tools
🪛 Ruff (0.11.9)
66-66: zip()
without an explicit strict=
parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/loss/chunked.py at line 66, the zip() call lacks the
explicit strict parameter. Modify the zip() call to include strict=True to
enforce that logits and labels chunks have the same length, improving safety and
correctness.
in testing with qwen3, reduced single gpu vram from 22-26GB -> 15-18GB
Summary by CodeRabbit