Skip to content

Commit

Permalink
Bandit Plus model added
Browse files Browse the repository at this point in the history
  • Loading branch information
ZFTurbo committed Dec 18, 2023
1 parent 0871663 commit 9fbe4b6
Show file tree
Hide file tree
Showing 41 changed files with 6,693 additions and 53 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ Available models for training:
* VitLarge23 based on [Segmentation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch). Key: `segm_models`.
* Band Split RoFormer [[Paper](https://arxiv.org/abs/2309.02612), [Repository](https://github.com/lucidrains/BS-RoFormer)] . Key: `bs_roformer`.
* Mel-Band RoFormer [[Paper](https://arxiv.org/abs/2310.01809), [Repository](https://github.com/lucidrains/BS-RoFormer)]. Key: `mel_band_roformer`.
* Swin Upernet [[Paper](https://arxiv.org/abs/2103.14030)] Key: `swin_upernet`.
* Swin Upernet [[Paper](https://arxiv.org/abs/2103.14030)] Key: `swin_upernet`.
* BandIt Plus [[Paper](https://arxiv.org/abs/2309.02539), [Repository](https://github.com/karnwatcharasupat/bandit)] Key: `bandit`.

**Note 1**: For `segm_models` there are many different encoders is possible. [Look here](https://github.com/qubvel/segmentation_models.pytorch#encoders-).

Expand Down
70 changes: 70 additions & 0 deletions configs/config_dnr_bandit_bsrnn_multi_mus64.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
name: "MultiMaskMultiSourceBandSplitRNN"
audio:
chunk_size: 264600
num_channels: 2
sample_rate: 44100
min_mean_abs: 0.001

model:
in_channel: 1
stems: ['speech', 'music', 'effects']
band_specs: "musical"
n_bands: 64
fs: 44100
require_no_overlap: false
require_no_gap: true
normalize_channel_independently: false
treat_channel_as_feature: true
n_sqm_modules: 8
emb_dim: 128
rnn_dim: 256
bidirectional: true
rnn_type: "GRU"
mlp_dim: 512
hidden_activation: "Tanh"
hidden_activation_kwargs: null
complex_mask: true
n_fft: 2048
win_length: 2048
hop_length: 512
window_fn: "hann_window"
wkwargs: null
power: null
center: true
normalized: true
pad_mode: "constant"
onesided: true

training:
batch_size: 4
gradient_accumulation_steps: 4
grad_clip: 0
instruments:
- speech
- music
- effects
lr: 9.0e-05
patience: 2
reduce_factor: 0.95
target_instrument: null
num_epochs: 1000
num_steps: 1000
augmentation: false # enable augmentations by audiomentations and pedalboard
augmentation_type: simple1
use_mp3_compress: false # Deprecated
augmentation_mix: true # Mix several stems of the same type with some probability
augmentation_loudness: true # randomly change loudness of each stem
augmentation_loudness_type: 1 # Type 1 or 2
augmentation_loudness_min: 0.5
augmentation_loudness_max: 1.5
q: 0.95
coarse_loss_clip: true
ema_momentum: 0.999
optimizer: adam
other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true

inference:
batch_size: 1
dim_t: 256
num_overlap: 4
69 changes: 69 additions & 0 deletions configs/config_vocals_bandit_bsrnn_multi_mus64.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: "MultiMaskMultiSourceBandSplitRNN"
audio:
chunk_size: 264600
num_channels: 2
sample_rate: 44100
min_mean_abs: 0.001

model:
in_channel: 1
stems: ['vocals', 'other']
band_specs: "musical"
n_bands: 64
fs: 44100
require_no_overlap: false
require_no_gap: true
normalize_channel_independently: false
treat_channel_as_feature: true
n_sqm_modules: 8
emb_dim: 128
rnn_dim: 256
bidirectional: true
rnn_type: "GRU"
mlp_dim: 512
hidden_activation: "Tanh"
hidden_activation_kwargs: null
complex_mask: true
n_fft: 2048
win_length: 2048
hop_length: 512
window_fn: "hann_window"
wkwargs: null
power: null
center: true
normalized: true
pad_mode: "constant"
onesided: true

training:
batch_size: 4
gradient_accumulation_steps: 4
grad_clip: 0
instruments:
- vocals
- other
lr: 9.0e-05
patience: 2
reduce_factor: 0.95
target_instrument: null
num_epochs: 1000
num_steps: 1000
augmentation: false # enable augmentations by audiomentations and pedalboard
augmentation_type: simple1
use_mp3_compress: false # Deprecated
augmentation_mix: true # Mix several stems of the same type with some probability
augmentation_loudness: true # randomly change loudness of each stem
augmentation_loudness_type: 1 # Type 1 or 2
augmentation_loudness_min: 0.5
augmentation_loudness_max: 1.5
q: 0.95
coarse_loss_clip: true
ema_momentum: 0.999
optimizer: adam
other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true

inference:
batch_size: 1
dim_t: 256
num_overlap: 4
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def run_folder(model, args, config, device, verbose=False):

def proc_folder(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet")
parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit")
parser.add_argument("--config_path", type=str, help="path to config file")
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights")
parser.add_argument("--input_folder", type=str, help="folder with mixtures to process")
Expand Down
Loading

0 comments on commit 9fbe4b6

Please sign in to comment.