Skip to content

Commit

Permalink
Self-estimated Speech Augmentation to bsrnn.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
mrjunjieli committed Sep 21, 2024
1 parent 56e49a3 commit 8527114
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
5 changes: 5 additions & 0 deletions examples/librimix/tse/v2/confs/bsrnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ dataset_args:
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
# Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589
# only Single-optimization method is supported here.
# if you want to use multi-optimization, please ref bsrnn_multi_optimization.yaml
SSA_enroll_prob:
Single_optimization: 0.6 # prob to add SSA on enrollment speech

enable_amp: false
exp_dir: exp/BSRNN
Expand Down
4 changes: 4 additions & 0 deletions wesep/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def train(config="conf/config.yaml", **kwargs):
device=device,
se_loss_weight=loss_args,
multi_task=multi_task,
SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", None),
fbank_args= configs["dataset_args"].get('fbank_args',None),
sample_rate=configs["dataset_args"]['resample_rate'],
speaker_feat = configs["dataset_args"].get('speaker_feat', True)
)

val_loss, _ = executor.cv(
Expand Down
23 changes: 20 additions & 3 deletions wesep/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
# if your python version < 3.7 use the below one
import torch

from wesep.utils.funcs import clip_gradients

from wesep.utils.funcs import clip_gradients,compute_fbank,apply_cmvn
import random

class Executor:

Expand All @@ -45,6 +45,10 @@ def train(
device=torch.device("cuda"),
se_loss_weight=1.0,
multi_task=False,
SSA_enroll_prob=0,
fbank_args=None,
sample_rate=16000,
speaker_feat=True
):
"""Train one epoch"""
model = models[0]
Expand Down Expand Up @@ -81,7 +85,20 @@ def train(
spk_label = spk_label.to(device)

with torch.cuda.amp.autocast(enabled=enable_amp):
outputs = model(features, enroll)
if SSA_enroll_prob['Single_optimization'] >0:
if SSA_enroll_prob['Single_optimization']>random.random():
with torch.no_grad():
outputs = model(features, enroll)
est_speech = outputs[0]
self_fbank = est_speech
if fbank_args!=None and speaker_feat==True:
self_fbank = compute_fbank(est_speech,**fbank_args,sample_rate=sample_rate)
self_fbank = apply_cmvn(self_fbank)
outputs = model(features, self_fbank)
else:
outputs = model(features, enroll)
else:
outputs = model(features, enroll)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]

Expand Down

0 comments on commit 8527114

Please sign in to comment.