diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index 14a16b07fb..af911b6b6d 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -81,6 +81,11 @@ def download_from_hub( import huggingface_hub.constants as constants previous_flag = constants.HF_HUB_ENABLE_HF_TRANSFER # this may be redundant + + if _HF_TRANSFER_AVAILABLE and not previous_flag: + print("Setting HF_HUB_ENABLE_HF_TRANSFER=1") + constants.HF_HUB_ENABLE_HF_TRANSFER = True + download.HF_HUB_ENABLE_HF_TRANSFER = True directory = checkpoint_dir / repo_id with gated_repo_catcher(repo_id, access_token):