Commit 53e91a0
authored
PyTorch-compatible backward API (#7665)
Currently DeepSpeed's backward API has more constraints compared to
PyTorch's normal backward API.
Here is the usage as described in the documentation:
```python
loss = model_engine(batch)
model_engine.backward(loss)
```
In this example,
1. Only accepts a (scalar) loss value
1. Need to call engine's backward API
In contrast, in standard PyTorch, you can do:
```python
output = model(batch)
output.backward(out_grad)
```
There are several use cases that rely on this flexibility. For example,
combining multiple models or using loss functions defined separately
from the main model.
If you attempt the same pattern with a DeepSpeed engine, some
preprocessing and postprocessing steps will be silently skipped, which
can lead to incorrect results.
The
[document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss)
explains we can call `_backward_epilogue` manually (possibly
`backward_prologue` as well). However, it's easy for users to miss these
calls, and passing a non-scalar gradient is still not supported.
This PR introduces the same `.backward()` behavior as PyTorch, allowing
.backward() to be called directly on tensors and supporting non-scalar
outputs.
To implement post-backward hooks, we had to use some torch internal
APIs. See
[comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424)
for more details. When the internal APIs are not available, DeepSpeed
engine only accepts the traditional way `model_engine.backward(loss)`.
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>1 parent 51dc888 commit 53e91a0
File tree
11 files changed
+1726
-129
lines changed- deepspeed/runtime
- fp16
- zenflow
- zero
- docs/code-docs/source
- tests/unit
- runtime
- v1/zero
11 files changed
+1726
-129
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
| 8 | + | |
8 | 9 | | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
| 14 | + | |
13 | 15 | | |
14 | 16 | | |
15 | 17 | | |
| |||
18 | 20 | | |
19 | 21 | | |
20 | 22 | | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
21 | 28 | | |
22 | 29 | | |
23 | 30 | | |
| |||
79 | 86 | | |
80 | 87 | | |
81 | 88 | | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
316 | 316 | | |
317 | 317 | | |
318 | 318 | | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
| 319 | + | |
328 | 320 | | |
329 | | - | |
330 | 321 | | |
| 322 | + | |
331 | 323 | | |
332 | 324 | | |
333 | 325 | | |
| |||
0 commit comments