Skip to content

Commit

Permalink
Merge pull request #30 from DeepLabCut/niels/fix_tests_update_readme
Browse files Browse the repository at this point in the history
update README to add new models
  • Loading branch information
AlexEMG authored Oct 18, 2024
2 parents 5a22726 + 90a945a commit aa2da6c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
38 changes: 36 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ download_huggingface_model("superanimal_quadruped", model_dir)
```

PyTorch models available for a given dataset (compatible with DeepLabCut>=3.0) can be
listed using the `dlclibrary.get_available_detectors` and
`dlclibrary.get_available_models` methods. Example use:
listed using the `dlclibrary.get_available_detectors` and
`dlclibrary.get_available_models` methods. The datasets for which models are available
can be listed using `dlclibrary.get_available_datasets`. Example use:

```python
>>> import dlclibrary
>>> dlclibrary.get_available_datasets()
['superanimal_bird', 'superanimal_topviewmouse', 'superanimal_quadruped']

>>> dlclibrary.get_available_detectors("superanimal_bird")
['fasterrcnn_mobilenet_v3_large_fpn', 'ssdlite']

Expand All @@ -51,6 +55,8 @@ listed using the `dlclibrary.get_available_detectors` and

## How to add a new model?

### TensorFlow models

Pick a good model_name. Follow the (novel) naming convention (modeltype_species), e.g. ```superanimal_topviewmouse```.

1. Add the model_name with path and commit ID to: https://github.com/DeepLabCut/DLClibrary/blob/main/dlclibrary/dlcmodelzoo/modelzoo_urls.yaml
Expand All @@ -59,3 +65,31 @@ Pick a good model_name. Follow the (novel) naming convention (modeltype_species)
https://github.com/DeepLabCut/DLClibrary/blob/main/dlclibrary/dlcmodelzoo/modelzoo_download.py#L15

3. For superanimal models also fill in the configs!

### PyTorch models (for `deeplabcut >= 3.0.0`)

PyTorch models are listed in [`dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml`](
https://github.com/DeepLabCut/DLClibrary/blob/main/dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml
). The file is organized as:

```yaml
my_cool_dataset: # name of the dataset used to train the model
detectors:
detector_name: path/to/huggingface-detector.pt # add detectors under `detector`
pose_models:
pose_model_name: path/to/huggingface-pose-model.pt # add pose models under `pose_models`
other_pose_model_name: path/to/huggingface-other-pose-model.pt
```
This will allow users to download the models using the format `datatsetName_modelName`,
i.e. for this example 3 models would be available: `my_cool_dataset_detector_name`,
`my_cool_dataset_pose_model_name` and `my_cool_dataset_other_pose_model_name`.

To add a new model for `deeplabcut >= 3.0.0`, simply:

- add a new line under detectors or pose models if the dataset is already defined
- add the structure if the model was trained on a new dataset

The models will then be listed when calling `dlclibrary.get_available_detectors` or
`dlclibrary.get_available_models`! You can list the datasets for which models are
available using `dlclibrary.get_available_datasets`.
1 change: 1 addition & 0 deletions dlclibrary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from dlclibrary.dlcmodelzoo.modelzoo_download import (
download_huggingface_model,
get_available_datasets,
get_available_detectors,
get_available_models,
parse_available_supermodels,
Expand Down
10 changes: 9 additions & 1 deletion dlclibrary/dlcmodelzoo/modelzoo_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"mouse_pupil_vclose",
"horse_sideview",
"full_macaque",
"superanimal_bird",
"superanimal_quadruped",
"superanimal_topviewmouse",
]
Expand Down Expand Up @@ -85,6 +84,15 @@ def parse_available_supermodels():
return super_animal_models


def get_available_datasets() -> list[str]:
"""Only for PyTorch models.
Returns:
The name of datasets for which models are available
"""
return list(_load_pytorch_models().keys())


def get_available_detectors(dataset: str) -> list[str]:
""" Only for PyTorch models.
Expand Down

0 comments on commit aa2da6c

Please sign in to comment.