Skip to content

Add step-based scheduling for GradientAccumulationScheduler#21583

Open
adityaroy10 wants to merge 3 commits intoLightning-AI:masterfrom
adityaroy10:feature/stepwise-gradient-accumulation-scheduler
Open

Add step-based scheduling for GradientAccumulationScheduler#21583
adityaroy10 wants to merge 3 commits intoLightning-AI:masterfrom
adityaroy10:feature/stepwise-gradient-accumulation-scheduler

Conversation

@adityaroy10
Copy link
Copy Markdown

@adityaroy10 adityaroy10 commented Mar 12, 2026

Fixes #21534

Made-with: Cursor

What does this PR do?

Adds step-based scheduling for GradientAccumulationScheduler via a new mode parameter ("epoch" or "step"). When mode="step", the scheduling dictionary keys are interpreted as global training steps instead of epochs, enabling finer-grained control for single-epoch pretraining and long runs where epoch-based scheduling is not sufficient.

  • New parameter: mode: Literal["epoch", "step"] = "epoch" (default "epoch" for backward compatibility).
  • Epoch mode (default): Unchanged behavior; keys are zero-indexed epochs; updates in on_train_epoch_start.
  • Step mode: Keys are global steps; accumulation factor is updated at each batch in on_train_batch_start from trainer.global_step.

Motivation: Issue #21534; see also downstream use cases such as batch size scheduling in coral-nlp/bertblocks (issue #51). Variable accumulation (epoch or step) remains unsupported by strategies such as DeepSpeed; existing checks are unchanged.

Dependencies: None.

Does your PR introduce any breaking changes? No. Default mode="epoch" preserves existing behavior.

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary) — Updated docs/source-pytorch/common/gradient_accumulation.rst and callback docstring.
  • Did you write any new necessary tests? (not for typos and docs) — Added tests for invalid mode, step-mode validation message, and step-mode scheduling behavior.
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request? — None.
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors) — Added under [unreleased] → Added.

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21583.org.readthedocs.build/en/21583/

@github-actions github-actions bot added docs Documentation related pl Generic label for PyTorch Lightning package labels Mar 12, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 16, 2026

Codecov Report

❌ Patch coverage is 95.00000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 79%. Comparing base (283ce77) to head (6ba4dd6).
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (283ce77) and HEAD (6ba4dd6). Click for more details.

HEAD has 2199 uploads less than BASE
Flag BASE (283ce77) HEAD (6ba4dd6)
cpu 524 32
python 48 3
lightning_fabric 141 0
pytest 261 0
lightning 239 14
python3.11 95 6
python3.12 142 9
python3.10 48 3
python3.13 48 3
python3.12.7 143 8
pytorch2.3 24 3
pytest-full 263 32
pytorch2.2.2 23 3
pytorch_lightning 144 18
pytorch2.6 24 3
pytorch2.1 48 6
pytorch2.9 24 3
pytorch2.5.1 24 2
pytorch2.8 48 6
pytorch2.4.1 24 3
pytorch2.7 24 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21583     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       24078    24030     -48     
=========================================
- Hits        20863    18964   -1899     
- Misses       3215     5066   +1851     

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add stepwise scheduling for GradientAccumulationScheduler

1 participant