Skip to content

Commit 53e91a0

Browse files
authored
PyTorch-compatible backward API (#7665)
Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent 51dc888 commit 53e91a0

File tree

11 files changed

+1726
-129
lines changed

11 files changed

+1726
-129
lines changed

deepspeed/runtime/base_optimizer.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
import os
77
import torch
8+
from typing import Any
89

910
from deepspeed.utils import logger
1011
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
1112
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage
1213
from deepspeed.runtime.torch_autocast import get_comm_dtype, is_autocast_initialized
14+
from deepspeed.runtime.utils import maybe_loss_for_backward
1315

1416

1517
class DeepSpeedOptimizer(object):
@@ -18,6 +20,11 @@ class DeepSpeedOptimizer(object):
1820

1921
class ZeROOptimizer(DeepSpeedOptimizer):
2022

23+
def __init__(self):
24+
self._remaining_grad_acc_hooks = 0
25+
self._grad_acc_post_hooks = []
26+
self._backward_active_depth = 0
27+
2128
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
2229
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
2330
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
@@ -79,3 +86,73 @@ def get_param_comm_dtype(self, param):
7986
return get_comm_dtype(param)
8087
else:
8188
return self.communication_data_type
89+
90+
def needs_scaler(self) -> bool:
91+
"""
92+
Check if this optimizer requires loss scaling for correct backward pass.
93+
94+
Returns True if any of the following conditions are met:
95+
- Custom loss scaler is enabled
96+
- torch.autocast gradient scaler is active (fp16 only)
97+
- Dynamic loss scaling is enabled (fp16 with DeepSpeed's loss scaler)
98+
99+
Returns False for bf16 or fp32, which don't require gradient scaling.
100+
"""
101+
return (self.custom_loss_scaler or self.torch_autocast_gradscaler is not None
102+
or (hasattr(self, 'dynamic_loss_scale') and self.dynamic_loss_scale))
103+
104+
def scale_if_loss(self, value: Any) -> Any:
105+
"""
106+
Applies loss scaling to the input value if it is a loss tensor.
107+
"""
108+
if maybe_loss_for_backward(value):
109+
if self.custom_loss_scaler:
110+
return self.external_loss_scale * value
111+
if self.torch_autocast_gradscaler:
112+
return self.torch_autocast_gradscaler.scale(value)
113+
return self.loss_scaler.scale_loss(value)
114+
115+
return value
116+
117+
def backward_prologue(self):
118+
pass
119+
120+
def backward_epilogue(self, **kwargs):
121+
pass
122+
123+
def backward(self, loss, **kwargs):
124+
assert maybe_loss_for_backward(loss), "Optimizer's backward() only accepts a scalar tensor"
125+
126+
scaled_loss = self.backward_prologue(loss)
127+
retain_graph = kwargs.pop('retain_graph', False)
128+
self.enter_backward()
129+
scaled_loss.backward(retain_graph=retain_graph)
130+
self.backward_epilogue()
131+
self.exit_backward()
132+
133+
def register_grad_acc_post_hook(self, hook):
134+
self._grad_acc_post_hooks.append(hook)
135+
136+
def unregister_grad_acc_post_hooks(self):
137+
self._grad_acc_post_hooks = []
138+
139+
def run_grad_acc_post_hooks(self):
140+
# Custom autograd Functions (e.g., TiledFusedLogitsLoss) can invoke
141+
# `torch.autograd.backward()` from their *forward* pass before the user
142+
# ever calls `engine.backward(loss)`. Those early backward calls still
143+
# trigger ZeRO's grad hooks, but we must not run the engine's
144+
# post-backward logic (which reduces/clears grads) until the outer/user
145+
# backward is active. The depth guard filters out only those pre-user
146+
# invocations while still allowing backward calls that happen during
147+
# the real user backward.
148+
if self._backward_active_depth == 0:
149+
return
150+
for hook in self._grad_acc_post_hooks:
151+
hook()
152+
153+
def enter_backward(self):
154+
self._backward_active_depth += 1
155+
156+
def exit_backward(self):
157+
if self._backward_active_depth > 0:
158+
self._backward_active_depth -= 1

deepspeed/runtime/bf16_optimizer.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,18 +316,10 @@ def step(self, closure=None):
316316

317317
self.clear_hp_grads()
318318

319-
def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
320-
"""Perform a backward pass and copy the low-precision gradients to the
321-
high-precision copy.
322-
323-
We copy/accumulate to the high-precision grads now to prevent accumulating in the
324-
bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)
325-
326-
The low-precision grads are deallocated during this procedure.
327-
"""
319+
def backward_prologue(self):
328320
self.clear_lp_grads()
329-
loss.backward(retain_graph=retain_graph, **bwd_kwargs)
330321

322+
def backward_epilogue(self, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
331323
if update_hp_grads:
332324
self.update_hp_grads(clear_lp_grads=clear_lp_grads)
333325

0 commit comments

Comments
 (0)