Skip to content

Commit 421cc21

Browse files
author
mrjunjieli
committed
fix lint errors
1 parent f7e33d8 commit 421cc21

File tree

4 files changed

+121
-80
lines changed

4 files changed

+121
-80
lines changed

wesep/bin/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,9 @@ def train(config="conf/config.yaml", **kwargs):
321321
se_loss_weight=loss_args,
322322
multi_task=multi_task,
323323
SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", 0),
324-
fbank_args= configs["dataset_args"].get('fbank_args',None),
324+
fbank_args=configs["dataset_args"].get('fbank_args', None),
325325
sample_rate=configs["dataset_args"]['resample_rate'],
326-
speaker_feat = configs["dataset_args"].get('speaker_feat', True)
326+
speaker_feat=configs["dataset_args"].get('speaker_feat', True)
327327
)
328328

329329
val_loss, _ = executor.cv(

wesep/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import wesep.modules.metric_gan.discriminator as discriminator
66
import wesep.models.bsrnn_multi_optim as bsrnn_multi
77

8+
89
def get_model(model_name: str):
910
if model_name.startswith("ConvTasNet"):
1011
return getattr(convtasnet, model_name)
1112
elif model_name.startswith("BSRNN_Multi"):
12-
return getattr(bsrnn_multi,model_name)
13+
return getattr(bsrnn_multi, model_name)
1314
elif model_name.startswith("BSRNN"):
1415
return getattr(bsrnn, model_name)
1516
elif model_name.startswith("DPCCN"):

wesep/models/bsrnn_multi_optim.py

Lines changed: 107 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from wesep.modules.common.speaker import SpeakerFuseLayer
1212
from wesep.modules.common.speaker import SpeakerTransform
1313

14+
1415
class ResRNN(nn.Module):
1516

1617
def __init__(self, input_size, hidden_size, bidirectional=True):
@@ -30,16 +31,17 @@ def __init__(self, input_size, hidden_size, bidirectional=True):
3031
)
3132

3233
# linear projection layer
33-
self.proj = nn.Linear(hidden_size * 2,
34-
input_size) # hidden_size = feature_dim * 2
34+
self.proj = nn.Linear(
35+
hidden_size * 2, input_size
36+
) # hidden_size = feature_dim * 2
3537

3638
def forward(self, input):
3739
# input shape: batch, dim, seq
3840

3941
rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous())
40-
rnn_output = self.proj(rnn_output.contiguous().view(
41-
-1, rnn_output.shape[2])).view(input.shape[0], input.shape[2],
42-
input.shape[1])
42+
rnn_output = self.proj(
43+
rnn_output.contiguous().view(-1, rnn_output.shape[2])
44+
).view(input.shape[0], input.shape[2], input.shape[1])
4345

4446
return input + rnn_output.transpose(1, 2).contiguous()
4547

@@ -57,26 +59,31 @@ def __init__(self, in_channel, nband=7, bidirectional=True):
5759

5860
self.nband = nband
5961
self.feature_dim = in_channel // nband
60-
self.band_rnn = ResRNN(self.feature_dim,
61-
self.feature_dim * 2,
62-
bidirectional=bidirectional)
63-
self.band_comm = ResRNN(self.feature_dim,
64-
self.feature_dim * 2,
65-
bidirectional=bidirectional)
62+
self.band_rnn = ResRNN(
63+
self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional
64+
)
65+
self.band_comm = ResRNN(
66+
self.feature_dim, self.feature_dim * 2, bidirectional=bidirectional
67+
)
6668

6769
def forward(self, input, dummy: Optional[torch.Tensor] = None):
6870
# input shape: B, nband*N, T
6971
B, N, T = input.shape
7072

7173
band_output = self.band_rnn(
72-
input.view(B * self.nband, self.feature_dim,
73-
-1)).view(B, self.nband, -1, T)
74+
input.view(B * self.nband, self.feature_dim, -1)
75+
).view(B, self.nband, -1, T)
7476

7577
# band comm
76-
band_output = (band_output.permute(0, 3, 2, 1).contiguous().view(
77-
B * T, -1, self.nband))
78-
output = (self.band_comm(band_output).view(
79-
B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous())
78+
band_output = (
79+
band_output.permute(0, 3, 2, 1).contiguous().view(B * T, -1, self.nband)
80+
)
81+
output = (
82+
self.band_comm(band_output)
83+
.view(B, T, -1, self.nband)
84+
.permute(0, 3, 2, 1)
85+
.contiguous()
86+
)
8087

8188
return output.view(B, N, T)
8289

@@ -108,15 +115,17 @@ def __init__(
108115
embed_dim=spk_emb_dim,
109116
feat_dim=feature_dim,
110117
fuse_type=spk_fuse_type,
111-
))
118+
)
119+
)
112120
self.separation.append(BSNet(nband * feature_dim, nband))
113121
else:
114122
self.separation.append(
115123
SpeakerFuseLayer(
116124
embed_dim=spk_emb_dim,
117125
feat_dim=feature_dim,
118126
fuse_type=spk_fuse_type,
119-
))
127+
)
128+
)
120129
for _ in range(num_repeat):
121130
self.separation.append(BSNet(nband * feature_dim, nband))
122131

@@ -131,11 +140,9 @@ def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
131140
for i, sep_func in enumerate(self.separation):
132141
x = sep_func(x, spk_embedding)
133142
if i % 2 == 0:
134-
x = x.view(batch_size * nch, self.nband * self.feature_dim,
135-
-1)
143+
x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)
136144
else:
137-
x = x.view(batch_size * nch, self.nband, self.feature_dim,
138-
-1)
145+
x = x.view(batch_size * nch, self.nband, self.feature_dim, -1)
139146
else:
140147
x = self.separation[0](x, spk_embedding)
141148
x = x.view(batch_size * nch, self.nband * self.feature_dim, -1)
@@ -253,7 +260,8 @@ def __init__(
253260
nn.Sequential(
254261
nn.GroupNorm(1, self.band_width[i] * 2, self.eps),
255262
nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),
256-
))
263+
)
264+
)
257265

258266
self.separator = FuseSeparation(
259267
nband=self.nband,
@@ -270,14 +278,14 @@ def __init__(
270278
for i in range(self.nband):
271279
self.mask.append(
272280
nn.Sequential(
273-
nn.GroupNorm(1, self.feature_dim,
274-
torch.finfo(torch.float32).eps),
281+
nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps),
275282
nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1),
276283
nn.Tanh(),
277284
nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1),
278285
nn.Tanh(),
279286
nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1),
280-
))
287+
)
288+
)
281289

282290
def pad_input(self, input, window, stride):
283291
"""
@@ -308,8 +316,9 @@ def forward(self, input, embeddings):
308316
wav_input,
309317
n_fft=self.win,
310318
hop_length=self.stride,
311-
window=torch.hann_window(self.win).to(wav_input.device).type(
312-
wav_input.type()),
319+
window=torch.hann_window(self.win)
320+
.to(wav_input.device)
321+
.type(wav_input.type()),
313322
return_complex=True,
314323
)
315324

@@ -319,23 +328,26 @@ def forward(self, input, embeddings):
319328
subband_mix_spec = []
320329
band_idx = 0
321330
for i in range(len(self.band_width)):
322-
subband_spec.append(spec_RI[:, :, band_idx:band_idx +
323-
self.band_width[i]].contiguous())
324-
subband_mix_spec.append(spec[:, band_idx:band_idx +
325-
self.band_width[i]]) # B*nch, BW, T
331+
subband_spec.append(
332+
spec_RI[:, :, band_idx : band_idx + self.band_width[i]].contiguous()
333+
)
334+
subband_mix_spec.append(
335+
spec[:, band_idx : band_idx + self.band_width[i]]
336+
) # B*nch, BW, T
326337
band_idx += self.band_width[i]
327338

328339
# normalization and bottleneck
329340
subband_feature = []
330341
for i, bn_func in enumerate(self.BN):
331342
subband_feature.append(
332-
bn_func(subband_spec[i].view(batch_size * nch,
333-
self.band_width[i] * 2, -1)))
343+
bn_func(
344+
subband_spec[i].view(batch_size * nch, self.band_width[i] * 2, -1)
345+
)
346+
)
334347
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
335348
# print(subband_feature.size(), spk_emb_input.size())
336349

337-
predict_speaker_lable = torch.tensor(0.0).to(
338-
spk_emb_input.device) # dummy
350+
predict_speaker_lable = torch.tensor(0.0).to(spk_emb_input.device) # dummy
339351
if self.joint_training:
340352
if not self.spk_feat:
341353
if self.feat_type == "consistent":
@@ -344,7 +356,8 @@ def forward(self, input, embeddings):
344356
spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8
345357
spk_emb_input = spk_emb_input.log()
346358
spk_emb_input = spk_emb_input - torch.mean(
347-
spk_emb_input, dim=-1, keepdim=True)
359+
spk_emb_input, dim=-1, keepdim=True
360+
)
348361
spk_emb_input = spk_emb_input.permute(0, 2, 1)
349362

350363
tmp_spk_emb_input = self.spk_model(spk_emb_input)
@@ -357,51 +370,58 @@ def forward(self, input, embeddings):
357370
spk_embedding = self.spk_transform(spk_emb_input)
358371
spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)
359372

360-
sep_output = self.separator(subband_feature, spk_embedding,
361-
torch.tensor(nch))
373+
sep_output = self.separator(subband_feature, spk_embedding, torch.tensor(nch))
362374

363375
sep_subband_spec = []
364376
for i, mask_func in enumerate(self.mask):
365377
this_output = mask_func(sep_output[:, i]).view(
366-
batch_size * nch, 2, 2, self.band_width[i], -1)
378+
batch_size * nch, 2, 2, self.band_width[i], -1
379+
)
367380
this_mask = this_output[:, 0] * torch.sigmoid(
368-
this_output[:, 1]) # B*nch, 2, K, BW, T
381+
this_output[:, 1]
382+
) # B*nch, 2, K, BW, T
369383
this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
370384
this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
371-
est_spec_real = (subband_mix_spec[i].real * this_mask_real -
372-
subband_mix_spec[i].imag * this_mask_imag
373-
) # B*nch, BW, T
374-
est_spec_imag = (subband_mix_spec[i].real * this_mask_imag +
375-
subband_mix_spec[i].imag * this_mask_real
376-
) # B*nch, BW, T
377-
sep_subband_spec.append(torch.complex(est_spec_real,
378-
est_spec_imag))
385+
est_spec_real = (
386+
subband_mix_spec[i].real * this_mask_real
387+
- subband_mix_spec[i].imag * this_mask_imag
388+
) # B*nch, BW, T
389+
est_spec_imag = (
390+
subband_mix_spec[i].real * this_mask_imag
391+
+ subband_mix_spec[i].imag * this_mask_real
392+
) # B*nch, BW, T
393+
sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag))
379394
est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T
380395
output = torch.istft(
381396
est_spec.view(batch_size * nch, self.enc_dim, -1),
382397
n_fft=self.win,
383398
hop_length=self.stride,
384-
window=torch.hann_window(self.win).to(wav_input.device).type(
385-
wav_input.type()),
399+
window=torch.hann_window(self.win)
400+
.to(wav_input.device)
401+
.type(wav_input.type()),
386402
length=nsample,
387403
)
388404

389405
output = output.view(batch_size, nch, -1)
390406
s = torch.squeeze(output, dim=1)
391407
if torch.is_grad_enabled():
392408
self_embedding = s.detach()
393-
self_predict_speaker_lable = torch.tensor(0.0).to(self_embedding.device) # dummy
409+
self_predict_speaker_lable = torch.tensor(0.0).to(
410+
self_embedding.device
411+
) # dummy
394412
if self.joint_training:
395-
if self.feat_type=='consistent':
413+
if self.feat_type == "consistent":
396414
with torch.no_grad():
397415
self_embedding = self.preEmphasis(self_embedding)
398-
self_embedding = self.spk_encoder(self_embedding)+1e-8
416+
self_embedding = self.spk_encoder(self_embedding) + 1e-8
399417
self_embedding = self_embedding.log()
400-
self_embedding = self_embedding - torch.mean(self_embedding, dim=-1, keepdim=True)
418+
self_embedding = self_embedding - torch.mean(
419+
self_embedding, dim=-1, keepdim=True
420+
)
401421
self_embedding = self_embedding.permute(0, 2, 1)
402422

403423
self_tmp_spk_emb_input = self.spk_model(self_embedding)
404-
if isinstance(self_tmp_spk_emb_input,tuple):
424+
if isinstance(self_tmp_spk_emb_input, tuple):
405425
self_spk_emb_input = self_tmp_spk_emb_input[-1]
406426
else:
407427
self_spk_emb_input = self_tmp_spk_emb_input
@@ -410,29 +430,46 @@ def forward(self, input, embeddings):
410430
self_spk_embedding = self.spk_transform(self_spk_emb_input)
411431
self_spk_embedding = self_spk_embedding.unsqueeze(1).unsqueeze(3)
412432

413-
self_sep_output = self.separator(subband_feature, self_spk_embedding, torch.tensor(nch))
433+
self_sep_output = self.separator(
434+
subband_feature, self_spk_embedding, torch.tensor(nch)
435+
)
414436

415437
self_sep_subband_spec = []
416438
for i, mask_func in enumerate(self.mask):
417-
this_output = mask_func(self_sep_output[:, i]).view(batch_size * nch, 2, 2, self.band_width[i], -1)
418-
this_mask = this_output[:, 0] * torch.sigmoid(this_output[:, 1]) # B*nch, 2, K, BW, T
439+
this_output = mask_func(self_sep_output[:, i]).view(
440+
batch_size * nch, 2, 2, self.band_width[i], -1
441+
)
442+
this_mask = this_output[:, 0] * torch.sigmoid(
443+
this_output[:, 1]
444+
) # B*nch, 2, K, BW, T
419445
this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
420446
this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
421-
est_spec_real = subband_mix_spec[i].real * this_mask_real - subband_mix_spec[
422-
i].imag * this_mask_imag # B*nch, BW, T
423-
est_spec_imag = subband_mix_spec[i].real * this_mask_imag + subband_mix_spec[
424-
i].imag * this_mask_real # B*nch, BW, T
425-
self_sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag))
447+
est_spec_real = (
448+
subband_mix_spec[i].real * this_mask_real
449+
- subband_mix_spec[i].imag * this_mask_imag
450+
) # B*nch, BW, T
451+
est_spec_imag = (
452+
subband_mix_spec[i].real * this_mask_imag
453+
+ subband_mix_spec[i].imag * this_mask_real
454+
) # B*nch, BW, T
455+
self_sep_subband_spec.append(
456+
torch.complex(est_spec_real, est_spec_imag)
457+
)
426458
self_est_spec = torch.cat(self_sep_subband_spec, 1) # B*nch, F, T
427-
self_output = torch.istft(self_est_spec.view(batch_size * nch, self.enc_dim, -1),
428-
n_fft=self.win, hop_length=self.stride,
429-
window=torch.hann_window(self.win).to(wav_input.device).type(wav_input.type()),
430-
length=nsample)
459+
self_output = torch.istft(
460+
self_est_spec.view(batch_size * nch, self.enc_dim, -1),
461+
n_fft=self.win,
462+
hop_length=self.stride,
463+
window=torch.hann_window(self.win)
464+
.to(wav_input.device)
465+
.type(wav_input.type()),
466+
length=nsample,
467+
)
431468

432469
self_output = self_output.view(batch_size, nch, -1)
433470
self_s = torch.squeeze(self_output, dim=1)
434471

435-
return s,self_s, predict_speaker_lable,self_predict_speaker_lable
472+
return s, self_s, predict_speaker_lable, self_predict_speaker_lable
436473

437474
return s, predict_speaker_lable
438475

wesep/utils/executor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
# if your python version < 3.7 use the below one
2121
import torch
2222

23-
from wesep.utils.funcs import clip_gradients,compute_fbank,apply_cmvn
24-
import random
23+
from wesep.utils.funcs import clip_gradients, compute_fbank, apply_cmvn
24+
import random
25+
2526

2627
class Executor:
2728

@@ -85,19 +86,21 @@ def train(
8586
spk_label = spk_label.to(device)
8687

8788
with torch.cuda.amp.autocast(enabled=enable_amp):
88-
if SSA_enroll_prob >0:
89-
if SSA_enroll_prob>random.random():
89+
if SSA_enroll_prob > 0:
90+
if SSA_enroll_prob > random.random():
9091
with torch.no_grad():
9192
outputs = model(features, enroll)
9293
est_speech = outputs[0]
9394
self_fbank = est_speech
94-
if fbank_args!=None and speaker_feat==True:
95-
self_fbank = compute_fbank(est_speech,**fbank_args,sample_rate=sample_rate)
95+
if fbank_args is not None and speaker_feat:
96+
self_fbank = compute_fbank(
97+
est_speech, **fbank_args,
98+
sample_rate=sample_rate)
9699
self_fbank = apply_cmvn(self_fbank)
97100
outputs = model(features, self_fbank)
98101
else:
99102
outputs = model(features, enroll)
100-
else:
103+
else:
101104
outputs = model(features, enroll)
102105
if not isinstance(outputs, (list, tuple)):
103106
outputs = [outputs]

0 commit comments

Comments
 (0)