Skip to content

Commit

Permalink
Merge pull request #40 from deepghs/dev/resume
Browse files Browse the repository at this point in the history
dev(narugo): optimize batch download
  • Loading branch information
narugo1992 authored Aug 14, 2024
2 parents 5aebcaf + ddc0e9d commit 5dc7adf
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 67 deletions.
175 changes: 114 additions & 61 deletions hfutils/operate/download.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
"""
This module provides functions for downloading files and directories from Hugging Face repositories.
It includes utilities for downloading individual files, archives, and entire directories,
with support for concurrent downloads, retries, and progress tracking.
The module interacts with the Hugging Face Hub API to fetch repository contents and
download files, handling various repository types and revisions.
Key features:
- Download individual files from Hugging Face repositories
- Download and extract archive files
- Download entire directories with pattern matching and ignore rules
- Concurrent downloads with configurable worker count
- Retry mechanism for failed downloads
- Progress tracking with tqdm
- Support for different repository types (dataset, model, space)
- Token-based authentication for accessing private repositories
This module is particularly useful for managing and synchronizing local copies of
Hugging Face repository contents, especially when dealing with large datasets or models.
"""

import logging
import os.path
import shutil
Expand All @@ -12,9 +36,50 @@
from ..utils import tqdm, TemporaryDirectory, hf_normpath


def _raw_download_file(td: str, local_file: str, repo_id: str, file_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
hf_token: Optional[str] = None):
"""
Download a file from a Hugging Face repository to a temporary directory and then move it to the final location.
This internal function handles the actual download process using the Hugging Face Hub client.
:param td: Temporary directory path.
:type td: str
:param local_file: The final local file path where the downloaded file will be moved.
:type local_file: str
:param repo_id: The identifier of the repository.
:type repo_id: str
:param file_in_repo: The file path within the repository.
:type file_in_repo: str
:param repo_type: The type of the repository ('dataset', 'model', 'space').
:type repo_type: RepoTypeTyping
:param revision: The revision of the repository (e.g., branch, tag, commit hash).
:type revision: str
:param hf_token: Hugging Face token for API client.
:type hf_token: str, optional
"""
hf_client = get_hf_client(hf_token=hf_token)
relative_filename = os.path.join(*file_in_repo.split("/"))
temp_path = os.path.join(td, relative_filename)
try:
hf_client.hf_hub_download(
repo_id=repo_id,
repo_type=repo_type,
filename=hf_normpath(file_in_repo),
revision=revision,
local_dir=td,
)
finally:
if os.path.exists(temp_path):
if os.path.dirname(local_file):
os.makedirs(os.path.dirname(local_file), exist_ok=True)
shutil.move(temp_path, local_file)


def download_file_to_file(local_file: str, repo_id: str, file_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
resume_download: bool = True, hf_token: Optional[str] = None):
hf_token: Optional[str] = None):
"""
Download a file from a Hugging Face repository and save it to a local file.
Expand All @@ -28,29 +93,19 @@ def download_file_to_file(local_file: str, repo_id: str, file_in_repo: str,
:type repo_type: RepoTypeTyping
:param revision: The revision of the repository (e.g., branch, tag, commit hash).
:type revision: str
:param resume_download: Resume the existing download.
:type resume_download: bool
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
hf_client = get_hf_client(hf_token)
relative_filename = os.path.join(*file_in_repo.split("/"))
with TemporaryDirectory() as td:
temp_path = os.path.join(td, relative_filename)
try:
hf_client.hf_hub_download(
repo_id=repo_id,
repo_type=repo_type,
filename=hf_normpath(file_in_repo),
revision=revision,
local_dir=td,
resume_download=resume_download,
)
finally:
if os.path.exists(temp_path):
if os.path.dirname(local_file):
os.makedirs(os.path.dirname(local_file), exist_ok=True)
shutil.move(temp_path, local_file)
_raw_download_file(
td=td,
local_file=local_file,
repo_id=repo_id,
file_in_repo=file_in_repo,
repo_type=repo_type,
revision=revision,
hf_token=hf_token,
)


def download_archive_as_directory(local_directory: str, repo_id: str, file_in_repo: str,
Expand Down Expand Up @@ -84,7 +139,7 @@ def download_directory_as_directory(
local_directory: str, repo_id: str, dir_in_repo: str = '.', pattern: str = '**/*',
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
resume_download: bool = True, max_workers: int = 8, max_retries: int = 5,
max_workers: int = 8, max_retries: int = 5,
soft_mode_when_check: bool = False, hf_token: Optional[str] = None
):
"""
Expand All @@ -110,8 +165,6 @@ def download_directory_as_directory(
:type max_workers: int
:param max_retries: Max retry times when downloading. Default is ``5``.
:type max_retries: int
:param resume_download: Resume the existing download.
:type resume_download: bool
:param soft_mode_when_check: Just check the size of the expected file when enabled. Default is False.
:type soft_mode_when_check: bool
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
Expand All @@ -129,46 +182,46 @@ def download_directory_as_directory(
progress = tqdm(files, silent=silent, desc=f'Downloading {dir_in_repo!r} ...')

def _download_one_file(rel_file):
current_resume_download = resume_download
try:
dst_file = os.path.join(local_directory, rel_file)
file_in_repo = hf_normpath(f'{dir_in_repo}/{rel_file}')
if os.path.exists(dst_file) and is_local_file_ready(
repo_id=repo_id,
repo_type=repo_type,
local_file=dst_file,
file_in_repo=file_in_repo,
revision=revision,
hf_token=hf_token,
soft_mode=soft_mode_when_check,
):
logging.info(f'Local file {rel_file} is ready, download skipped.')
else:
tries = 0
while True:
try:
download_file_to_file(
local_file=dst_file,
repo_id=repo_id,
file_in_repo=file_in_repo,
repo_type=repo_type,
revision=revision,
resume_download=current_resume_download,
hf_token=hf_token,
)
except requests.exceptions.RequestException as err:
if tries < max_retries:
tries += 1
logging.warning(f'Download {rel_file!r} failed, retry ({tries}/{max_retries}) - {err!r}.')
current_resume_download = True
with TemporaryDirectory() as td:
try:
dst_file = os.path.join(local_directory, rel_file)
file_in_repo = hf_normpath(f'{dir_in_repo}/{rel_file}')
if os.path.exists(dst_file) and is_local_file_ready(
repo_id=repo_id,
repo_type=repo_type,
local_file=dst_file,
file_in_repo=file_in_repo,
revision=revision,
hf_token=hf_token,
soft_mode=soft_mode_when_check,
):
logging.info(f'Local file {rel_file} is ready, download skipped.')
else:
tries = 0
while True:
try:
_raw_download_file(
td=td,
local_file=dst_file,
repo_id=repo_id,
file_in_repo=file_in_repo,
repo_type=repo_type,
revision=revision,
hf_token=hf_token,
)
except requests.exceptions.RequestException as err:
if tries < max_retries:
tries += 1
logging.warning(
f'Download {rel_file!r} failed, retry ({tries}/{max_retries}) - {err!r}.')
else:
raise
else:
raise
else:
break
break

progress.update()
except Exception as err:
logging.error(f'Unexpected error when downloading {rel_file!r} - {err!r}')
progress.update()
except Exception as err:
logging.exception(f'Unexpected error when downloading {rel_file!r} - {err!r}')

tp = ThreadPoolExecutor(max_workers=max_workers)
for file in files:
Expand Down
13 changes: 7 additions & 6 deletions test/operate/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hbutils.testing import isolated_directory

from hfutils.operate import download_file_to_file, download_archive_as_directory, download_directory_as_directory
from hfutils.operate.download import _raw_download_file
from test.testings import get_testfile, file_compare, dir_compare


Expand Down Expand Up @@ -37,9 +38,9 @@ def test_download_directory_as_directory(self):
def _my_download(*args, **kwargs):
nonlocal call_times
call_times += 1
return download_file_to_file(*args, **kwargs)
return _raw_download_file(*args, **kwargs)

with patch('hfutils.operate.download.download_file_to_file', _my_download), \
with patch('hfutils.operate.download._raw_download_file', _my_download), \
isolated_directory():
download_directory_as_directory(
'download_dir',
Expand All @@ -59,9 +60,9 @@ def test_download_directory_as_directory_partial(self):
def _my_download(*args, **kwargs):
nonlocal call_times
call_times += 1
return download_file_to_file(*args, **kwargs)
return _raw_download_file(*args, **kwargs)

with patch('hfutils.operate.download.download_file_to_file', _my_download), \
with patch('hfutils.operate.download._raw_download_file', _my_download), \
isolated_directory({'download_dir': src_dir}):
download_directory_as_directory(
'download_dir',
Expand All @@ -80,9 +81,9 @@ def test_download_directory_as_directory_with_pattern(self):
def _my_download(*args, **kwargs):
nonlocal call_times
call_times += 1
return download_file_to_file(*args, **kwargs)
return _raw_download_file(*args, **kwargs)

with patch('hfutils.operate.download.download_file_to_file', _my_download), \
with patch('hfutils.operate.download._raw_download_file', _my_download), \
isolated_directory():
download_directory_as_directory(
'download_dir',
Expand Down

0 comments on commit 5dc7adf

Please sign in to comment.