@@ -254,7 +254,7 @@ def _del__(self):
254
254
255
255
self .wait_async_save_process ()
256
256
257
- def load (self , resume_ckpt_path : str , diloco_rank : int | None = None ) -> None :
257
+ def load (self , resume_ckpt_path : str , diloco_rank : int | None = None , skip_dataloader : bool = False ) -> None :
258
258
"""
259
259
loading should be done after fsdp wrap and optimizer init.
260
260
Each rank will load the right shard of the model and optimizer.
@@ -279,11 +279,12 @@ def load(self, resume_ckpt_path: str, diloco_rank: int | None = None) -> None:
279
279
for param_offloaded , param_model in zip (self .diloco_offloaded_param_list , self .model .parameters ()):
280
280
param_offloaded .data .copy_ (param_model .data )
281
281
282
- ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
283
- with open (os .path .join (resume_ckpt_path , f"__{ world_info .local_rank } _0.pt" ), "rb" ) as f :
284
- rank_state_dict = torch .load (f )
282
+ if not skip_dataloader :
283
+ ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
284
+ with open (os .path .join (resume_ckpt_path , f"__{ world_info .local_rank } _0.pt" ), "rb" ) as f :
285
+ rank_state_dict = torch .load (f )
285
286
286
- self .dataloader .load_state_dict (rank_state_dict ["data_loader" ])
287
+ self .dataloader .load_state_dict (rank_state_dict ["data_loader" ])
287
288
288
289
self ._init_state ()
289
290
@@ -308,7 +309,8 @@ def download_and_load_ckpt_from_peers(self, adress: str):
308
309
destination = path ,
309
310
)
310
311
dist .barrier ()
311
- self .load (resume_ckpt_path = ckpt_path )
312
+ self .load (resume_ckpt_path = ckpt_path , skip_dataloader = True )
313
+ # we don't want the dataloader states to be loaded as they are not the same on each rank
312
314
313
315
314
316
class CkptLiveServer :
0 commit comments