Skip to content

Commit

Permalink
suport different archs
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed May 11, 2023
1 parent 7e68e49 commit 17f03c8
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pip3 install torch==1.8.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.o
# Installation
Install the latest version of this repo:
```shell
pip install -e 'git+https://github.com/kkoutini/passt_hear21@0.0.21#egg=hear21passt'
pip install -e 'git+https://github.com/kkoutini/passt_hear21@0.0.22#egg=hear21passt'
```

The models follow the [common API](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html#common-api) of HEAR 21
Expand Down Expand Up @@ -51,6 +51,14 @@ logits = model(wave_signal)
The class labels indices can be found [here](https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/metadata/class_labels_indices.csv)


You can also use different pre-trained models, for example, the model trained with KD `passt_s_kd_p16_128_ap486`:
```python
from hear21passt.base import get_basic_model

model = get_basic_model(mode="logits", arch="passt_s_kd_p16_128_ap486")
logits = model(wave_signal)

```

# Supporting longer clips

Expand Down
2 changes: 1 addition & 1 deletion hear21passt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

__version__ = "0.0.21"
__version__ = "0.0.22"


def embeding_size(hop=50, embeding_size=1000):
Expand Down
3 changes: 1 addition & 2 deletions hear21passt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def get_basic_model(**kwargs):
timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)

net = get_model_passt(arch="passt_s_swa_p16_128_ap476")
net = get_model_passt(arch=kwargs.get("arch", "passt_s_swa_p16_128_ap476"))
model = PasstBasicWrapper(mel=mel, net=net, **kwargs)
return model

2 changes: 1 addition & 1 deletion hear21passt/base2level.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_concat_2level_model(**kwargs):
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)

net = get_model_passt(arch="passt_s_swa_p16_128_ap476")
net = get_model_passt(arch=kwargs.get("arch", "passt_s_swa_p16_128_ap476"))
model = PasstBasicWrapper(mel=mel, net=net, timestamp_embedding_size=1295 * 2, **kwargs)
return model

Expand Down
2 changes: 1 addition & 1 deletion hear21passt/base2levelmel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_concat_2levelmel_model(**kwargs):
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)

net = get_model_passt(arch="passt_s_swa_p16_128_ap476")
net = get_model_passt(arch=kwargs.get("arch", "passt_s_swa_p16_128_ap476"))
model = PasstBasicWrapper(mel=mel, net=net, timestamp_embedding_size=768 + 1295 * 2, **kwargs)
return model

Expand Down

0 comments on commit 17f03c8

Please sign in to comment.