Skip to content

Commit dbd567d

Browse files
committed
remove load dataloader
1 parent 3439347 commit dbd567d

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/zeroband/checkpoint.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def _del__(self):
254254

255255
self.wait_async_save_process()
256256

257-
def load(self, resume_ckpt_path: str, diloco_rank: int | None = None) -> None:
257+
def load(self, resume_ckpt_path: str, diloco_rank: int | None = None, skip_dataloader: bool = False) -> None:
258258
"""
259259
loading should be done after fsdp wrap and optimizer init.
260260
Each rank will load the right shard of the model and optimizer.
@@ -279,11 +279,12 @@ def load(self, resume_ckpt_path: str, diloco_rank: int | None = None) -> None:
279279
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.parameters()):
280280
param_offloaded.data.copy_(param_model.data)
281281

282-
## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
283-
with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
284-
rank_state_dict = torch.load(f)
282+
if not skip_dataloader:
283+
## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
284+
with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
285+
rank_state_dict = torch.load(f)
285286

286-
self.dataloader.load_state_dict(rank_state_dict["data_loader"])
287+
self.dataloader.load_state_dict(rank_state_dict["data_loader"])
287288

288289
self._init_state()
289290

@@ -308,7 +309,8 @@ def download_and_load_ckpt_from_peers(self, adress: str):
308309
destination=path,
309310
)
310311
dist.barrier()
311-
self.load(resume_ckpt_path=ckpt_path)
312+
self.load(resume_ckpt_path=ckpt_path, skip_dataloader=True)
313+
# we don't want the dataloader states to be loaded as they are not the same on each rank
312314

313315

314316
class CkptLiveServer:

0 commit comments

Comments
 (0)