Skip to content

Commit

Permalink
dev(narugo): add soft mode
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Aug 9, 2024
1 parent d63e354 commit b1505cf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
7 changes: 6 additions & 1 deletion hfutils/entry/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ def _add_download_subcommand(cli: click.Group) -> click.Group:
help='Password for the archive file. Only applied when -a is used.', show_default=True)
@click.option('-w', '--wildcard', 'wildcard', type=str, default=None,
help='Wildcard for files to download. Only applied when -d is used.', show_default=True)
@click.option('-s', '--soft_mode_when_check', 'soft_mode_when_check', is_flag=True, type=bool, default=False,
help='Just check the file size when validating the downloaded files.', show_default=True)
@click.option('--tmpdir', 'tmpdir', type=str, default=None,
help='Use custom temporary Directory.', show_default=True)
@command_wrap()
def download(
repo_id: str, repo_type: RepoTypeTyping,
file_in_repo: Optional[str], archive_in_repo: Optional[str], dir_in_repo: Optional[str],
output_path: str, revision: str, max_workers: int,
password: Optional[str], wildcard: Optional[str], tmpdir: Optional[str]
password: Optional[str], wildcard: Optional[str], soft_mode_when_check: bool, tmpdir: Optional[str]
):
"""
Download data from HuggingFace repositories.
Expand All @@ -83,6 +85,8 @@ def download(
:type password: str, optional
:param wildcard: Wildcard for files to download. Only applied when -d is used.
:type password: str, optional
:param soft_mode_when_check: Just check the size of the expected file when enabled. Default is False.
:type soft_mode_when_check: bool
:param tmpdir: Use custom temporary Directory.
:type tmpdir: str, optional
"""
Expand Down Expand Up @@ -130,6 +134,7 @@ def download(
revision=revision,
silent=False,
max_workers=max_workers,
soft_mode_when_check=soft_mode_when_check,
)

else:
Expand Down
5 changes: 4 additions & 1 deletion hfutils/operate/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def download_directory_as_directory(
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
resume_download: bool = True, max_workers: int = 8, max_retries: int = 5,
hf_token: Optional[str] = None
soft_mode_when_check: bool = False, hf_token: Optional[str] = None
):
"""
Download all files in a directory from a Hugging Face repository to a local directory.
Expand All @@ -112,6 +112,8 @@ def download_directory_as_directory(
:type max_retries: int
:param resume_download: Resume the existing download.
:type resume_download: bool
:param soft_mode_when_check: Just check the size of the expected file when enabled. Default is False.
:type soft_mode_when_check: bool
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
Expand All @@ -138,6 +140,7 @@ def _download_one_file(rel_file):
file_in_repo=file_in_repo,
revision=revision,
hf_token=hf_token,
soft_mode=soft_mode_when_check,
):
logging.info(f'Local file {rel_file} is ready, download skipped.')
else:
Expand Down
14 changes: 11 additions & 3 deletions hfutils/operate/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .base import RepoTypeTyping, get_hf_client


def _raw_check_local_file(repo_file: RepoFile, local_file: str, chunk_for_hash: int = 1 << 20) -> bool:
def _raw_check_local_file(repo_file: RepoFile, local_file: str, chunk_for_hash: int = 1 << 20,
soft_mode: bool = False) -> bool:
"""
Checks if the local file matches the file on the Hugging Face Hub repository.
Expand All @@ -19,6 +20,8 @@ def _raw_check_local_file(repo_file: RepoFile, local_file: str, chunk_for_hash:
:type local_file: str
:param chunk_for_hash: The chunk size for calculating the hash. Default is 1 << 20.
:type chunk_for_hash: int
:param soft_mode: Just check the size of the expected file when enabled. Default is False.
:type soft_mode: bool
:return: True if the local file matches the file on the repository, False otherwise.
:rtype: bool
"""
Expand All @@ -28,6 +31,9 @@ def _raw_check_local_file(repo_file: RepoFile, local_file: str, chunk_for_hash:
f'the remote file {repo_file.path!r} ({repo_file.size}).')
return False

if soft_mode:
return True

if repo_file.lfs:
sha = sha256()
expected_hash = repo_file.lfs.sha256
Expand All @@ -54,7 +60,7 @@ def _raw_check_local_file(repo_file: RepoFile, local_file: str, chunk_for_hash:

def is_local_file_ready(local_file: str, repo_id: str, file_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
hf_token: Optional[str] = None) -> bool:
soft_mode: bool = False, hf_token: Optional[str] = None) -> bool:
"""
Checks if the local file is ready by comparing it with the file on the Hugging Face Hub repository.
Expand All @@ -68,6 +74,8 @@ def is_local_file_ready(local_file: str, repo_id: str, file_in_repo: str,
:type repo_type: RepoTypeTyping
:param revision: The revision of the repository. Default is 'main'.
:type revision: str
:param soft_mode: Just check the size of the expected file when enabled. Default is False.
:type soft_mode: bool
:param hf_token: The Hugging Face API token. Default is None.
:type hf_token: Optional[str]
:return: True if the local file matches the file on the repository, False otherwise.
Expand All @@ -83,6 +91,6 @@ def is_local_file_ready(local_file: str, repo_id: str, file_in_repo: str,
if len(infos) == 0:
raise EntryNotFoundError(f'Entry {repo_type}s/{repo_id}/{file_in_repo} not found.')
elif len(infos) == 1:
return _raw_check_local_file(infos[0], local_file)
return _raw_check_local_file(infos[0], local_file, soft_mode=soft_mode)
else:
assert False, f'Should not reach here, multiple files with the same name found - {infos!r}' # pragma: no cover

0 comments on commit b1505cf

Please sign in to comment.