Skip to content

Commit ab856fd

Browse files
authored
[checkpointio] fix zero optimizer async save memory (#6151)
* [checkpointio] fix zero optimizer async save memory * [checkpointio] fit new tensornvme api * [checkpointio] fit new tensornvme api
1 parent 8ecff0c commit ab856fd

File tree

7 files changed

+57
-42
lines changed

7 files changed

+57
-42
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,20 @@ def save_unsharded_optimizer(
128128
# the `state_dict` in LowLevelZeroOptimizer has communication
129129
# if only the master rank collect state_dict and save,
130130
# the communication on each rank would not match
131-
if use_async:
131+
if use_async and self.coordinator.is_master():
132132
if id(optimizer) not in self.pinned_state_dicts:
133133
self.pinned_state_dicts[id(optimizer)] = {}
134134
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
135135
else:
136136
pinned_state_dicts = None
137-
state_dict = optimizer.state_dict(pinned_state_dicts)
137+
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
138138
if self.coordinator.is_master():
139139
if use_async:
140140
from tensornvme.async_file_io import AsyncFileWriter
141141

142142
from colossalai.utils.safetensors import save_nested
143143

144-
f_writer = AsyncFileWriter(
145-
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
146-
)
144+
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
147145
save_nested(f_writer, state_dict)
148146
self.async_writers.append(f_writer)
149147
else:
@@ -192,13 +190,15 @@ def save_sharded_optimizer(
192190
# state_dict only provide only 'param_groups'
193191
state_dict = optimizer.optim.state_dict()
194192
# state shard would be handled by the low-level zero optimizer
195-
if use_async:
193+
if use_async and self.coordinator.is_master():
196194
if id(optimizer) not in self.pinned_state_dicts:
197195
self.pinned_state_dicts[id(optimizer)] = {}
198196
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
199197
else:
200198
pinned_state_dicts = None
201-
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
199+
sharded_state = optimizer.state_dict_shard(
200+
max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True
201+
)
202202

203203
# Preparing file paths and index file.
204204
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
@@ -227,7 +227,7 @@ def save_sharded_optimizer(
227227
from colossalai.utils.safetensors import save_nested
228228

229229
f_writer = AsyncFileWriter(
230-
fp=open(checkpoint_file_path, "wb", buffering=0),
230+
checkpoint_file_path,
231231
n_entries=self.N_WRITE_ENTRIES,
232232
backend="pthread",
233233
)

colossalai/checkpoint_io/checkpoint_io_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def __init__(self):
7272
def _sync_io(self):
7373
for writer in self.async_writers:
7474
writer.synchronize()
75-
writer.fp.close()
7675
self.async_writers.clear()
7776

7877
def _sync_d2h(self):

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def save_unsharded_model(
5656
if use_async:
5757
from tensornvme.async_file_io import AsyncFileWriter
5858

59-
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
59+
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
6060
if id(model) not in self.pinned_state_dicts:
6161
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
6262
self.async_writers.append(writer)

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,7 @@ def save_unsharded_model(
690690

691691
from colossalai.utils.safetensors import move_and_save
692692

693-
writer = AsyncFileWriter(
694-
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
695-
)
693+
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
696694
if id(model) not in self.pinned_state_dicts:
697695
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
698696
self.async_writers.append(writer)

colossalai/checkpoint_io/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def async_save_state_dict_shards(
311311
index_file.append_weight_map(key, shard_file)
312312
checkpoint_file_path = os.path.join(checkpoint, shard_file)
313313

314-
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
314+
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
315315
writers.append(writer)
316316

317317
if pinned_state_dict is not None:

colossalai/zero/low_level/low_level_optim.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,9 @@ def pack_group(group):
776776

777777
return {"state": packed_state, "param_groups": param_groups}
778778

779-
def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
779+
def state_dict(
780+
self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, only_on_master: bool = False
781+
) -> Dict:
780782
"""Return a state_dict same with DDP
781783
782784
Returns:
@@ -785,23 +787,29 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens
785787
zero_state = dict()
786788
device = get_accelerator().get_current_device()
787789
for param, state in self.optim.state.items():
790+
working_param = self.master_to_working_param[id(param)]
791+
pg = self.param_to_pg[working_param]
792+
if not only_on_master or get_nd_rank(pg) == 0:
793+
zero_state[param] = copy.deepcopy(state)
794+
else:
795+
zero_state[param] = {}
796+
788797
if pinned_state_dicts is not None and param not in pinned_state_dicts:
789798
pinned_state_dicts[param] = {}
790-
zero_state[param] = copy.deepcopy(state)
799+
791800
for k, v in state.items():
792801
if isinstance(v, torch.Tensor) and k != "step":
793-
working_param = self.master_to_working_param[id(param)]
794-
pg = self.param_to_pg[working_param]
795802
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
796803
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
797804
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
798-
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
799-
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
800-
if pinned_state_dicts is not None:
801-
pinned_state_dicts[param][k].copy_(param_state)
802-
zero_state[param][k] = pinned_state_dicts[param][k]
803-
else:
804-
zero_state[param][k] = param_state.cpu()
805+
if not only_on_master or get_nd_rank(pg) == 0:
806+
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
807+
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
808+
if pinned_state_dicts is not None:
809+
pinned_state_dicts[param][k].copy_(param_state)
810+
zero_state[param][k] = pinned_state_dicts[param][k]
811+
else:
812+
zero_state[param][k] = param_state.cpu()
805813

806814
states_dict = self._pack_state(zero_state)
807815

@@ -837,7 +845,10 @@ def load_state_dict(self, state_dict: Dict):
837845
self.optim.load_state_dict(zero_state_dict)
838846

839847
def state_dict_shard(
840-
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
848+
self,
849+
max_shard_size: int = 1024,
850+
pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
851+
only_on_master: bool = False,
841852
) -> Iterator[Tuple[Dict, int]]:
842853
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
843854
Only include the 'state' in state_dict.
@@ -862,25 +873,31 @@ def state_dict_shard(
862873
cnt += 1
863874
for param_idx, states in local_states.items():
864875
current_block_size = 0
865-
current_block = copy.deepcopy(states)
866876
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
867877
pinned_state_dicts[param_idx] = {}
868878
master_param = idx2master[param_idx]
869879
working_param = self.master_to_working_param[id(master_param)]
870880
pg = self.param_to_pg[working_param]
881+
if not only_on_master or get_nd_rank(pg) == 0:
882+
current_block = copy.deepcopy(states)
883+
else:
884+
current_block = {}
871885

872886
for k, v in states.items():
873887
if isinstance(v, torch.Tensor) and k != "step":
874888
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
875889
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
876890
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
877-
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
878-
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
879-
if pinned_state_dicts is not None:
880-
pinned_state_dicts[param_idx][k].copy_(state_tensor)
881-
current_block[k] = pinned_state_dicts[param_idx][k]
882-
else:
883-
current_block[k] = state_tensor.cpu()
891+
if not only_on_master or get_nd_rank(pg) == 0:
892+
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
893+
pinned_state_dicts[param_idx][k] = torch.empty_like(
894+
state_tensor, pin_memory=True, device="cpu"
895+
)
896+
if pinned_state_dicts is not None:
897+
pinned_state_dicts[param_idx][k].copy_(state_tensor)
898+
current_block[k] = pinned_state_dicts[param_idx][k]
899+
else:
900+
current_block[k] = state_tensor.cpu()
884901
current_block_size += calculate_tensor_size(state_tensor)
885902

886903
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:

tests/test_checkpoint_io/test_safetensors_async_io.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
except ModuleNotFoundError:
1111
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
1212

13+
1314
from colossalai.testing import check_state_dict_equal
1415
from colossalai.utils import get_current_device
1516

@@ -110,20 +111,20 @@ def test_save_load():
110111
}
111112

112113
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
113-
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
114+
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
114115
save_nested(f_writer, optimizer_state_dict)
115116
f_writer.sync_before_step()
116117
f_writer.synchronize()
117-
f_writer.fp.close()
118+
del f_writer
118119
load_state_dict = load_flat(optimizer_saved_path)
119120
check_state_dict_equal(load_state_dict, optimizer_state_dict)
120121

121122
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
122-
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
123+
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
123124
save_nested(f_writer, optimizer_state_dict["state"])
124125
f_writer.sync_before_step()
125126
f_writer.synchronize()
126-
f_writer.fp.close()
127+
del f_writer
127128
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
128129
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
129130

@@ -133,21 +134,21 @@ def test_save_load():
133134
"module.weight2": torch.rand((1024, 1024)),
134135
}
135136
model_saved_path = f"{tempdir}/save_model.safetensors"
136-
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
137+
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
137138
save(f_writer, model_state_dict)
138139
f_writer.sync_before_step()
139140
f_writer.synchronize()
140-
f_writer.fp.close()
141+
del f_writer
141142
load_state_dict = load_file(model_saved_path)
142143
check_state_dict_equal(model_state_dict, load_state_dict)
143144

144145
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
145146
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
146147
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
147-
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
148+
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
148149
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
149150
f_writer.sync_before_step()
150151
f_writer.synchronize()
151-
f_writer.fp.close()
152+
del f_writer
152153
load_state_dict = load_file(model_saved_path)
153154
check_state_dict_equal(model_state_dict, load_state_dict)

0 commit comments

Comments
 (0)