Skip to content

Commit 2c18902

Browse files
authored
Merge branch 'master' into add_torch_compile_test_configs
2 parents dddb92d + 3631712 commit 2c18902

File tree

12 files changed

+184
-6
lines changed

12 files changed

+184
-6
lines changed

deepspeed/datastates/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# DataStates-LLM checkpointing engine.
2+
3+
This feature is not enabled by default. To enable, set the following options in ds_config.json and download the [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).
4+
5+
```
6+
{
7+
... other deepspeed config options,
8+
"datastates_ckpt": {
9+
"host_cache_size": 16
10+
}
11+
}
12+
```

deepspeed/datastates/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team

deepspeed/datastates/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
9+
import copy
10+
11+
DATASTATES_CHECKPOINTING = "datastates_ckpt"
12+
DATASTATES_CHECKPOINTING_ENABLED = False
13+
14+
15+
class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):
16+
17+
def __init__(self, param_dict):
18+
super(DeepSpeedDataStatesConfig, self).__init__()
19+
20+
self.enabled = param_dict.get(DATASTATES_CHECKPOINTING, DATASTATES_CHECKPOINTING_ENABLED) is not False
21+
self.config = copy.deepcopy(param_dict.get(DATASTATES_CHECKPOINTING, None))

deepspeed/runtime/checkpoint_engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
from .torch_checkpoint_engine import TorchCheckpointEngine
99
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
1010
from .checkpoint_engine import CheckpointCommitInfo
11+
from .datastates_checkpoint_engine import DataStatesCheckpointEngine
1112
from .utils import create_checkpoint_engine

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ def get_commit_info(self):
5858

5959
def cleanup(self):
6060
pass
61+
62+
def preserves_storage_sharing(self):
63+
return True
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
9+
CheckpointEngine, CheckpointCommitInfo
10+
11+
ENGINE_NAME = "DataStatesCheckpointEngine"
12+
13+
14+
class DataStatesCheckpointEngine(CheckpointEngine):
15+
16+
def __init__(self, deepspeed_config, rank):
17+
super().__init__(deepspeed_config)
18+
self.commit_info = None
19+
self.ckpt_engine = None
20+
try:
21+
from datastates import CheckpointEngine as DataStatesEngine
22+
self.ckpt_engine = DataStatesEngine(deepspeed_config, rank)
23+
except ImportError:
24+
raise RuntimeError("Please install DataStates from https://github.com/DataStates/datastates-llm.")
25+
except Exception as e:
26+
raise RuntimeError(f"An error occurred while initializing DataStates Checkpoint Engine: {e}")
27+
28+
def __del__(self):
29+
self.cleanup()
30+
31+
def create(self, info: CheckpointCommitInfo):
32+
self.commit_info = info
33+
return None
34+
35+
def save(self, state_dict, path: str):
36+
return self.ckpt_engine.save(state_dict, path)
37+
38+
def load(self, path: str, map_location=None):
39+
return self.ckpt_engine.load(path, map_location)
40+
41+
def commit(self, info: CheckpointCommitInfo):
42+
if info is None:
43+
return
44+
assert info == self.commit_info
45+
self.ckpt_engine.wait(persist=True)
46+
self.commit_info = None
47+
return True
48+
49+
def cleanup(self):
50+
self.commit(self.commit_info)
51+
if self.ckpt_engine:
52+
self.ckpt_engine.wait(persist=True)
53+
del self.ckpt_engine
54+
55+
def is_decoupled(self):
56+
return True
57+
58+
def preserves_storage_sharing(self):
59+
return False

deepspeed/runtime/checkpoint_engine/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from deepspeed.runtime.model_checkpointing.constants import *
77
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
88
from deepspeed.utils import logger
9-
9+
from deepspeed import comm as dist
1010
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
1111
from .fast_checkpoint_engine import FastCheckpointEngine
1212
from .torch_checkpoint_engine import TorchCheckpointEngine
@@ -35,4 +35,14 @@ def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers,
3535
else:
3636
return NebulaCheckpointEngine(config_params=config_params.nebula_config)
3737

38+
if config_params.datastates_config.enabled:
39+
try:
40+
from .datastates_checkpoint_engine import DataStatesCheckpointEngine
41+
return DataStatesCheckpointEngine(deepspeed_config=config_params, rank=dist.get_rank())
42+
except ImportError as err:
43+
logger.error(
44+
f"No datastates engine found! Install from https://github.com/DataStates/datastates-llm. Will fall back to torch.save. Details: {err}"
45+
)
46+
return TorchCheckpointEngine(config_params)
47+
3848
return TorchCheckpointEngine(config_params)

deepspeed/runtime/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from ..profiling.config import DeepSpeedFlopsProfilerConfig
5353
from ..autotuning.config import DeepSpeedAutotuningConfig
5454
from ..nebula.config import DeepSpeedNebulaConfig
55+
from ..datastates.config import DeepSpeedDataStatesConfig
5556

5657
from ..compression.config import get_compression_config, get_quantize_enabled
5758
from ..compression.constants import *
@@ -859,6 +860,7 @@ def _initialize_params(self, param_dict):
859860
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
860861

861862
self.nebula_config = DeepSpeedNebulaConfig(param_dict)
863+
self.datastates_config = DeepSpeedDataStatesConfig(param_dict)
862864
self.checkpoint_config = get_checkpoint_config(param_dict)
863865

864866
self.weight_quantization_config = WeightQuantConfig(

deepspeed/runtime/engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3612,7 +3612,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36123612
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
36133613
if self.random_ltd_enabled():
36143614
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
3615-
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
3615+
saveable_state_dict = expert_state_dict
3616+
if self.checkpoint_engine.preserves_storage_sharing():
3617+
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
36163618
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
36173619
moe_layer_id += 1
36183620

@@ -3634,7 +3636,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36343636
}
36353637
# TODO: why use BufferedWriter not the path
36363638
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
3637-
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
3639+
saveable_state_dict = optimizer_state
3640+
if self.checkpoint_engine.preserves_storage_sharing():
3641+
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
36383642
self.checkpoint_engine.save(saveable_state_dict, file_path)
36393643

36403644
# Load flow uses below saved file for model parameters, RNG and more
@@ -3674,7 +3678,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36743678
}
36753679
state.update(client_state)
36763680
logger.info(f'Saving model checkpoint: {save_path}')
3677-
saveable_state_dict = clone_tensors_for_torch_save(state)
3681+
savable_state_dict = state
3682+
if self.checkpoint_engine.preserves_storage_sharing():
3683+
saveable_state_dict = clone_tensors_for_torch_save(state)
36783684
self.checkpoint_engine.save(saveable_state_dict, save_path)
36793685

36803686
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):

deepspeed/runtime/pipe/module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
621621
layer_list = self.forward_funcs[start:end]
622622

623623
checkpoint_engine.makedirs(save_dir, exist_ok=True)
624+
should_clone = checkpoint_engine.preserves_storage_sharing()
624625
for idx, layer in enumerate(layer_list):
625626
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
626627
if not hasattr(layer, 'state_dict'):
@@ -630,7 +631,9 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
630631
if exclude_frozen_params:
631632
for n in self._get_frozen_parameter_names(layer):
632633
del orig_state_dict[n]
633-
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634+
final_state_dict = orig_state_dict
635+
if should_clone:
636+
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634637
checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path)
635638

636639
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):

0 commit comments

Comments
 (0)