Skip to content

Commit

Permalink
Fix checkpointable_layers Logic (#6881)
Browse files Browse the repository at this point in the history
**Problem**

There's an edge-case in DeepSpeed, where if all three of the following
are true:
1. Deepspeed activation checkpointing is applied 
2. The user passes `checkpointable_layers` (e.g.
https://github.com/EleutherAI/gpt-neox/blob/f5325805678c2b9e35aae4528283e0132c5f5bbc/megatron/model/gpt2_model.py#L175)
3. The user's model class contains `GPT2ModelPipe` or GPTModelPipe`

Then the `checkpointable_layers` will not be activation checkpointed. 

**Reason**

This is because in the current logic, `_is_checkpointable` will
short-circuit to just return layers matching
`ParallelTransformerLayerPipe` in the case of `self.__class__.__name__
in ('GPTModelPipe', 'GPT2ModelPipe')`. See
https://github.com/microsoft/DeepSpeed/blob/da771ed42e41a44d5047813ca4672f1cfe9d1731/deepspeed/runtime/pipe/module.py#L653

**Proposed Fixes**

I think that `checkpointable_layers` should always be checked for, and
added logic to this effect. I also found the documentation for
`checkpointable_layers` confusing and contradictory, so I updated the
docstring. Lastly, I added a unit test for `checkpointable_layers`.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
Quentin-Anthony and loadams authored Jan 4, 2025
1 parent a8ede3a commit 0dbbb70
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
14 changes: 12 additions & 2 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def forward(self, inputs):
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models,
ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are
considered checkpointable. Defaults to None.
dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
"""

Expand Down Expand Up @@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs):
# because only non_reentrant_checkpoint can accept inputs with requires_grad=False
# otherwise, the backward of the embedding layer won't receive gradients.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
# For GPT models, checkpoint both transformer layers and any additional
# layers specified in checkpointable_layers (if provided)
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or (
self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers)
for f in funcs)

if self.checkpointable_layers is not None:
# For non-GPT models, only checkpoint layers specified in checkpointable_layers
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)

# Default behavior: checkpoint any layer that has parameters
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
import deepspeed
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.accelerator import get_accelerator
from copy import deepcopy
from unit.common import DistributedTest
Expand Down Expand Up @@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output):
else:
ordering += [torch.is_tensor(non_tensor_output)]
_test_activation_checkpoint_ordering(module, ordering, inputs)


class TestCheckpointableLayersConfig(DistributedTest):
world_size = 1

def test_gpt2_checkpointable_layers(self):
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU accelerator does not support this test yet")

# Create a simple topology for testing
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1, num_dp=1)

# Create test classes that we want to checkpoint
class TestTransformerLayer(torch.nn.Module):

def forward(self, x):
return x

class ParallelTransformerLayerPipe(TestTransformerLayer):
pass

class GMLPBlock(TestTransformerLayer):
pass

# Create a mock GPT2 model with different layer types
class TestGPT2ModelPipe(PipelineModule):

def __init__(self):
self.layers_spec = [
LayerSpec(ParallelTransformerLayerPipe),
LayerSpec(GMLPBlock),
LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed
]

super().__init__(layers=self.layers_spec,
topology=topo,
checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"])

model = TestGPT2ModelPipe()
model.to(get_accelerator().device_name())

# Build layers manually for testing
layers = [spec.build() for spec in model.layers_spec]

# Test that _is_checkpointable returns correct values
assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe
assert model._is_checkpointable([layers[1]]) == True # GMLPBlock
assert model._is_checkpointable([layers[2]]) == False # Linear layer

0 comments on commit 0dbbb70

Please sign in to comment.