Skip to content

Commit 5965f8b

Browse files
committed
fix
[pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix
1 parent 37d9623 commit 5965f8b

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

colossalai/pipeline/schedule/zero_bubble_pp.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import nullcontext
12
from functools import partial
23
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
34

@@ -11,9 +12,6 @@
1112
from colossalai.pipeline.p2p import PipelineP2PCommunication
1213
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
1314
from colossalai.pipeline.stage_manager import PipelineStageManager
14-
from colossalai.zero.low_level import LowLevelZeroOptimizer
15-
from contextlib import nullcontext
16-
1715
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
1816
from .base import PipelineSchedule
1917

@@ -487,7 +485,13 @@ def backward_b_step(
487485
assert output_obj_grad is None
488486

489487
input_obj_ = input_obj["hidden_states"]
490-
ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext()
488+
489+
# Attempt to disable gradient synchronization when using the LowLevelZeroPlugin.
490+
try:
491+
ctx = optimizer.no_sync()
492+
except Exception as e:
493+
ctx = nullcontext()
494+
491495
with ctx:
492496
if output_obj_grad is None:
493497
optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True)

0 commit comments

Comments
 (0)