Skip to content

Conversation

@tohtana
Copy link
Contributor

@tohtana tohtana commented Nov 3, 2025

Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API.
Here is the usage as described in the documentation:

    loss = model_engine(batch)
    model_engine.backward(loss)

In this example,

  1. Only accepts a (scalar) loss value
  2. Need to call engine's backward API

In contrast, in standard PyTorch, you can do:

    output = model(batch)
    output.backward(out_grad)

There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model.

If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results.

The document explains we can call _backward_epilogue manually (possibly backward_prologue as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported.

This PR introduces the same .backward() behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs.

To implement post-backward hooks, we had to use some torch internal APIs. See comments for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way model_engine.backward(loss).

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
@sfc-gh-truwase
Copy link
Collaborator

@tohtana, this is a very exciting usability improvement. Please remember to update the documentation.

tohtana and others added 18 commits November 6, 2025 16:24
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
@tohtana tohtana marked this pull request as ready for review November 14, 2025 02:13
@tohtana
Copy link
Contributor Author

tohtana commented Nov 14, 2025

@sfc-gh-truwase I think this PR is now ready for review, though the latest change on HF transformer causes an error with test_zero_nesting_init.py::TestNestedParallelInit::test_nested_parallel_init.

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
@tohtana tohtana enabled auto-merge (squash) November 19, 2025 00:01
@tohtana tohtana merged commit 53e91a0 into deepspeedai:master Nov 19, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants