Skip to content

Commit d1103db

Browse files
committed
update pr
1 parent ce5d0fe commit d1103db

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

litgpt/scripts/download.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from concurrent.futures import ProcessPoolExecutor
55
from contextlib import contextmanager
6+
import importlib.util
67
from pathlib import Path
78
from typing import List, Optional, Tuple
89

@@ -56,6 +57,8 @@ def download_from_hub(
5657
return
5758

5859
from huggingface_hub import snapshot_download
60+
if importlib.util.find_spec("hf_transfer") is None:
61+
print("It is recommended to install hf_transfer for faster checkpoint download speeds: `pip install hf_transfer`")
5962

6063
download_files = ["tokenizer*", "generation_config.json", "config.json"]
6164
if not tokenizer_only:
@@ -70,22 +73,14 @@ def download_from_hub(
7073
else:
7174
raise ValueError(f"Couldn't find weight files for {repo_id}")
7275

73-
# Get and set env variable to improve download speed
74-
user_env_value = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")
75-
76-
if user_env_value is None:
77-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
78-
print("Setting HF_HUB_ENABLE_HF_TRANSFER=1 by default")
79-
8076
import huggingface_hub._snapshot_download as download
8177
import huggingface_hub.constants as constants
8278

83-
previous_flag = constants.HF_HUB_ENABLE_HF_TRANSFER # this may be redundant
84-
85-
if _HF_TRANSFER_AVAILABLE and not previous_flag:
86-
print("Setting HF_HUB_ENABLE_HF_TRANSFER=1")
87-
constants.HF_HUB_ENABLE_HF_TRANSFER = True
88-
download.HF_HUB_ENABLE_HF_TRANSFER = True
79+
previous = constants.HF_HUB_ENABLE_HF_TRANSFER
80+
if _HF_TRANSFER_AVAILABLE and not previous:
81+
print("Setting HF_HUB_ENABLE_HF_TRANSFER=1")
82+
constants.HF_HUB_ENABLE_HF_TRANSFER = True
83+
download.HF_HUB_ENABLE_HF_TRANSFER = True
8984

9085
directory = checkpoint_dir / repo_id
9186
with gated_repo_catcher(repo_id, access_token):
@@ -96,8 +91,8 @@ def download_from_hub(
9691
token=access_token,
9792
)
9893

99-
constants.HF_HUB_ENABLE_HF_TRANSFER = previous_flag
100-
download.HF_HUB_ENABLE_HF_TRANSFER = previous_flag
94+
constants.HF_HUB_ENABLE_HF_TRANSFER = previous
95+
download.HF_HUB_ENABLE_HF_TRANSFER = previous
10196

10297
if convert_checkpoint and not tokenizer_only:
10398
print("Converting checkpoint files to LitGPT format.")

0 commit comments

Comments
 (0)