Skip to content

Commit 90a945a

Browse files
committed
improved readme, add code to list datasets
1 parent 21db191 commit 90a945a

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,15 @@ download_huggingface_model("superanimal_quadruped", model_dir)
3636
```
3737

3838
PyTorch models available for a given dataset (compatible with DeepLabCut>=3.0) can be
39-
listed using the `dlclibrary.get_available_detectors` and
40-
`dlclibrary.get_available_models` methods. Example use:
39+
listed using the `dlclibrary.get_available_detectors` and
40+
`dlclibrary.get_available_models` methods. The datasets for which models are available
41+
can be listed using `dlclibrary.get_available_datasets`. Example use:
4142

4243
```python
4344
>>> import dlclibrary
45+
>>> dlclibrary.get_available_datasets()
46+
['superanimal_bird', 'superanimal_topviewmouse', 'superanimal_quadruped']
47+
4448
>>> dlclibrary.get_available_detectors("superanimal_bird")
4549
['fasterrcnn_mobilenet_v3_large_fpn', 'ssdlite']
4650

@@ -87,4 +91,5 @@ To add a new model for `deeplabcut >= 3.0.0`, simply:
8791
- add the structure if the model was trained on a new dataset
8892

8993
The models will then be listed when calling `dlclibrary.get_available_detectors` or
90-
`dlclibrary.get_available_models`!
94+
`dlclibrary.get_available_models`! You can list the datasets for which models are
95+
available using `dlclibrary.get_available_datasets`.

dlclibrary/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from dlclibrary.dlcmodelzoo.modelzoo_download import (
1313
download_huggingface_model,
14+
get_available_datasets,
1415
get_available_detectors,
1516
get_available_models,
1617
parse_available_supermodels,

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ def parse_available_supermodels():
8484
return super_animal_models
8585

8686

87+
def get_available_datasets() -> list[str]:
88+
"""Only for PyTorch models.
89+
90+
Returns:
91+
The name of datasets for which models are available
92+
"""
93+
return list(_load_pytorch_models().keys())
94+
95+
8796
def get_available_detectors(dataset: str) -> list[str]:
8897
""" Only for PyTorch models.
8998

0 commit comments

Comments
 (0)