-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [minimal working example] HuggingFace hub models now supported (#109) * Add gitignore * Add support for TIMM image feature extractors * docs * renamed imagenet labels and added in21k labels * extract_frames(hf): loading pretr weights now; * extract_frames(hf): implemented show pred for timm models * utils: test for model_name (should be specified) * hf.yaml: rm model_name default; style fix * extract_frames: a note with assumption * renamed hf to timm * timm.md: init * conda_env, install_conda: upd for timm * test_timm: test timm models * extract_frames: not all hf models have 'tag' * rename extract_frames.py to extract_timm.py * README, index: added timm models --------- Co-authored-by: Bruno Korbar <bjuncek@gmail.com>
- Loading branch information
Showing
15 changed files
with
22,118 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Model | ||
feature_type: 'timm' | ||
model_name: null # any timm model | ||
batch_size: 1 # Batchsize (only frame-wise extractors are supported) | ||
extraction_fps: null # For original video fps, leave unspecified 'null' (None) | ||
extraction_total: null # extract a fix number of frames. It is mutually exclusive with 'fps' | ||
|
||
# Extraction Parameters | ||
device: 'cuda:0' # device as in `torch`, can be 'cpu' | ||
on_extraction: 'print' # what to do once the features are extracted. Can be ['print', 'save_numpy', 'save_pickle'] | ||
output_path: './output' # where to store results if saved | ||
tmp_path: './tmp' # folder to store the temporary files used for extraction (frames or aud files) | ||
keep_tmp_files: false # to keep temp files after feature extraction. | ||
show_pred: false # to show preds of a model, i.e. on a pre-train dataset for each feature (Kinetics 400) | ||
pred_texts: null # provide a list of multiple sentences. if `null`, will perform zero-shot on Kinetics 400 | ||
|
||
# config | ||
config: null | ||
|
||
# Video paths | ||
video_paths: null | ||
file_with_video_paths: null # if the list of videos is large, you might put them in a txt file, use this argument to specify the path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# timm | ||
|
||
`video_features` ❤️ [timm](https://huggingface.co/docs/timm/index). | ||
We support all the models from the `timm` library (technically, for those where you can specify `pretrained=True`). | ||
|
||
For details, see the [timm docs](https://huggingface.co/docs/timm/index) and, | ||
specifically [model summaries](https://huggingface.co/docs/timm/models) and | ||
[model benchmark results](https://huggingface.co/docs/timm/results). | ||
|
||
## Supported Arguments | ||
<!-- the <div> makes columns wider --> | ||
| <div style="width: 12em">Argument</div> | <div style="width: 8em">Default</div> | Description | | ||
| --------------------------------------- | ------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| `model_name` | `null` | Any model from `timm.list_pretrained()`, e.g. `efficientnet_b0` or `efficientnet_b0.ra_in1k`. | | ||
| `batch_size` | `1` | You may speed up extraction of features by increasing the batch size as much as your GPU permits. | | ||
| `extraction_fps` | `null` | If specified (e.g. as `5`), the video will be re-encoded to the `extraction_fps` fps. Leave unspecified or `null` to skip re-encoding. | | ||
| `device` | `"cuda:0"` | The device specification. It follows the PyTorch style. Use `"cuda:3"` for the 4th GPU on the machine or `"cpu"` for CPU-only. | | ||
| `video_paths` | `null` | A list of videos for feature extraction. E.g. `"[./sample/v_ZNVhz7ctTq0.mp4, ./sample/v_GGSY1Qvo990.mp4]"` or just one path `"./sample/v_GGSY1Qvo990.mp4"`. | | ||
| `file_with_video_paths` | `null` | A path to a text file with video paths (one path per line). Hint: given a folder `./dataset` with `.mp4` files one could use: `find ./dataset -name "*mp4" > ./video_paths.txt`. | | ||
| `on_extraction` | `print` | If `print`, the features are printed to the terminal. If `save_numpy` or `save_pickle`, the features are saved to either `.npy` file or `.pkl`. | | ||
| `output_path` | `"./output"` | A path to a folder for storing the extracted features (if `on_extraction` is either `save_numpy` or `save_pickle`). | | ||
| `keep_tmp_files` | `false` | If `true`, the reencoded videos will be kept in `tmp_path`. | | ||
| `tmp_path` | `"./tmp"` | A path to a folder for storing temporal files (e.g. reencoded videos). | | ||
| `show_pred` | `false` | If `true`, the script will print the predictions of the model on a down-stream task. It is useful for debugging. This flag is only supported for the models that were trained on ImageNet 1K and 21K. | | ||
|
||
|
||
## Examples | ||
|
||
```bash | ||
python main.py \ | ||
feature_type=timm \ | ||
model_name=efficientnet_b0 \ | ||
device="cuda:0" \ | ||
video_paths="[./sample/v_ZNVhz7ctTq0.mp4, ./sample/v_GGSY1Qvo990.mp4]" | ||
``` | ||
|
||
If you want to specify particular weights, you can do it with `model_name` argument, as you'd do with `timm`, | ||
e.g. | ||
```bash | ||
python main.py \ | ||
feature_type=timm \ | ||
model_name=efficientnet_b0.ra_in1k \ | ||
device="cuda:0" \ | ||
video_paths="[./sample/v_GGSY1Qvo990.mp4]" | ||
``` | ||
|
||
If you'd like to check the model's outputs on a downstream task (ImageNet 1K or 21K), you can use `show_pred` argument. | ||
```bash | ||
python main.py \ | ||
feature_type=timm \ | ||
model_name=swin_small_patch4_window7_224.ms_in22k \ | ||
device="cuda:0" \ | ||
extraction_fps=1 \ | ||
video_paths="[./sample/v_GGSY1Qvo990.mp4]" \ | ||
show_pred=true | ||
# Logits | Prob. | Label | ||
# 12.029 | 0.456 | barbell | ||
# 11.676 | 0.321 | weight, free_weight, exercising_weight | ||
# 9.653 | 0.042 | pusher, thruster | ||
# 9.499 | 0.036 | dumbbell | ||
# 8.787 | 0.018 | bench_press | ||
|
||
# Logits | Prob. | Label | ||
# 11.742 | 0.467 | barbell | ||
# 11.233 | 0.281 | weight, free_weight, exercising_weight | ||
# 9.489 | 0.049 | dumbbell | ||
# 8.923 | 0.028 | pusher, thruster | ||
# 8.406 | 0.017 | bench_press | ||
|
||
# Logits | Prob. | Label | ||
# 12.257 | 0.571 | barbell | ||
# 11.391 | 0.240 | weight, free_weight, exercising_weight | ||
# 9.708 | 0.045 | dumbbell | ||
# 9.031 | 0.023 | pusher, thruster | ||
# 8.756 | 0.017 | bench_press | ||
|
||
# Logits | Prob. | Label | ||
# 12.469 | 0.571 | barbell | ||
# 11.655 | 0.253 | weight, free_weight, exercising_weight | ||
# 9.818 | 0.040 | dumbbell | ||
# 9.648 | 0.034 | pusher, thruster | ||
# 8.527 | 0.011 | bench_press | ||
|
||
... | ||
``` | ||
|
||
## Credits | ||
* [timm](https://huggingface.co/docs/timm/index) library | ||
|
||
## License | ||
`video_features` is under MIT, the `timm` is under [Apache 2.0](https://github.com/huggingface/pytorch-image-models/blob/main/LICENSE). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .extract_timm import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
|
||
import omegaconf | ||
from typing import Dict, List | ||
import torch | ||
from PIL import Image | ||
from torchvision.transforms import Compose | ||
from models._base.base_framewise_extractor import BaseFrameWiseExtractor | ||
from utils.utils import show_predictions_on_dataset | ||
|
||
try: | ||
import timm | ||
from timm.data import resolve_data_config | ||
from timm.data.transforms_factory import create_transform | ||
except ImportError: | ||
raise ImportError("This features require timm library to be installed.") | ||
|
||
class ExtractTIMM(BaseFrameWiseExtractor): | ||
|
||
def __init__(self, args: omegaconf.DictConfig) -> None: | ||
super().__init__( | ||
feature_type=args.feature_type, | ||
on_extraction=args.on_extraction, | ||
tmp_path=args.tmp_path, | ||
output_path=args.output_path, | ||
keep_tmp_files=args.keep_tmp_files, | ||
device=args.device, | ||
model_name=args.model_name, | ||
batch_size=args.batch_size, | ||
extraction_fps=args.extraction_fps, | ||
extraction_total=args.extraction_total, | ||
show_pred=args.show_pred, | ||
) | ||
|
||
# transform must be implemented in _create_model | ||
self.transforms = None | ||
self.name2module = self.load_model() | ||
|
||
def load_model(self) -> Dict[str, torch.nn.Module]: | ||
"""Defines the models, loads checkpoints and related transforms, | ||
sends them to the device. | ||
Raises: | ||
NotImplementedError: if a model is not implemented. | ||
Returns: | ||
Dict[str, torch.nn.Module]: model-agnostic dict holding modules for extraction and show_pred | ||
""" | ||
model = timm.create_model(self.model_name, pretrained=True) | ||
|
||
# transforms | ||
self.transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) | ||
self.transforms = Compose([lambda np_array: Image.fromarray(np_array), self.transforms]) | ||
print(self.transforms) | ||
|
||
model.to(self.device) | ||
model.eval() | ||
|
||
# remove the classifier after getting it | ||
class_head = model.get_classifier() | ||
model.reset_classifier(0) | ||
|
||
# to be used in `run_on_a_batch` to determine the how to show predictions | ||
self.hf_arch = model.default_cfg['architecture'] | ||
self.hf_tag = model.default_cfg.get('tag', '') | ||
|
||
return {'model': model, 'class_head': class_head, } | ||
|
||
def run_on_a_batch(self, batch: List) -> torch.Tensor: | ||
"""This is a hack for timm models to output features. | ||
Ideally, you want to use model_spec to define behaviour at forward pass in | ||
the config file. | ||
""" | ||
model = self.name2module['model'] | ||
batch = torch.cat(batch).to(self.device) | ||
batch_feats = model(batch) | ||
self.maybe_show_pred(batch_feats) | ||
return batch_feats | ||
|
||
def maybe_show_pred(self, feats: torch.Tensor): | ||
if self.show_pred: | ||
logits = self.name2module['class_head'](feats) | ||
# NOTE: these hardcoded ends assume that the end of the tag corresponds to the last training dset | ||
if self.hf_tag.endswith(('in1k', 'in1k_288', 'in1k_320', 'in1k_384', 'in1k_475', 'in1k_512',)): | ||
show_predictions_on_dataset(logits, 'imagenet1k') | ||
elif self.hf_tag.endswith(('in21k', 'in21k_288', 'in21k_320', 'in21k_384', 'in21k_475', | ||
'in21k_512', | ||
'in22k', 'in22k_288', 'in22k_320', 'in22k_384', 'in22k_475', | ||
'in22k_512',)): | ||
show_predictions_on_dataset(logits, 'imagenet21k') | ||
else: | ||
print(f'No show_pred for {self.hf_arch} with tag {self.hf_tag}; use `show_pred=False`') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import sys | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
sys.path.insert(0, '.') # nopep8 | ||
|
||
from models.timm import ExtractTIMM as Extractor | ||
from tests.utils import base_test_script | ||
|
||
# a bit ugly but it assumes the features being tested has the same folder name, | ||
# e.g. for r21d: ./tests/r21d/THIS_FILE | ||
# it prevents doing the same tests for different features | ||
THIS_FILE_PATH = __file__ | ||
FEATURE_TYPE = Path(THIS_FILE_PATH).parent.name | ||
|
||
# True when run for the first time, then must be False | ||
TO_MAKE_REF = False | ||
|
||
signature = 'device, video_paths, model_name, batch_size, extraction_fps, to_make_ref' | ||
test_params = [ | ||
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'vit_base_patch16_224.dino', 1, 1, TO_MAKE_REF), | ||
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'coat_tiny.in1k', 1, 1, TO_MAKE_REF), | ||
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'hf-hub:nateraw/resnet50-oxford-iiit-pet', 4, 2, TO_MAKE_REF), | ||
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'mobilenetv3_small_050', 1, None, TO_MAKE_REF), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize(signature, test_params) | ||
def test(device, video_paths, model_name, batch_size, extraction_fps, to_make_ref): | ||
# get config | ||
patch_kwargs = dict( | ||
device=device, | ||
video_paths=video_paths, | ||
model_name=model_name, | ||
batch_size=batch_size, | ||
extraction_fps=extraction_fps | ||
) | ||
base_test_script(FEATURE_TYPE, Extractor, to_make_ref, **patch_kwargs) |
File renamed without changes.
Oops, something went wrong.