diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 37dd9ed8eb..869b127e96 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -980,7 +980,16 @@ def _restore_checkpoint( # Ensure state exists state_dict['state'] = state_dict.get('state', {}) log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}") - + from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard + from torch.distributed._shard.sharded_tensor import ShardedTensor + for k, v in state_dict['state']['model'].items(): + if isinstance(v, ShardedTensor): + dtensor = DTensor.from_local( + v.local_tensor(), + device_mesh=,#get device mesh from state.model, + placements=,#get device mesh placements from state.model, + run_check=False,) + state_dict['state']['model'][k] = torch.nn.Parameter(dtensor) if is_model_deepspeed(state.model): if extracted_checkpoint_folder is None: raise RuntimeError('Deepspeed checkpoints require a tarball, not a weights file.')