12
12
13
13
import json
14
14
import os
15
+ import tarfile
16
+ from pathlib import Path
17
+
18
+ from huggingface_hub import hf_hub_download
19
+ from ruamel .yaml .comments import CommentedBase
15
20
16
21
# just expand this list when adding new models:
17
22
MODELOPTIONS = [
@@ -52,34 +57,54 @@ def parse_available_supermodels():
52
57
return json .load (file )
53
58
54
59
60
+ def _handle_downloaded_file (
61
+ file_path : str , target_dir : str , rename_mapping : dict | None = None
62
+ ):
63
+ """Handle the downloaded file from HuggingFace"""
64
+ file_name = os .path .basename (file_path )
65
+ try :
66
+ with tarfile .open (file_path , mode = "r:gz" ) as tar :
67
+ for member in tar :
68
+ if not member .isdir ():
69
+ fname = Path (member .name ).name
70
+ tar .makefile (member , os .path .join (target_dir , fname ))
71
+ except tarfile .ReadError : # The model is a .pt file
72
+ if rename_mapping is not None :
73
+ file_name = rename_mapping .get (file_name , file_name )
74
+ if os .path .islink (file_path ):
75
+ file_path_ = os .readlink (file_path )
76
+ if not os .path .isabs (file_path_ ):
77
+ file_path_ = os .path .abspath (
78
+ os .path .join (os .path .dirname (file_path ), file_path_ )
79
+ )
80
+ file_path = file_path_
81
+ os .rename (file_path , os .path .join (target_dir , file_name ))
82
+
83
+
55
84
def download_huggingface_model (
56
- modelname , target_dir = "." , remove_hf_folder = True , rename_mapping : dict | None = None
85
+ model_name : str ,
86
+ target_dir : str = "." ,
87
+ remove_hf_folder : bool = True ,
88
+ rename_mapping : dict | None = None ,
57
89
):
58
90
"""
59
- Download a DeepLabCut Model Zoo Project from Hugging Face
60
-
61
- Parameters
62
- ----------
63
- modelname : string
64
- Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
65
- target_dir : directory (as string)
66
- Directory where to store the model weights and pose_cfg.yaml file
67
- remove_hf_folder : bool, default True
68
- Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
69
- rename_mapping : dict, default None
70
- Dictionary to rename the downloaded file. If None, the original filename is used.
91
+ Downloads a DeepLabCut Model Zoo Project from Hugging Face.
92
+
93
+ Args:
94
+ model_name (str): Name of the ModelZoo model.
95
+ For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo.
96
+ target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored.
97
+ remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace
98
+ after downloading and decompressing the data into DeepLabCut format. Defaults to True.
99
+ rename_mapping (dict, optional): A dictionary to rename the downloaded file.
100
+ If None, the original filename is used. Defaults to None.
71
101
"""
72
- from huggingface_hub import hf_hub_download
73
- import tarfile
74
- from pathlib import Path
75
- from ruamel .yaml .comments import CommentedBase
76
-
77
- neturls = _load_model_names ()
78
- if modelname not in neturls :
79
- raise ValueError (f"`modelname` should be one of: { ', ' .join (modelname )} ." )
102
+ net_urls = _load_model_names ()
103
+ if model_name not in net_urls :
104
+ raise ValueError (f"`modelname` should be one of: { ', ' .join (net_urls )} ." )
80
105
81
- print ("Loading...." , modelname )
82
- urls = neturls [ modelname ]
106
+ print ("Loading...." , model_name )
107
+ urls = net_urls [ model_name ]
83
108
if isinstance (urls , CommentedBase ):
84
109
urls = list (urls )
85
110
else :
@@ -98,26 +123,10 @@ def download_huggingface_model(
98
123
hf_folder = f"models--{ url [0 ]} --{ url [1 ]} "
99
124
path_ = os .path .join (target_dir , hf_folder , "snapshots" )
100
125
commit = os .listdir (path_ )[0 ]
101
- filename = os .path .join (path_ , commit , targzfn )
102
- try :
103
- with tarfile .open (filename , mode = "r:gz" ) as tar :
104
- for member in tar :
105
- if not member .isdir ():
106
- fname = Path (member .name ).name
107
- tar .makefile (member , os .path .join (target_dir , fname ))
108
- except tarfile .ReadError : # The model is a .pt file
109
- if rename_mapping is not None :
110
- targzfn = rename_mapping .get (targzfn , targzfn )
111
- if os .path .islink (filename ):
112
- filename_ = os .readlink (filename )
113
- if not os .path .isabs (filename_ ):
114
- filename_ = os .path .abspath (os .path .join (os .path .dirname (filename ), filename_ ))
115
- filename = filename_
116
- os .rename (filename , os .path .join (target_dir , targzfn ))
126
+ file_name = os .path .join (path_ , commit , targzfn )
127
+ _handle_downloaded_file (file_name , target_dir , rename_mapping )
117
128
118
129
if remove_hf_folder :
119
130
import shutil
120
131
121
132
shutil .rmtree (os .path .join (target_dir , hf_folder ))
122
-
123
- '../../blobs/6c9c66d48f25cac9f8adaea7a485b07f4bd781ba656785bc4e077d9064e8e5df'
0 commit comments