From 94bb251df4aecabdb391915bb7042cdf69c17b40 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Tue, 28 May 2024 14:17:13 -0700 Subject: [PATCH] snippet --- composer/utils/checkpoint.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 37dd9ed8eb..869b127e96 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -980,7 +980,16 @@ def _restore_checkpoint( # Ensure state exists state_dict['state'] = state_dict.get('state', {}) log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}") - + from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard + from torch.distributed._shard.sharded_tensor import ShardedTensor + for k, v in state_dict['state']['model'].items(): + if isinstance(v, ShardedTensor): + dtensor = DTensor.from_local( + v.local_tensor(), + device_mesh=,#get device mesh from state.model, + placements=,#get device mesh placements from state.model, + run_check=False,) + state_dict['state']['model'][k] = torch.nn.Parameter(dtensor) if is_model_deepspeed(state.model): if extracted_checkpoint_folder is None: raise RuntimeError('Deepspeed checkpoints require a tarball, not a weights file.')