Skip to content

Commit 9d0410d

Browse files
committed
Linter check, add warning message
1 parent 3e342e9 commit 9d0410d

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

src/transformers/trainer_utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,18 @@
2626
import threading
2727
import time
2828
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
29+
import warnings
2930

3031
import numpy as np
3132

3233
from .utils import (
34+
ADAPTER_SAFE_WEIGHTS_NAME,
35+
ADAPTER_WEIGHTS_NAME,
36+
CONFIG_NAME,
37+
SAFE_WEIGHTS_INDEX_NAME,
38+
SAFE_WEIGHTS_NAME,
39+
WEIGHTS_INDEX_NAME,
40+
WEIGHTS_NAME,
3341
ExplicitEnum,
3442
is_psutil_available,
3543
is_tf_available,
@@ -42,16 +50,9 @@
4250
is_torch_xla_available,
4351
is_torch_xpu_available,
4452
requires_backends,
45-
ADAPTER_CONFIG_NAME,
46-
ADAPTER_SAFE_WEIGHTS_NAME,
47-
ADAPTER_WEIGHTS_NAME,
48-
CONFIG_NAME,
49-
SAFE_WEIGHTS_INDEX_NAME,
50-
SAFE_WEIGHTS_NAME,
51-
WEIGHTS_INDEX_NAME,
52-
WEIGHTS_NAME,
5353
)
5454

55+
5556
FSDP_MODEL_NAME = "pytorch_model_fsdp"
5657
TRAINER_STATE_NAME = "trainer_state.json"
5758

@@ -222,12 +223,25 @@ class TrainOutput(NamedTuple):
222223
def is_valid_checkpoint_dir(folder):
223224
return any(
224225
os.path.isfile(os.path.join(folder, f))
225-
for f in [ WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME , f"{FSDP_MODEL_NAME}.bin"]
226+
for f in [
227+
WEIGHTS_NAME,
228+
SAFE_WEIGHTS_NAME,
229+
WEIGHTS_INDEX_NAME,
230+
SAFE_WEIGHTS_NAME,
231+
SAFE_WEIGHTS_INDEX_NAME,
232+
ADAPTER_WEIGHTS_NAME,
233+
ADAPTER_SAFE_WEIGHTS_NAME,
234+
f"{FSDP_MODEL_NAME}.bin",
235+
]
226236
) and all(
227237
os.path.isfile(os.path.join(folder, f))
228-
for f in [ CONFIG_NAME, TRAINER_STATE_NAME, ]
238+
for f in [
239+
CONFIG_NAME,
240+
TRAINER_STATE_NAME,
241+
]
229242
)
230243

244+
231245
def get_last_checkpoint(folder):
232246
content = os.listdir(folder)
233247
checkpoints = [
@@ -239,10 +253,14 @@ def get_last_checkpoint(folder):
239253
return
240254

241255
for checkpoint in sorted(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]), reverse=True):
242-
if is_valid_checkpoint_dir(checkpoint):
243-
break
256+
if is_valid_checkpoint_dir(checkpoint):
257+
break
258+
else:
259+
warnings.warn("Skipping checkpoint dir {checkpoint} due to missing files", UserWarning, stacklevel=2)
260+
244261
return os.path.join(folder, checkpoint)
245262

263+
246264
class IntervalStrategy(ExplicitEnum):
247265
NO = "no"
248266
STEPS = "steps"

0 commit comments

Comments
 (0)