Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[plugin]hybrid support zero bubble pipeline #6060

Merged
merged 55 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
6911938
hybrid support zbv
flybird11111 Sep 12, 2024
e6da1aa
fix
flybird11111 Sep 12, 2024
6d5b32b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
4404775
fix
flybird11111 Sep 12, 2024
e993144
Merge branch 'support-zbv' of github.com:flybird11111/ColossalAI into…
flybird11111 Sep 12, 2024
3feda3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
6d0122d
fix
flybird11111 Sep 12, 2024
24727de
Merge branch 'support-zbv' of github.com:flybird11111/ColossalAI into…
flybird11111 Sep 12, 2024
9802a7d
Update zero_bubble_pp.py
flybird11111 Sep 12, 2024
9c59e6c
fix
flybird11111 Sep 13, 2024
37d9623
fix-ci
flybird11111 Sep 13, 2024
5965f8b
fix
flybird11111 Sep 13, 2024
f99fc6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
e169f97
fix
flybird11111 Sep 23, 2024
1684d6d
Merge branch 'support-zbv' of github.com:flybird11111/ColossalAI into…
flybird11111 Sep 23, 2024
7f78272
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
885ace7
fix
flybird11111 Sep 23, 2024
95c4b31
Merge branch 'support-zbv' of github.com:flybird11111/ColossalAI into…
flybird11111 Sep 23, 2024
629c76d
fix
flybird11111 Sep 23, 2024
1fcc3a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
88b068f
fix
flybird11111 Sep 23, 2024
6550e80
Merge branch 'support-zbv' of github.com:flybird11111/ColossalAI into…
flybird11111 Sep 23, 2024
0509712
fix
flybird11111 Sep 23, 2024
28f581f
fix
flybird11111 Sep 23, 2024
f9f04e5
fix
flybird11111 Sep 23, 2024
6cf3ebc
[zerobubble]Support ZeroBubble Pipeline (#6034)
duanjunwen Sep 10, 2024
3dd5d59
hybrid support zbv
flybird11111 Sep 12, 2024
fee18d0
fix
flybird11111 Sep 12, 2024
b93d008
fix
flybird11111 Sep 12, 2024
eef5d83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
b55279c
fix
flybird11111 Sep 12, 2024
3efd8d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
1839030
Update zero_bubble_pp.py
flybird11111 Sep 12, 2024
433c8a9
fix
flybird11111 Sep 13, 2024
3fb1e42
fix-ci
flybird11111 Sep 13, 2024
cd2e34b
fix
flybird11111 Sep 13, 2024
4e0f212
fix
flybird11111 Sep 23, 2024
fa358b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
0bde0bf
fix
flybird11111 Sep 23, 2024
a2f187b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
a9ac438
fix
flybird11111 Sep 23, 2024
8c3cdea
fix
flybird11111 Sep 23, 2024
f07b305
fix
flybird11111 Sep 23, 2024
29ec566
fix
flybird11111 Sep 23, 2024
b57e78d
fix
flybird11111 Sep 23, 2024
083ea31
fix
flybird11111 Sep 27, 2024
83d0766
fix
flybird11111 Sep 27, 2024
d3e83c3
fix
flybird11111 Sep 27, 2024
8ad2b72
Merge branch 'feature/zerobubble' into support-zbv
flybird11111 Sep 27, 2024
9b3c266
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2024
f288930
fix
flybird11111 Sep 27, 2024
454e236
Merge branch 'support-zbv' of github.com:flybird11111/ColossalAI into…
flybird11111 Sep 27, 2024
93557f5
fix
flybird11111 Sep 27, 2024
90a82e2
fix
flybird11111 Sep 27, 2024
7403580
fix
flybird11111 Sep 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:

- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install -v .
pip install --no-cache-dir -r requirements/requirements-test.txt

- name: Store Colossal-AI Cache
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
if: steps.check-avai.outputs.avai == 'true'
run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
BUILD_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install -v .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install --no-cache-dir -r requirements/requirements-test.txt

Expand Down
2 changes: 1 addition & 1 deletion colossalai/amp/naive_amp/mixed_precision_mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def zero_grad(self):
dtype: torch.dtype

@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
"""Called before backward.

Args:
Expand Down
13 changes: 9 additions & 4 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,18 @@ def __init__(
master_params.append(master_p)
group["params"] = master_params

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)

def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
Expand Down
4 changes: 2 additions & 2 deletions colossalai/booster/mixed_precision/fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(
growth_interval=growth_interval,
)

def backward(self, loss: Tensor, *args, **kwargs) -> None:
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs)
Expand Down
63 changes: 42 additions & 21 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
super().__init__(optim)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.

Expand All @@ -306,7 +306,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""

# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.

Expand All @@ -332,7 +332,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""

# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -512,7 +512,7 @@ def __init__(
max_norm=max_norm,
)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.

Expand All @@ -529,7 +529,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -538,7 +538,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.

Expand All @@ -554,7 +554,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -768,7 +768,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
else:
return

def backward(self, loss, retain_graph=False):
def backward(self, loss, inputs=None, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.

Expand All @@ -784,7 +784,7 @@ def backward(self, loss, retain_graph=False):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
super().backward(loss, inputs=inputs, retain_graph=retain_graph)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -793,7 +793,7 @@ def backward(self, loss, retain_graph=False):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.

Expand All @@ -809,7 +809,7 @@ def backward_by_grad(self, tensor, grad):
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -1013,6 +1013,7 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
Expand All @@ -1029,6 +1030,9 @@ def __init__(
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"

assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
Expand Down Expand Up @@ -1088,29 +1092,39 @@ def __init__(
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
if pp_style == "zbv":
self.logger.warning(
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
)
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=(pp_style == "interleaved"),
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
use_zbv=(pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)

if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
Expand All @@ -1119,12 +1133,20 @@ def __init__(
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
Expand Down Expand Up @@ -1236,7 +1258,6 @@ def configure(

# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
Expand Down Expand Up @@ -1352,7 +1373,7 @@ def execute_pipeline(
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step(
outputs = self.scheduler.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)

Expand Down
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def __init__(
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
Expand All @@ -304,7 +304,7 @@ def __init__(

if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
Expand All @@ -313,7 +313,7 @@ def __init__(
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs):
"""
self.optim.zero_grad(*args, **kwargs)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Expand Down
6 changes: 5 additions & 1 deletion colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool:
if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
# use zero bubble pipeline
if self.use_zbv:
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1

@property
def num_stages(self) -> int:
Expand Down
12 changes: 9 additions & 3 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.norm)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)

else:
Expand Down Expand Up @@ -351,7 +353,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

Expand Down Expand Up @@ -404,7 +408,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.score)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers

Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor):
loss.backward()
self._post_backward()

def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,14 @@ def backward(self, loss: torch.Tensor):
loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss)

def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
def backward_by_grad(
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
self.module.backward_by_grad(tensor, grad)

def _maybe_move_fp32_params(self):
Expand Down
Loading
Loading