Skip to content

Commit

Permalink
Merge pull request #2 from deepghs/dev/custom
Browse files Browse the repository at this point in the history
dev(narugo): allow the usage of custom hf token
  • Loading branch information
narugo1992 authored Jan 5, 2024
2 parents 0247090 + 3799c53 commit 53194da
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 17 deletions.
21 changes: 15 additions & 6 deletions hfutils/operate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,33 @@ def _get_hf_token() -> Optional[str]:


@lru_cache()
def get_hf_client() -> HfApi:
def get_hf_client(hf_token: Optional[str] = None) -> HfApi:
"""
Get the Hugging Face API client.
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
:return: The Hugging Face API client.
:rtype: HfApi
"""
return HfApi(token=_get_hf_token())
return HfApi(token=hf_token or _get_hf_token())


@lru_cache()
def get_hf_fs() -> HfFileSystem:
def get_hf_fs(hf_token: Optional[str] = None) -> HfFileSystem:
"""
Get the Hugging Face file system.
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
:return: The Hugging Face file system.
:rtype: HfFileSystem
"""
# use_listings_cache=False is necessary
# or the result of glob and ls will be cached, the unittest will down
return HfFileSystem(token=_get_hf_token(), use_listings_cache=False)
return HfFileSystem(token=hf_token or _get_hf_token(), use_listings_cache=False)


_DEFAULT_IGNORE_PATTERNS = ['.git*']
Expand Down Expand Up @@ -70,7 +76,8 @@ def _is_file_ignored(file_segments: List[str], ignore_patterns: List[str]) -> bo

def list_files_in_repository(repo_id: str, repo_type: RepoTypeTyping = 'dataset',
subdir: str = '', revision: str = 'main',
ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET) -> List[str]:
ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
hf_token: Optional[str] = None) -> List[str]:
"""
List files in a Hugging Face repository based on the given parameters.
Expand All @@ -84,13 +91,15 @@ def list_files_in_repository(repo_id: str, repo_type: RepoTypeTyping = 'dataset'
:type revision: str
:param ignore_patterns: List of file patterns to ignore.
:type ignore_patterns: List[str]
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
:return: A list of file paths.
:rtype: List[str]
"""
if ignore_patterns is _IGNORE_PATTERN_UNSET:
ignore_patterns = _DEFAULT_IGNORE_PATTERNS
hf_fs = get_hf_fs()
hf_fs = get_hf_fs(hf_token)
if repo_type == 'model':
repo_root_path = repo_id
elif repo_type == 'dataset':
Expand Down
18 changes: 13 additions & 5 deletions hfutils/operate/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def download_file_to_file(local_file: str, repo_id: str, file_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
resume_download: bool = True):
resume_download: bool = True, hf_token: Optional[str] = None):
"""
Download a file from a Hugging Face repository and save it to a local file.
Expand All @@ -26,8 +26,10 @@ def download_file_to_file(local_file: str, repo_id: str, file_in_repo: str,
:type revision: str
:param resume_download: Resume the existing download.
:type resume_download: bool
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
hf_client = get_hf_client()
hf_client = get_hf_client(hf_token)
relative_filename = os.path.join(*file_in_repo.split("/"))
with TemporaryDirectory() as td:
temp_path = os.path.join(td, relative_filename)
Expand All @@ -51,7 +53,7 @@ def download_file_to_file(local_file: str, repo_id: str, file_in_repo: str,

def download_archive_as_directory(local_directory: str, repo_id: str, file_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
password: Optional[str] = None):
password: Optional[str] = None, hf_token: Optional[str] = None):
"""
Download an archive file from a Hugging Face repository and extract it to a local directory.
Expand All @@ -67,17 +69,20 @@ def download_archive_as_directory(local_directory: str, repo_id: str, file_in_re
:type revision: str
:param password: The password of the archive file.
:type password: str, optional
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
with TemporaryDirectory() as td:
archive_file = os.path.join(td, os.path.basename(file_in_repo))
download_file_to_file(archive_file, repo_id, file_in_repo, repo_type, revision)
download_file_to_file(archive_file, repo_id, file_in_repo, repo_type, revision, hf_token=hf_token)
archive_unpack(archive_file, local_directory, password=password)


def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_repo: str = '.',
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
resume_download: bool = True, max_workers: int = 8):
resume_download: bool = True, max_workers: int = 8,
hf_token: Optional[str] = None):
"""
Download all files in a directory from a Hugging Face repository to a local directory.
Expand All @@ -99,6 +104,8 @@ def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_r
:type max_workers: int
:param resume_download: Resume the existing download.
:type resume_download: bool
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
files = list_files_in_repository(repo_id, repo_type, dir_in_repo, revision, ignore_patterns)
progress = tqdm(files, silent=silent, desc=f'Downloading {dir_in_repo!r} ...')
Expand All @@ -111,6 +118,7 @@ def _download_one_file(rel_file):
repo_type=repo_type,
revision=revision,
resume_download=resume_download,
hf_token=hf_token,
)
progress.update()

Expand Down
21 changes: 15 additions & 6 deletions hfutils/operate/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def upload_file_to_file(local_file, repo_id: str, file_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
message: Optional[str] = None):
message: Optional[str] = None, hf_token: Optional[str] = None):
"""
Upload a local file to a specified path in a Hugging Face repository.
Expand All @@ -28,8 +28,10 @@ def upload_file_to_file(local_file, repo_id: str, file_in_repo: str,
:type revision: str
:param message: The commit message for the upload.
:type message: Optional[str]
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
hf_client = get_hf_client()
hf_client = get_hf_client(hf_token)
hf_client.upload_file(
repo_id=repo_id,
repo_type=repo_type,
Expand All @@ -42,7 +44,8 @@ def upload_file_to_file(local_file, repo_id: str, file_in_repo: str,

def upload_directory_as_archive(local_directory, repo_id: str, archive_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
message: Optional[str] = None, silent: bool = False):
message: Optional[str] = None, silent: bool = False,
hf_token: Optional[str] = None):
"""
Upload a local directory as an archive file to a specified path in a Hugging Face repository.
Expand All @@ -60,12 +63,15 @@ def upload_directory_as_archive(local_directory, repo_id: str, archive_in_repo:
:type message: Optional[str]
:param silent: If True, suppress progress bar output.
:type silent: bool
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
archive_type = get_archive_type(archive_in_repo)
with TemporaryDirectory() as td:
local_archive_file = os.path.join(td, os.path.basename(archive_in_repo))
archive_pack(archive_type, local_directory, local_archive_file, silent=silent)
upload_file_to_file(local_archive_file, repo_id, archive_in_repo, repo_type, revision, message)
upload_file_to_file(local_archive_file, repo_id, archive_in_repo,
repo_type, revision, message, hf_token=hf_token)


_PATH_SEP = re.compile(r'[/\\]+')
Expand All @@ -74,7 +80,8 @@ def upload_directory_as_archive(local_directory, repo_id: str, archive_in_repo:
def upload_directory_as_directory(local_directory, repo_id: str, path_in_repo: str,
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
message: Optional[str] = None, time_suffix: bool = True,
clear: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET):
clear: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
hf_token: Optional[str] = None):
"""
Upload a local directory and its files to a specified path in a Hugging Face repository.
Expand All @@ -96,8 +103,10 @@ def upload_directory_as_directory(local_directory, repo_id: str, path_in_repo: s
:type clear: bool
:param ignore_patterns: List of file patterns to ignore.
:type ignore_patterns: List[str]
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
hf_client = get_hf_client()
hf_client = get_hf_client(hf_token)
if clear:
pre_exist_files = {
tuple(file.split('/')) for file in
Expand Down

0 comments on commit 53194da

Please sign in to comment.