Skip to content

Commit 2356b1c

Browse files
authored
add activation cktp (#31)
* add activation cktp * allow to precise an int as ac ckpt
1 parent 5ddd12d commit 2356b1c

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

src/zeroband/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from zeroband.comms import ElasticDeviceMesh
2323

2424
from zeroband.utils import GPUMemoryMonitor, PerfCounter, get_module_signature, get_sharding_strategy
25+
from zeroband.utils.activation_ckpt import apply_ac_ckpt
2526
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
2627
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
2728
from zeroband.models.llama import get_model
@@ -57,7 +58,7 @@ class TrainConfig(BaseConfig):
5758
micro_bs: int
5859
torch_compile: bool = True
5960
sharding_strategy: str = "SHARD_GRAD_OP"
60-
61+
ac_ckpt: bool | int = False
6162
log_model_hash: bool = False
6263

6364
memory_profiler: MemoryProfilerConfig | None = None
@@ -145,6 +146,10 @@ def train(config: Config):
145146
config.data.seq_length,
146147
)
147148

149+
if config.train.ac_ckpt:
150+
num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt
151+
apply_ac_ckpt(model, num)
152+
148153
elastic_device_mesh = ElasticDeviceMesh()
149154

150155
model = FSDP(

src/zeroband/utils/activation_ckpt.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from zeroband.models.llama.model import Transformer
2+
3+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
4+
5+
from zeroband.utils.logging import get_logger
6+
7+
8+
def apply_ac_ckpt(model: Transformer, num: int):
9+
"""Apply activation checkpointing to the model.
10+
Apply to layers multiple of `num`.
11+
12+
Example if `num=2` only half of the layers are checkpointed.
13+
"""
14+
logger = get_logger()
15+
16+
layers_ckpt = 0
17+
18+
for layer_id, transformer_block in model.layers.named_children():
19+
if layers_ckpt % num == 0:
20+
transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False)
21+
model.layers.register_module(layer_id, transformer_block)
22+
layers_ckpt += 1
23+
24+
logger.info(f"Applied activation checkpointing to {layers_ckpt} layers")

tests/test_torchrun/test_train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,13 @@ def test_multi_gpu_diloco_non_full_shard(strategy):
7171
# we don't test 1,1 and 2,1 because 1 solo gpu failed with fsdp
7272
num_gpus = [2, 2]
7373
_test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--train.sharding_strategy", strategy])
74+
75+
76+
def test_act_ckpt():
77+
num_gpus = [1, 2]
78+
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt"])
79+
80+
81+
def test_act_ckpt_num():
82+
num_gpus = [1, 2]
83+
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt", "2"])

0 commit comments

Comments
 (0)