diff --git a/examples/librimix/tse/v2/confs/bsrnn.yaml b/examples/librimix/tse/v2/confs/bsrnn.yaml index 98b2fb0..f8b9791 100644 --- a/examples/librimix/tse/v2/confs/bsrnn.yaml +++ b/examples/librimix/tse/v2/confs/bsrnn.yaml @@ -23,6 +23,11 @@ dataset_args: specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech noise_enroll_prob: 0 # prob to add noise aug on enrollment speech + # Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589 + # only Single-optimization method is supported here. + # if you want to use multi-optimization, please ref bsrnn_multi_optimization.yaml + SSA_enroll_prob: + Single_optimization: 0.6 # prob to add SSA on enrollment speech enable_amp: false exp_dir: exp/BSRNN diff --git a/wesep/bin/train.py b/wesep/bin/train.py index 772d27d..65048c5 100644 --- a/wesep/bin/train.py +++ b/wesep/bin/train.py @@ -320,6 +320,10 @@ def train(config="conf/config.yaml", **kwargs): device=device, se_loss_weight=loss_args, multi_task=multi_task, + SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", None), + fbank_args= configs["dataset_args"].get('fbank_args',None), + sample_rate=configs["dataset_args"]['resample_rate'], + speaker_feat = configs["dataset_args"].get('speaker_feat', True) ) val_loss, _ = executor.cv( diff --git a/wesep/utils/executor.py b/wesep/utils/executor.py index 59154f5..ca7e95a 100644 --- a/wesep/utils/executor.py +++ b/wesep/utils/executor.py @@ -20,8 +20,8 @@ # if your python version < 3.7 use the below one import torch -from wesep.utils.funcs import clip_gradients - +from wesep.utils.funcs import clip_gradients,compute_fbank,apply_cmvn +import random class Executor: @@ -45,6 +45,10 @@ def train( device=torch.device("cuda"), se_loss_weight=1.0, multi_task=False, + SSA_enroll_prob=0, + fbank_args=None, + sample_rate=16000, + speaker_feat=True ): """Train one epoch""" model = models[0] @@ -81,7 +85,20 @@ def train( spk_label = spk_label.to(device) with torch.cuda.amp.autocast(enabled=enable_amp): - outputs = model(features, enroll) + if SSA_enroll_prob['Single_optimization'] >0: + if SSA_enroll_prob['Single_optimization']>random.random(): + with torch.no_grad(): + outputs = model(features, enroll) + est_speech = outputs[0] + self_fbank = est_speech + if fbank_args!=None and speaker_feat==True: + self_fbank = compute_fbank(est_speech,**fbank_args,sample_rate=sample_rate) + self_fbank = apply_cmvn(self_fbank) + outputs = model(features, self_fbank) + else: + outputs = model(features, enroll) + else: + outputs = model(features, enroll) if not isinstance(outputs, (list, tuple)): outputs = [outputs]