diff --git a/dlclibrary/dlcmodelzoo/modelzoo_download.py b/dlclibrary/dlcmodelzoo/modelzoo_download.py index cb986ae..4483cfa 100644 --- a/dlclibrary/dlcmodelzoo/modelzoo_download.py +++ b/dlclibrary/dlcmodelzoo/modelzoo_download.py @@ -12,6 +12,11 @@ import json import os +import tarfile +from pathlib import Path + +from huggingface_hub import hf_hub_download +from ruamel.yaml.comments import CommentedBase # just expand this list when adding new models: MODELOPTIONS = [ @@ -52,34 +57,54 @@ def parse_available_supermodels(): return json.load(file) +def _handle_downloaded_file( + file_path: str, target_dir: str, rename_mapping: dict | None = None +): + """Handle the downloaded file from HuggingFace""" + file_name = os.path.basename(file_path) + try: + with tarfile.open(file_path, mode="r:gz") as tar: + for member in tar: + if not member.isdir(): + fname = Path(member.name).name + tar.makefile(member, os.path.join(target_dir, fname)) + except tarfile.ReadError: # The model is a .pt file + if rename_mapping is not None: + file_name = rename_mapping.get(file_name, file_name) + if os.path.islink(file_path): + file_path_ = os.readlink(file_path) + if not os.path.isabs(file_path_): + file_path_ = os.path.abspath( + os.path.join(os.path.dirname(file_path), file_path_) + ) + file_path = file_path_ + os.rename(file_path, os.path.join(target_dir, file_name)) + + def download_huggingface_model( - modelname, target_dir=".", remove_hf_folder=True, rename_mapping: dict | None = None + model_name: str, + target_dir: str = ".", + remove_hf_folder: bool = True, + rename_mapping: dict | None = None, ): """ - Download a DeepLabCut Model Zoo Project from Hugging Face - - Parameters - ---------- - modelname : string - Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo - target_dir : directory (as string) - Directory where to store the model weights and pose_cfg.yaml file - remove_hf_folder : bool, default True - Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format. - rename_mapping : dict, default None - Dictionary to rename the downloaded file. If None, the original filename is used. + Downloads a DeepLabCut Model Zoo Project from Hugging Face. + + Args: + model_name (str): Name of the ModelZoo model. + For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo. + target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored. + remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace + after downloading and decompressing the data into DeepLabCut format. Defaults to True. + rename_mapping (dict, optional): A dictionary to rename the downloaded file. + If None, the original filename is used. Defaults to None. """ - from huggingface_hub import hf_hub_download - import tarfile - from pathlib import Path - from ruamel.yaml.comments import CommentedBase - - neturls = _load_model_names() - if modelname not in neturls: - raise ValueError(f"`modelname` should be one of: {', '.join(modelname)}.") + net_urls = _load_model_names() + if model_name not in net_urls: + raise ValueError(f"`modelname` should be one of: {', '.join(net_urls)}.") - print("Loading....", modelname) - urls = neturls[modelname] + print("Loading....", model_name) + urls = net_urls[model_name] if isinstance(urls, CommentedBase): urls = list(urls) else: @@ -98,26 +123,10 @@ def download_huggingface_model( hf_folder = f"models--{url[0]}--{url[1]}" path_ = os.path.join(target_dir, hf_folder, "snapshots") commit = os.listdir(path_)[0] - filename = os.path.join(path_, commit, targzfn) - try: - with tarfile.open(filename, mode="r:gz") as tar: - for member in tar: - if not member.isdir(): - fname = Path(member.name).name - tar.makefile(member, os.path.join(target_dir, fname)) - except tarfile.ReadError: # The model is a .pt file - if rename_mapping is not None: - targzfn = rename_mapping.get(targzfn, targzfn) - if os.path.islink(filename): - filename_ = os.readlink(filename) - if not os.path.isabs(filename_): - filename_ = os.path.abspath(os.path.join(os.path.dirname(filename), filename_)) - filename = filename_ - os.rename(filename, os.path.join(target_dir, targzfn)) + file_name = os.path.join(path_, commit, targzfn) + _handle_downloaded_file(file_name, target_dir, rename_mapping) if remove_hf_folder: import shutil shutil.rmtree(os.path.join(target_dir, hf_folder)) - -'../../blobs/6c9c66d48f25cac9f8adaea7a485b07f4bd781ba656785bc4e077d9064e8e5df' \ No newline at end of file