Skip to content

Commit 209b461

Browse files
committed
Refactor and other formatting improvements
1 parent 15c13ef commit 209b461

File tree

1 file changed

+50
-41
lines changed

1 file changed

+50
-41
lines changed

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
import json
1414
import os
15+
import tarfile
16+
from pathlib import Path
17+
18+
from huggingface_hub import hf_hub_download
19+
from ruamel.yaml.comments import CommentedBase
1520

1621
# just expand this list when adding new models:
1722
MODELOPTIONS = [
@@ -52,34 +57,54 @@ def parse_available_supermodels():
5257
return json.load(file)
5358

5459

60+
def _handle_downloaded_file(
61+
file_path: str, target_dir: str, rename_mapping: dict | None = None
62+
):
63+
"""Handle the downloaded file from HuggingFace"""
64+
file_name = os.path.basename(file_path)
65+
try:
66+
with tarfile.open(file_path, mode="r:gz") as tar:
67+
for member in tar:
68+
if not member.isdir():
69+
fname = Path(member.name).name
70+
tar.makefile(member, os.path.join(target_dir, fname))
71+
except tarfile.ReadError: # The model is a .pt file
72+
if rename_mapping is not None:
73+
file_name = rename_mapping.get(file_name, file_name)
74+
if os.path.islink(file_path):
75+
file_path_ = os.readlink(file_path)
76+
if not os.path.isabs(file_path_):
77+
file_path_ = os.path.abspath(
78+
os.path.join(os.path.dirname(file_path), file_path_)
79+
)
80+
file_path = file_path_
81+
os.rename(file_path, os.path.join(target_dir, file_name))
82+
83+
5584
def download_huggingface_model(
56-
modelname, target_dir=".", remove_hf_folder=True, rename_mapping: dict | None = None
85+
model_name: str,
86+
target_dir: str = ".",
87+
remove_hf_folder: bool = True,
88+
rename_mapping: dict | None = None,
5789
):
5890
"""
59-
Download a DeepLabCut Model Zoo Project from Hugging Face
60-
61-
Parameters
62-
----------
63-
modelname : string
64-
Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
65-
target_dir : directory (as string)
66-
Directory where to store the model weights and pose_cfg.yaml file
67-
remove_hf_folder : bool, default True
68-
Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
69-
rename_mapping : dict, default None
70-
Dictionary to rename the downloaded file. If None, the original filename is used.
91+
Downloads a DeepLabCut Model Zoo Project from Hugging Face.
92+
93+
Args:
94+
model_name (str): Name of the ModelZoo model.
95+
For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo.
96+
target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored.
97+
remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace
98+
after downloading and decompressing the data into DeepLabCut format. Defaults to True.
99+
rename_mapping (dict, optional): A dictionary to rename the downloaded file.
100+
If None, the original filename is used. Defaults to None.
71101
"""
72-
from huggingface_hub import hf_hub_download
73-
import tarfile
74-
from pathlib import Path
75-
from ruamel.yaml.comments import CommentedBase
76-
77-
neturls = _load_model_names()
78-
if modelname not in neturls:
79-
raise ValueError(f"`modelname` should be one of: {', '.join(modelname)}.")
102+
net_urls = _load_model_names()
103+
if model_name not in net_urls:
104+
raise ValueError(f"`modelname` should be one of: {', '.join(net_urls)}.")
80105

81-
print("Loading....", modelname)
82-
urls = neturls[modelname]
106+
print("Loading....", model_name)
107+
urls = net_urls[model_name]
83108
if isinstance(urls, CommentedBase):
84109
urls = list(urls)
85110
else:
@@ -98,26 +123,10 @@ def download_huggingface_model(
98123
hf_folder = f"models--{url[0]}--{url[1]}"
99124
path_ = os.path.join(target_dir, hf_folder, "snapshots")
100125
commit = os.listdir(path_)[0]
101-
filename = os.path.join(path_, commit, targzfn)
102-
try:
103-
with tarfile.open(filename, mode="r:gz") as tar:
104-
for member in tar:
105-
if not member.isdir():
106-
fname = Path(member.name).name
107-
tar.makefile(member, os.path.join(target_dir, fname))
108-
except tarfile.ReadError: # The model is a .pt file
109-
if rename_mapping is not None:
110-
targzfn = rename_mapping.get(targzfn, targzfn)
111-
if os.path.islink(filename):
112-
filename_ = os.readlink(filename)
113-
if not os.path.isabs(filename_):
114-
filename_ = os.path.abspath(os.path.join(os.path.dirname(filename), filename_))
115-
filename = filename_
116-
os.rename(filename, os.path.join(target_dir, targzfn))
126+
file_name = os.path.join(path_, commit, targzfn)
127+
_handle_downloaded_file(file_name, target_dir, rename_mapping)
117128

118129
if remove_hf_folder:
119130
import shutil
120131

121132
shutil.rmtree(os.path.join(target_dir, hf_folder))
122-
123-
'../../blobs/6c9c66d48f25cac9f8adaea7a485b07f4bd781ba656785bc4e077d9064e8e5df'

0 commit comments

Comments
 (0)