Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpointable_layers Logic #6881

Merged
merged 8 commits into from
Jan 4, 2025
14 changes: 12 additions & 2 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def forward(self, inputs):
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models,
ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are
considered checkpointable. Defaults to None.
dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
"""

Expand Down Expand Up @@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs):
# because only non_reentrant_checkpoint can accept inputs with requires_grad=False
# otherwise, the backward of the embedding layer won't receive gradients.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
# For GPT models, checkpoint both transformer layers and any additional
# layers specified in checkpointable_layers (if provided)
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or (
self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers)
for f in funcs)

if self.checkpointable_layers is not None:
# For non-GPT models, only checkpoint layers specified in checkpointable_layers
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)

# Default behavior: checkpoint any layer that has parameters
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
import deepspeed
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.accelerator import get_accelerator
from copy import deepcopy
from unit.common import DistributedTest
Expand Down Expand Up @@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output):
else:
ordering += [torch.is_tensor(non_tensor_output)]
_test_activation_checkpoint_ordering(module, ordering, inputs)


class TestCheckpointableLayersConfig(DistributedTest):
loadams marked this conversation as resolved.
Show resolved Hide resolved
world_size = 1

def test_gpt2_checkpointable_layers(self):
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU accelerator does not support this test yet")

# Create a simple topology for testing
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1)

# Create test classes that we want to checkpoint
class TestTransformerLayer(torch.nn.Module):

def forward(self, x):
return x

class ParallelTransformerLayerPipe(TestTransformerLayer):
pass

class GMLPBlock(TestTransformerLayer):
pass

# Create a mock GPT2 model with different layer types
class TestGPT2ModelPipe(PipelineModule):

def __init__(self):
self.layers_spec = [
LayerSpec(ParallelTransformerLayerPipe),
LayerSpec(GMLPBlock),
LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed
]

super().__init__(layers=self.layers_spec,
topology=topo,
checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"])

model = TestGPT2ModelPipe()
model.to(get_accelerator().device_name())

# Build layers manually for testing
layers = [spec.build() for spec in model.layers_spec]

# Test that _is_checkpointable returns correct values
assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe
assert model._is_checkpointable([layers[1]]) == True # GMLPBlock
assert model._is_checkpointable([layers[2]]) == False # Linear layer
Loading