Skip to content

Commit 7e10d38

Browse files
committed
[checkpointio] support non blocking pin load
1 parent 8369924 commit 7e10d38

15 files changed

+485
-173
lines changed

colossalai/booster/booster.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,14 @@ def enable_lora(
288288

289289
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
290290

291-
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
291+
def load_model(
292+
self,
293+
model: Union[nn.Module, ModelWrapper],
294+
checkpoint: str,
295+
strict: bool = True,
296+
low_cpu_mem_mode: bool = True,
297+
num_threads: int = 1,
298+
) -> None:
292299
"""Load model from checkpoint.
293300
294301
Args:
@@ -298,8 +305,12 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str
298305
strict (bool, optional): whether to strictly enforce that the keys
299306
in :attr:`state_dict` match the keys returned by this module's
300307
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
308+
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
309+
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
301310
"""
302-
self.checkpoint_io.load_model(model, checkpoint, strict)
311+
self.checkpoint_io.load_model(
312+
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
313+
)
303314

304315
def save_model(
305316
self,
@@ -338,18 +349,25 @@ def save_model(
338349
use_async=use_async,
339350
)
340351

341-
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
352+
def load_optimizer(
353+
self,
354+
optimizer: Optimizer,
355+
checkpoint: str,
356+
low_cpu_mem_mode: bool = True,
357+
num_threads: int = 1,
358+
) -> None:
342359
"""Load optimizer from checkpoint.
343360
344361
Args:
345362
optimizer (Optimizer): An optimizer boosted by Booster.
346363
checkpoint (str): Path to the checkpoint. It must be a local path.
347364
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
348-
prefix (str, optional): A prefix added to parameter and buffer
349-
names to compose the keys in state_dict. Defaults to None.
350-
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
365+
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
366+
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
351367
"""
352-
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
368+
self.checkpoint_io.load_optimizer(
369+
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
370+
)
353371

354372
def save_optimizer(
355373
self,

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,22 @@ def save_unsharded_model(
9797
else:
9898
save_state_dict(state_dict, checkpoint, use_safetensors)
9999

100-
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
100+
def load_unsharded_model(
101+
self,
102+
model: GeminiDDP,
103+
checkpoint: str,
104+
strict: bool = True,
105+
low_cpu_mem_mode: bool = True,
106+
num_threads: int = 1,
107+
):
101108
"""
102109
Load model from checkpoint with automatic unwrapping.
103110
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
104111
"""
105112
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
106-
super().load_unsharded_model(model, checkpoint, strict=strict)
113+
super().load_unsharded_model(
114+
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
115+
)
107116

108117
def save_unsharded_optimizer(
109118
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@@ -131,13 +140,17 @@ def save_unsharded_optimizer(
131140
else:
132141
save_state_dict(state_dict, checkpoint, use_safetensors=False)
133142

134-
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
143+
def load_unsharded_optimizer(
144+
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
145+
):
135146
"""
136147
Loading unsharded optimizer from checkpoint file.
137148
For each process, only loading optimizer states of parameters it controls.
138149
"""
139150
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
140-
super().load_unsharded_optimizer(optimizer, checkpoint)
151+
super().load_unsharded_optimizer(
152+
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
153+
)
141154

142155
def save_sharded_model(
143156
self,
@@ -206,13 +219,27 @@ def save_sharded_model(
206219
)
207220

208221
def load_sharded_model(
209-
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
222+
self,
223+
model: GeminiDDP,
224+
checkpoint_index_file: Path,
225+
strict: bool = False,
226+
use_safetensors: bool = False,
227+
low_cpu_mem_mode: bool = True,
228+
num_threads: int = 1,
210229
):
211230
"""
212231
Load shard model, load model from multiple files.
213232
"""
214233
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
215-
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
234+
return super().load_sharded_model(
235+
model,
236+
checkpoint_index_file,
237+
strict,
238+
use_safetensors,
239+
load_sub_module=False,
240+
low_cpu_mem_mode=low_cpu_mem_mode,
241+
num_threads=num_threads,
242+
)
216243

217244
def save_sharded_optimizer(
218245
self,
@@ -289,7 +316,14 @@ def save_sharded_optimizer(
289316
ranks=[0],
290317
)
291318

292-
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
319+
def load_sharded_optimizer(
320+
self,
321+
optimizer: GeminiOptimizer,
322+
checkpoint_index_file: Path,
323+
prefix: str,
324+
low_cpu_mem_mode: bool = True,
325+
num_threads: int = 1,
326+
):
293327
"""
294328
Loading sharded optimizer from checkpoint folder, with index file given.
295329
For each process, only loading optimizer states of parameters it controls.
@@ -322,9 +356,9 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
322356
state_dict_shard = load_flat(shard_file)
323357
else:
324358
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
359+
if not low_cpu_mem_mode:
360+
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
325361
optimizer.load_param_states(state_dict_shard)
326-
del state_dict_shard
327-
gc.collect()
328362

329363
optimizer.optimizer_loading_epilogue()
330364

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from colossalai.accelerator import get_accelerator
2121
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
2222
from colossalai.checkpoint_io.utils import (
23+
create_pinned_state_dict,
2324
get_optimizer_base_filenames,
2425
get_shard_filename,
2526
load_param_groups_into_optimizer,
@@ -145,14 +146,18 @@ def save_unsharded_optimizer(
145146
else:
146147
save_state_dict(state_dict, checkpoint, use_safetensors=False)
147148

148-
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
149+
def load_unsharded_optimizer(
150+
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
151+
):
149152
use_async = checkpoint.endswith(".safetensors")
150153
if use_async:
151154
from colossalai.utils.safetensors import load_flat
152155

153156
checkpoint = load_flat(checkpoint)
154157
else:
155158
checkpoint = load_state_dict(checkpoint)
159+
if not low_cpu_mem_mode:
160+
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
156161
optimizer.load_state_dict(checkpoint)
157162

158163
def save_sharded_optimizer(
@@ -239,7 +244,14 @@ def save_sharded_optimizer(
239244
ranks=[0],
240245
)
241246

242-
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
247+
def load_sharded_optimizer(
248+
self,
249+
optimizer: OptimizerWrapper,
250+
index_file_path: str,
251+
prefix: str,
252+
low_cpu_mem_mode: bool = True,
253+
num_threads: int = 1,
254+
):
243255
"""Load sharded optimizer with the given path to index file.
244256
245257
Args:
@@ -283,14 +295,28 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
283295
if padding_size > 0:
284296
v = torch.nn.functional.pad(v, [0, padding_size])
285297
v_list = v.split(v.numel() // self.coordinator.world_size)
286-
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
298+
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
299+
if low_cpu_mem_mode:
300+
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
301+
302+
if not low_cpu_mem_mode:
303+
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
287304
load_states_into_optimizer(optimizer, state_dict, id_map)
288305
sharded_optimizer_loading_epilogue(optimizer)
289306

290-
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
307+
def load_unsharded_model(
308+
self,
309+
model: ModelWrapper,
310+
checkpoint: str,
311+
strict: bool = True,
312+
low_cpu_mem_mode: bool = True,
313+
num_threads: int = 1,
314+
):
291315
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
292316
model._force_wait_all_gather()
293-
super().load_unsharded_model(model, checkpoint, strict)
317+
super().load_unsharded_model(
318+
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
319+
)
294320
model.update_master_params()
295321

296322
def load_sharded_model(
@@ -300,10 +326,20 @@ def load_sharded_model(
300326
strict: bool = False,
301327
use_safetensors: bool = False,
302328
load_sub_module: bool = True,
329+
low_cpu_mem_mode: bool = True,
330+
num_threads: int = 1,
303331
):
304332
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
305333
model._force_wait_all_gather()
306-
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
334+
super().load_sharded_model(
335+
model,
336+
checkpoint_index_file,
337+
strict,
338+
use_safetensors,
339+
load_sub_module,
340+
low_cpu_mem_mode=low_cpu_mem_mode,
341+
num_threads=num_threads,
342+
)
307343
model.update_master_params()
308344

309345
def save_unsharded_model(

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,21 @@ def __init__(self) -> None:
2626
self.coordinator = DistCoordinator()
2727
self.logger = get_dist_logger()
2828

29-
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
29+
def load_unsharded_model(
30+
self,
31+
model: ModelWrapper,
32+
checkpoint: str,
33+
strict: bool = True,
34+
low_cpu_mem_mode: bool = True,
35+
num_threads: int = 1,
36+
):
3037
"""
3138
Load model from checkpoint.
3239
"""
3340
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
34-
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
41+
super().load_unsharded_model(
42+
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
43+
)
3544

3645
def save_unsharded_model(
3746
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
@@ -45,12 +54,16 @@ def save_unsharded_model(
4554
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
4655
)
4756

48-
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
57+
def load_unsharded_optimizer(
58+
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
59+
):
4960
"""
5061
Load optimizer from checkpoint.
5162
"""
5263
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
53-
super().load_unsharded_optimizer(optimizer, checkpoint)
64+
super().load_unsharded_optimizer(
65+
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
66+
)
5467

5568
def save_unsharded_optimizer(
5669
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@@ -101,12 +114,22 @@ def load_sharded_model(
101114
strict: bool = False,
102115
use_safetensors: bool = False,
103116
load_sub_module: bool = True,
117+
low_cpu_mem_mode: bool = True,
118+
num_threads: int = 1,
104119
):
105120
"""
106121
Load model from sharded checkpoint.
107122
"""
108123
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
109-
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
124+
super().load_sharded_model(
125+
model.unwrap(),
126+
checkpoint_index_file,
127+
strict,
128+
use_safetensors,
129+
load_sub_module,
130+
low_cpu_mem_mode=low_cpu_mem_mode,
131+
num_threads=num_threads,
132+
)
110133

111134
def save_sharded_optimizer(
112135
self,
@@ -131,12 +154,16 @@ def load_sharded_optimizer(
131154
optimizer: Optimizer,
132155
index_file_path: str,
133156
prefix: Optional[str] = None,
157+
low_cpu_mem_mode: bool = True,
158+
num_threads: int = 1,
134159
):
135160
"""
136161
Load optimizer from sharded checkpoint.
137162
"""
138163
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
139-
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
164+
super().load_sharded_optimizer(
165+
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
166+
)
140167

141168
def save_lora_as_pretrained(
142169
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False

colossalai/booster/plugin/torch_fsdp_plugin.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,17 @@ def __init__(self) -> None:
4343
self.coordinator = DistCoordinator()
4444
self.logger = get_dist_logger()
4545

46-
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
46+
def load_unsharded_model(
47+
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
48+
):
4749
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
4850
model = model.unwrap()
4951
checkpoint = utils.load_state_dict(checkpoint)
5052
model.load_state_dict(checkpoint)
5153

52-
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
54+
def load_unsharded_optimizer(
55+
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
56+
):
5357
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
5458
if checkpoint.endswith(".safetensors"):
5559
checkpoint = load_flat(checkpoint, seperator=".")
@@ -232,6 +236,8 @@ def load_sharded_model(
232236
strict: bool = False,
233237
use_safetensors: bool = False,
234238
load_sub_module: bool = True,
239+
low_cpu_mem_mode: bool = True,
240+
num_threads: int = 1,
235241
):
236242
"""
237243
Load model to checkpoint but only on master process.
@@ -354,7 +360,14 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
354360
f"index located at {save_index_file}."
355361
)
356362

357-
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
363+
def load_sharded_optimizer(
364+
self,
365+
optimizer: Optimizer,
366+
index_file_path: str,
367+
size_per_shard: int,
368+
low_cpu_mem_mode: bool = True,
369+
num_threads: int = 1,
370+
):
358371
"""
359372
Load optimizer to checkpoint but only on master process.
360373
"""

0 commit comments

Comments
 (0)