-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix
checkpointable_layers
Logic (#6881)
**Problem** There's an edge-case in DeepSpeed, where if all three of the following are true: 1. Deepspeed activation checkpointing is applied 2. The user passes `checkpointable_layers` (e.g. https://github.com/EleutherAI/gpt-neox/blob/f5325805678c2b9e35aae4528283e0132c5f5bbc/megatron/model/gpt2_model.py#L175) 3. The user's model class contains `GPT2ModelPipe` or GPTModelPipe` Then the `checkpointable_layers` will not be activation checkpointed. **Reason** This is because in the current logic, `_is_checkpointable` will short-circuit to just return layers matching `ParallelTransformerLayerPipe` in the case of `self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe')`. See https://github.com/microsoft/DeepSpeed/blob/da771ed42e41a44d5047813ca4672f1cfe9d1731/deepspeed/runtime/pipe/module.py#L653 **Proposed Fixes** I think that `checkpointable_layers` should always be checked for, and added logic to this effect. I also found the documentation for `checkpointable_layers` confusing and contradictory, so I updated the docstring. Lastly, I added a unit test for `checkpointable_layers`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
- Loading branch information
1 parent
a8ede3a
commit 0dbbb70
Showing
2 changed files
with
62 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters