Skip to content

Commit 5a22726

Browse files
authored
Merge pull request #29 from DeepLabCut/niels/sa_detectors
SuperAnimal Model Updates
2 parents b7d9dde + c4237cb commit 5a22726

File tree

6 files changed

+145
-11
lines changed

6 files changed

+145
-11
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ model_dir.mkdir()
3535
download_huggingface_model("superanimal_quadruped", model_dir)
3636
```
3737

38+
PyTorch models available for a given dataset (compatible with DeepLabCut>=3.0) can be
39+
listed using the `dlclibrary.get_available_detectors` and
40+
`dlclibrary.get_available_models` methods. Example use:
41+
42+
```python
43+
>>> import dlclibrary
44+
>>> dlclibrary.get_available_detectors("superanimal_bird")
45+
['fasterrcnn_mobilenet_v3_large_fpn', 'ssdlite']
46+
47+
>>> dlclibrary.get_available_models("superanimal_bird")
48+
['resnet_50']
49+
```
50+
51+
3852
## How to add a new model?
3953

4054
Pick a good model_name. Follow the (novel) naming convention (modeltype_species), e.g. ```superanimal_topviewmouse```.

dlclibrary/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from dlclibrary.dlcmodelzoo.modelzoo_download import (
1313
download_huggingface_model,
14+
get_available_detectors,
15+
get_available_models,
1416
parse_available_supermodels,
1517
)
1618
from dlclibrary.version import __version__, VERSION

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pathlib import Path
1717

1818
from huggingface_hub import hf_hub_download
19+
from ruamel.yaml import YAML
1920
from ruamel.yaml.comments import CommentedBase
2021

2122
# just expand this list when adding new models:
@@ -27,12 +28,9 @@
2728
"mouse_pupil_vclose",
2829
"horse_sideview",
2930
"full_macaque",
30-
"superanimal_topviewmouse_dlcrnet",
31-
"superanimal_quadruped_dlcrnet",
32-
"superanimal_topviewmouse_hrnetw32",
33-
"superanimal_quadruped_hrnetw32",
34-
"superanimal_topviewmouse", # DeepLabCut 2.X backwards compatibility
35-
"superanimal_quadruped", # DeepLabCut 2.X backwards compatibility
31+
"superanimal_bird",
32+
"superanimal_quadruped",
33+
"superanimal_topviewmouse",
3634
]
3735

3836

@@ -43,20 +41,66 @@ def _get_dlclibrary_path():
4341
return os.path.split(importlib.util.find_spec("dlclibrary").origin)[0]
4442

4543

46-
def _load_model_names():
44+
def _load_pytorch_models() -> dict[str, dict[str, dict[str, str]]]:
45+
"""Load URLs and commit hashes for available models."""
46+
urls = Path(_get_dlclibrary_path()) / "dlcmodelzoo" / "modelzoo_urls_pytorch.yaml"
47+
with open(urls) as file:
48+
data = YAML(pure=True).load(file)
49+
50+
return data
51+
52+
53+
def _load_pytorch_dataset_models(dataset: str) -> dict[str, dict[str, str]]:
4754
"""Load URLs and commit hashes for available models."""
48-
from ruamel.yaml import YAML
55+
models = _load_pytorch_models()
56+
if not dataset in models:
57+
raise ValueError(
58+
f"Could not find any models for {dataset}. Models are available for "
59+
f"{list(models.keys())}"
60+
)
4961

62+
return models[dataset]
63+
64+
65+
def _load_model_names():
66+
"""Load URLs and commit hashes for available models."""
5067
fn = os.path.join(_get_dlclibrary_path(), "dlcmodelzoo", "modelzoo_urls.yaml")
5168
with open(fn) as file:
52-
return YAML().load(file)
69+
model_names = YAML().load(file)
70+
71+
# add PyTorch models
72+
for dataset, model_types in _load_pytorch_models().items():
73+
for model_type, models in model_types.items():
74+
for model, url in models.items():
75+
model_names[f"{dataset}_{model}"] = url
76+
77+
return model_names
5378

5479

5580
def parse_available_supermodels():
5681
libpath = _get_dlclibrary_path()
5782
json_path = os.path.join(libpath, "dlcmodelzoo", "superanimal_models.json")
5883
with open(json_path) as file:
59-
return json.load(file)
84+
super_animal_models = json.load(file)
85+
return super_animal_models
86+
87+
88+
def get_available_detectors(dataset: str) -> list[str]:
89+
""" Only for PyTorch models.
90+
91+
Returns:
92+
The detectors available for the dataset.
93+
"""
94+
return list(_load_pytorch_dataset_models(dataset)["detectors"].keys())
95+
96+
97+
def get_available_models(dataset: str) -> list[str]:
98+
""" Only for PyTorch models.
99+
100+
Returns:
101+
The pose models available for the dataset.
102+
"""
103+
return list(_load_pytorch_dataset_models(dataset)["pose_models"].keys())
60104

61105

62106
def _handle_downloaded_file(
@@ -103,7 +147,9 @@ def download_huggingface_model(
103147
"""
104148
net_urls = _load_model_names()
105149
if model_name not in net_urls:
106-
raise ValueError(f"`modelname` should be one of: {', '.join(net_urls)}.")
150+
raise ValueError(
151+
f"`modelname={model_name}` should be one of: {', '.join(net_urls)}."
152+
)
107153

108154
print("Loading....", model_name)
109155
urls = net_urls[model_name]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# DeepLabCut 3.0: SuperAnimal detectors and pose model URLS
2+
3+
superanimal_bird:
4+
detectors:
5+
fasterrcnn_mobilenet_v3_large_fpn: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_fasterrcnn_mobilenet_v3_large_fpn.pt
6+
ssdlite: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_ssdlite.pt
7+
pose_models:
8+
resnet_50: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_resnet_50.pt
9+
10+
superanimal_topviewmouse:
11+
detectors:
12+
fasterrcnn_mobilenet_v3_large_fpn: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_fasterrcnn_mobilenet_v3_large_fpn.pt
13+
fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_fasterrcnn_resnet50_fpn_v2.pt
14+
pose_models:
15+
hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_hrnet_w32.pt
16+
resnet_50: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_resnet_50.pt
17+
18+
superanimal_quadruped:
19+
detectors:
20+
fasterrcnn_mobilenet_v3_large_fpn: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_fasterrcnn_mobilenet_v3_large_fpn.pt
21+
fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_fasterrcnn_resnet50_fpn_v2.pt
22+
pose_models:
23+
hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_hrnet_w32.pt
24+
resnet_50: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_resnet_50.pt

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"dlclibrary",
3333
[
3434
"dlclibrary/dlcmodelzoo/modelzoo_urls.yaml",
35+
"dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml",
3536
"dlclibrary/dlcmodelzoo/superanimal_models.json",
3637
"dlclibrary/dlcmodelzoo/superanimal_configs/superquadruped.yaml",
3738
"dlclibrary/dlcmodelzoo/superanimal_configs/supertopview.yaml",

tests/test_pytorch_models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#
2+
# DeepLabCut Toolbox (deeplabcut.org)
3+
# © A. & M.W. Mathis Labs
4+
# https://github.com/DeepLabCut/DeepLabCut
5+
#
6+
# Please see AUTHORS for contributors.
7+
# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
8+
#
9+
# Licensed under GNU Lesser General Public License v3.0
10+
#
11+
import os
12+
import pytest
13+
14+
import dlclibrary
15+
import dlclibrary.dlcmodelzoo.modelzoo_download as modelzoo
16+
17+
18+
@pytest.mark.parametrize(
19+
"data",
20+
[
21+
("superanimal_bird", ["ssdlite"]),
22+
("superanimal_topviewmouse", ["fasterrcnn_resnet50_fpn_v2"]),
23+
("superanimal_quadruped", ["fasterrcnn_resnet50_fpn_v2"]),
24+
]
25+
)
26+
def test_get_super_animal_detectors(data: tuple[str, list[str]]):
27+
dataset, expected_detectors = data
28+
detectors = modelzoo.get_available_detectors(dataset)
29+
assert len(detectors) >= len(expected_detectors)
30+
for det in expected_detectors:
31+
assert det in detectors
32+
33+
34+
@pytest.mark.parametrize(
35+
"data",
36+
[
37+
("superanimal_bird", ["resnet_50"]),
38+
("superanimal_topviewmouse", ["hrnet_w32"]),
39+
("superanimal_quadruped", ["hrnet_w32"]),
40+
]
41+
)
42+
def test_get_super_animal_pose_models(data: tuple[str, list[str]]):
43+
dataset, expected_pose_models = data
44+
pose_models = modelzoo.get_available_models(dataset)
45+
assert len(pose_models) >= len(expected_pose_models)
46+
for pose_model in expected_pose_models:
47+
assert pose_model in pose_models

0 commit comments

Comments
 (0)