3
3
import os
4
4
from concurrent .futures import ProcessPoolExecutor
5
5
from contextlib import contextmanager
6
+ import importlib .util
6
7
from pathlib import Path
7
8
from typing import List , Optional , Tuple
8
9
@@ -56,6 +57,8 @@ def download_from_hub(
56
57
return
57
58
58
59
from huggingface_hub import snapshot_download
60
+ if importlib .util .find_spec ("hf_transfer" ) is None :
61
+ print ("It is recommended to install hf_transfer for faster checkpoint download speeds: `pip install hf_transfer`" )
59
62
60
63
download_files = ["tokenizer*" , "generation_config.json" , "config.json" ]
61
64
if not tokenizer_only :
@@ -70,22 +73,14 @@ def download_from_hub(
70
73
else :
71
74
raise ValueError (f"Couldn't find weight files for { repo_id } " )
72
75
73
- # Get and set env variable to improve download speed
74
- user_env_value = os .environ .get ("HF_HUB_ENABLE_HF_TRANSFER" )
75
-
76
- if user_env_value is None :
77
- os .environ ["HF_HUB_ENABLE_HF_TRANSFER" ] = "1"
78
- print ("Setting HF_HUB_ENABLE_HF_TRANSFER=1 by default" )
79
-
80
76
import huggingface_hub ._snapshot_download as download
81
77
import huggingface_hub .constants as constants
82
78
83
- previous_flag = constants .HF_HUB_ENABLE_HF_TRANSFER # this may be redundant
84
-
85
- if _HF_TRANSFER_AVAILABLE and not previous_flag :
86
- print ("Setting HF_HUB_ENABLE_HF_TRANSFER=1" )
87
- constants .HF_HUB_ENABLE_HF_TRANSFER = True
88
- download .HF_HUB_ENABLE_HF_TRANSFER = True
79
+ previous = constants .HF_HUB_ENABLE_HF_TRANSFER
80
+ if _HF_TRANSFER_AVAILABLE and not previous :
81
+ print ("Setting HF_HUB_ENABLE_HF_TRANSFER=1" )
82
+ constants .HF_HUB_ENABLE_HF_TRANSFER = True
83
+ download .HF_HUB_ENABLE_HF_TRANSFER = True
89
84
90
85
directory = checkpoint_dir / repo_id
91
86
with gated_repo_catcher (repo_id , access_token ):
@@ -96,8 +91,8 @@ def download_from_hub(
96
91
token = access_token ,
97
92
)
98
93
99
- constants .HF_HUB_ENABLE_HF_TRANSFER = previous_flag
100
- download .HF_HUB_ENABLE_HF_TRANSFER = previous_flag
94
+ constants .HF_HUB_ENABLE_HF_TRANSFER = previous
95
+ download .HF_HUB_ENABLE_HF_TRANSFER = previous
101
96
102
97
if convert_checkpoint and not tokenizer_only :
103
98
print ("Converting checkpoint files to LitGPT format." )
0 commit comments