13
13
import torch .distributed as dist
14
14
import torch .distributed .checkpoint as dcp
15
15
import torch .nn as nn
16
- from torch .distributed .checkpoint .state_dict import (StateDictOptions ,
17
- get_model_state_dict ,
18
- get_optimizer_state_dict ,
19
- set_model_state_dict ,
20
- set_optimizer_state_dict )
16
+ from torch .distributed .checkpoint .state_dict import (
17
+ StateDictOptions ,
18
+ get_model_state_dict ,
19
+ get_optimizer_state_dict ,
20
+ set_model_state_dict ,
21
+ set_optimizer_state_dict ,
22
+ )
21
23
from torch .distributed .checkpoint .stateful import Stateful
22
24
from torch .utils .data import DataLoader
23
25
@@ -50,9 +52,7 @@ def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
50
52
self .model = [model ] if isinstance (model , nn .Module ) else model
51
53
52
54
def state_dict (self ) -> None :
53
- return {
54
- k : v for sd in map (get_model_state_dict , self .model ) for k , v in sd .items ()
55
- }
55
+ return {k : v for sd in map (get_model_state_dict , self .model ) for k , v in sd .items ()}
56
56
57
57
def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
58
58
func = functools .partial (
@@ -70,8 +70,7 @@ def __init__(
70
70
optim : Union [torch .optim .Optimizer , List [torch .optim .Optimizer ]],
71
71
) -> None :
72
72
self .model = [model ] if isinstance (model , nn .Module ) else model
73
- self .optim = [optim ] if isinstance (
74
- optim , torch .optim .Optimizer ) else optim
73
+ self .optim = [optim ] if isinstance (optim , torch .optim .Optimizer ) else optim
75
74
76
75
def state_dict (self ) -> None :
77
76
func = functools .partial (
@@ -109,8 +108,7 @@ def checkpoint_mp(recv, send):
109
108
state , checkpoint_id = obj
110
109
dcp .save (state , checkpoint_id = checkpoint_id )
111
110
logger .info (
112
- "Finish saving the checkpoint in the background process in "
113
- f"{ time .monotonic () - begin :.2f} seconds."
111
+ "Finish saving the checkpoint in the background process in " f"{ time .monotonic () - begin :.2f} seconds."
114
112
)
115
113
finally :
116
114
logger .info ("Destroying the process group." )
@@ -158,19 +156,11 @@ def __init__(
158
156
159
157
TODO: This is currently unsolved and needs a fix.
160
158
"""
161
- assert len (model_parts ) == len (
162
- optimizers
163
- ), "Must pass one optimizer per model part"
164
- assert len (model_parts ) == len (
165
- lr_schedulers
166
- ), "Must pass one lr_scheduler per model part"
167
-
168
- assert len (model_parts ) == len (
169
- optimizers
170
- ), "Must pass one optimizer per model part"
171
- assert len (model_parts ) == len (
172
- lr_schedulers
173
- ), "Must pass one lr_scheduler per model part"
159
+ assert len (model_parts ) == len (optimizers ), "Must pass one optimizer per model part"
160
+ assert len (model_parts ) == len (lr_schedulers ), "Must pass one lr_scheduler per model part"
161
+
162
+ assert len (model_parts ) == len (optimizers ), "Must pass one optimizer per model part"
163
+ assert len (model_parts ) == len (lr_schedulers ), "Must pass one lr_scheduler per model part"
174
164
175
165
self .states = states
176
166
@@ -190,11 +180,7 @@ def __init__(
190
180
self .states [f"lr_scheduler_{ idx } " ] = lr_scheduler
191
181
192
182
self .folder = os .path .join (ckpt_config .ckpt_dir )
193
- self .interval_type = (
194
- IntervalType .SECONDS
195
- if ckpt_config .interval_type == "seconds"
196
- else IntervalType .STEPS
197
- )
183
+ self .interval_type = IntervalType .SECONDS if ckpt_config .interval_type == "seconds" else IntervalType .STEPS
198
184
self .interval = ckpt_config .interval
199
185
self .begin_time = 0
200
186
self .time_sync_work = None
@@ -231,12 +217,9 @@ def __init__(
231
217
self .staging_id = None
232
218
self .staging_stream = torch .cuda .Stream ()
233
219
else :
234
- raise ValueError (
235
- f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
220
+ raise ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
236
221
237
- logger .info (
238
- f"Checkpointing active. Checkpoints will be loaded from and saved to { self .folder } "
239
- )
222
+ logger .info (f"Checkpointing active. Checkpoints will be loaded from and saved to { self .folder } " )
240
223
241
224
def __del__ (self ):
242
225
if self .enable_checkpoint and self .mp and self .mp .is_alive ():
@@ -268,16 +251,12 @@ def _save_last_step(self, curr_step: int) -> None:
268
251
self .states .pop ("freqs_cis" )
269
252
270
253
if self .export_dtype != torch .float32 :
271
- self .states = {
272
- k : v .to (self .export_dtype ) for k , v in self .states .items ()
273
- }
254
+ self .states = {k : v .to (self .export_dtype ) for k , v in self .states .items ()}
274
255
logger .info (
275
- f"Saving a model weights only checkpoint in { self .export_dtype } "
276
- f"at last step, step { curr_step } ."
256
+ f"Saving a model weights only checkpoint in { self .export_dtype } " f"at last step, step { curr_step } ."
277
257
)
278
258
else :
279
- logger .info (
280
- f"Saving a full checkpoint at last step, step { curr_step } ." )
259
+ logger .info (f"Saving a full checkpoint at last step, step { curr_step } ." )
281
260
282
261
dcp .save (self .states , checkpoint_id = self ._create_checkpoint_id (curr_step ))
283
262
self .reset ()
@@ -287,18 +266,13 @@ def _should_save(self, curr_step: int, force: bool = False) -> bool:
287
266
return False
288
267
289
268
if not force :
290
- if self .interval_type == IntervalType .STEPS and not (
291
- curr_step % self .interval == 0
292
- ):
269
+ if self .interval_type == IntervalType .STEPS and not (curr_step % self .interval == 0 ):
293
270
return False
294
271
if self .interval_type == IntervalType .SECONDS :
295
- time_sync_result = (time .monotonic () -
296
- self .begin_time ) >= self .interval
272
+ time_sync_result = (time .monotonic () - self .begin_time ) >= self .interval
297
273
self .time_sync_result = torch .tensor (int (time_sync_result ))
298
274
if self .time_sync_work is None :
299
- self .time_sync_work = dist .all_reduce (
300
- self .time_sync_result , group = self .pg , async_op = True
301
- )
275
+ self .time_sync_work = dist .all_reduce (self .time_sync_result , group = self .pg , async_op = True )
302
276
return False
303
277
elif curr_step % 5 == 4 :
304
278
self .time_sync_work .wait ()
@@ -319,31 +293,25 @@ def _should_save(self, curr_step: int, force: bool = False) -> bool:
319
293
320
294
def _async_wait (self ) -> None :
321
295
if self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
322
- logger .debug (
323
- f"Waiting for the background process to finish, { time .monotonic ()= } .:.2f"
324
- )
296
+ logger .debug (f"Waiting for the background process to finish, { time .monotonic ()= } .:.2f" )
325
297
if not self .mp .is_alive ():
326
- raise RuntimeError (
327
- "The checkpoint background process is dead." )
298
+ raise RuntimeError ("The checkpoint background process is dead." )
328
299
_ = self .mp_queue_recv .get ()
329
300
elif self .async_mode == AsyncMode .ASYNC :
330
301
if self .async_future is not None :
331
302
self .async_future .result ()
332
303
333
304
def _async_with_pinned_memory (self , checkpoint_id : str ) -> None :
334
305
try :
335
- from torch .distributed ._state_dict_utils import (
336
- _copy_state_dict , _create_cpu_state_dict )
306
+ from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
337
307
except ImportError as e :
338
308
raise ImportError (
339
309
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
340
310
) from e
341
311
state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
342
312
if self .cpu_offload_state_dict is None :
343
313
logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
344
- self .cpu_offload_state_dict = _create_cpu_state_dict (
345
- state_dict , pin_memory = True
346
- )
314
+ self .cpu_offload_state_dict = _create_cpu_state_dict (state_dict , pin_memory = True )
347
315
348
316
logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
349
317
with torch .cuda .stream (self .staging_stream ):
@@ -374,9 +342,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
374
342
elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
375
343
self ._async_with_pinned_memory (checkpoint_id )
376
344
elif self .async_mode == AsyncMode .ASYNC :
377
- self .async_future = dcp .async_save (
378
- self .states , checkpoint_id = checkpoint_id , process_group = self .pg
379
- )
345
+ self .async_future = dcp .async_save (self .states , checkpoint_id = checkpoint_id , process_group = self .pg )
380
346
else :
381
347
dcp .save (self .states , checkpoint_id = checkpoint_id )
382
348
self .reset ()
@@ -388,16 +354,10 @@ def save(self, curr_step: int, force: bool = False) -> None:
388
354
)
389
355
390
356
def maybe_wait_for_staging (self ) -> None :
391
- if (
392
- self .enable_checkpoint
393
- and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
394
- and self .staging
395
- ):
357
+ if self .enable_checkpoint and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM and self .staging :
396
358
logger .debug (f"Waiting for staging, { time .monotonic ()= :.2f} ." )
397
359
self .staging_stream .synchronize ()
398
- logger .debug (
399
- f"Sending the state dict to the background process, { time .monotonic ()= :.2f} ."
400
- )
360
+ logger .debug (f"Sending the state dict to the background process, { time .monotonic ()= :.2f} ." )
401
361
self .mp_queue_send .put ((self .staging_state_dict , self .staging_id ))
402
362
self .staging = False
403
363
@@ -413,8 +373,7 @@ def load(self, step: int = -1) -> bool:
413
373
step_counts = []
414
374
for filename in os .listdir (self .folder ):
415
375
match = re .search (r"step-(\d+)" , filename )
416
- metadata_probe = os .path .join (
417
- self .folder , filename , ".metadata" )
376
+ metadata_probe = os .path .join (self .folder , filename , ".metadata" )
418
377
if match and os .path .isfile (metadata_probe ):
419
378
step_counts .append (int (match .group (1 )))
420
379
if not step_counts :
@@ -429,9 +388,7 @@ def load(self, step: int = -1) -> bool:
429
388
states ,
430
389
checkpoint_id = self ._create_checkpoint_id (step ),
431
390
)
432
- logger .info (
433
- f"Finished loading the checkpoint in { time .monotonic () - begin :.2f} seconds."
434
- )
391
+ logger .info (f"Finished loading the checkpoint in { time .monotonic () - begin :.2f} seconds." )
435
392
return True
436
393
437
394
def _purge_stale_checkpoints (self ):
0 commit comments