26
26
import threading
27
27
import time
28
28
from typing import Any , Dict , List , NamedTuple , Optional , Tuple , Union
29
+ import warnings
29
30
30
31
import numpy as np
31
32
32
33
from .utils import (
34
+ ADAPTER_SAFE_WEIGHTS_NAME ,
35
+ ADAPTER_WEIGHTS_NAME ,
36
+ CONFIG_NAME ,
37
+ SAFE_WEIGHTS_INDEX_NAME ,
38
+ SAFE_WEIGHTS_NAME ,
39
+ WEIGHTS_INDEX_NAME ,
40
+ WEIGHTS_NAME ,
33
41
ExplicitEnum ,
34
42
is_psutil_available ,
35
43
is_tf_available ,
42
50
is_torch_xla_available ,
43
51
is_torch_xpu_available ,
44
52
requires_backends ,
45
- ADAPTER_CONFIG_NAME ,
46
- ADAPTER_SAFE_WEIGHTS_NAME ,
47
- ADAPTER_WEIGHTS_NAME ,
48
- CONFIG_NAME ,
49
- SAFE_WEIGHTS_INDEX_NAME ,
50
- SAFE_WEIGHTS_NAME ,
51
- WEIGHTS_INDEX_NAME ,
52
- WEIGHTS_NAME ,
53
53
)
54
54
55
+
55
56
FSDP_MODEL_NAME = "pytorch_model_fsdp"
56
57
TRAINER_STATE_NAME = "trainer_state.json"
57
58
@@ -222,12 +223,25 @@ class TrainOutput(NamedTuple):
222
223
def is_valid_checkpoint_dir (folder ):
223
224
return any (
224
225
os .path .isfile (os .path .join (folder , f ))
225
- for f in [ WEIGHTS_NAME , SAFE_WEIGHTS_NAME , WEIGHTS_INDEX_NAME , SAFE_WEIGHTS_NAME , SAFE_WEIGHTS_INDEX_NAME , ADAPTER_WEIGHTS_NAME , ADAPTER_SAFE_WEIGHTS_NAME , f"{ FSDP_MODEL_NAME } .bin" ]
226
+ for f in [
227
+ WEIGHTS_NAME ,
228
+ SAFE_WEIGHTS_NAME ,
229
+ WEIGHTS_INDEX_NAME ,
230
+ SAFE_WEIGHTS_NAME ,
231
+ SAFE_WEIGHTS_INDEX_NAME ,
232
+ ADAPTER_WEIGHTS_NAME ,
233
+ ADAPTER_SAFE_WEIGHTS_NAME ,
234
+ f"{ FSDP_MODEL_NAME } .bin" ,
235
+ ]
226
236
) and all (
227
237
os .path .isfile (os .path .join (folder , f ))
228
- for f in [ CONFIG_NAME , TRAINER_STATE_NAME , ]
238
+ for f in [
239
+ CONFIG_NAME ,
240
+ TRAINER_STATE_NAME ,
241
+ ]
229
242
)
230
243
244
+
231
245
def get_last_checkpoint (folder ):
232
246
content = os .listdir (folder )
233
247
checkpoints = [
@@ -239,10 +253,14 @@ def get_last_checkpoint(folder):
239
253
return
240
254
241
255
for checkpoint in sorted (checkpoints , key = lambda x : int (_re_checkpoint .search (x ).groups ()[0 ]), reverse = True ):
242
- if is_valid_checkpoint_dir (checkpoint ):
243
- break
256
+ if is_valid_checkpoint_dir (checkpoint ):
257
+ break
258
+ else :
259
+ warnings .warn ("Skipping checkpoint dir {checkpoint} due to missing files" , UserWarning , stacklevel = 2 )
260
+
244
261
return os .path .join (folder , checkpoint )
245
262
263
+
246
264
class IntervalStrategy (ExplicitEnum ):
247
265
NO = "no"
248
266
STEPS = "steps"
0 commit comments