11
11
from wesep .modules .common .speaker import SpeakerFuseLayer
12
12
from wesep .modules .common .speaker import SpeakerTransform
13
13
14
+
14
15
class ResRNN (nn .Module ):
15
16
16
17
def __init__ (self , input_size , hidden_size , bidirectional = True ):
@@ -30,16 +31,17 @@ def __init__(self, input_size, hidden_size, bidirectional=True):
30
31
)
31
32
32
33
# linear projection layer
33
- self .proj = nn .Linear (hidden_size * 2 ,
34
- input_size ) # hidden_size = feature_dim * 2
34
+ self .proj = nn .Linear (
35
+ hidden_size * 2 , input_size
36
+ ) # hidden_size = feature_dim * 2
35
37
36
38
def forward (self , input ):
37
39
# input shape: batch, dim, seq
38
40
39
41
rnn_output , _ = self .rnn (self .norm (input ).transpose (1 , 2 ).contiguous ())
40
- rnn_output = self .proj (rnn_output . contiguous (). view (
41
- - 1 , rnn_output .shape [ 2 ])) .view (input . shape [ 0 ], input .shape [2 ],
42
- input .shape [1 ])
42
+ rnn_output = self .proj (
43
+ rnn_output .contiguous () .view (- 1 , rnn_output .shape [2 ])
44
+ ). view ( input . shape [ 0 ], input . shape [ 2 ], input .shape [1 ])
43
45
44
46
return input + rnn_output .transpose (1 , 2 ).contiguous ()
45
47
@@ -57,26 +59,31 @@ def __init__(self, in_channel, nband=7, bidirectional=True):
57
59
58
60
self .nband = nband
59
61
self .feature_dim = in_channel // nband
60
- self .band_rnn = ResRNN (self . feature_dim ,
61
- self .feature_dim * 2 ,
62
- bidirectional = bidirectional )
63
- self .band_comm = ResRNN (self . feature_dim ,
64
- self .feature_dim * 2 ,
65
- bidirectional = bidirectional )
62
+ self .band_rnn = ResRNN (
63
+ self . feature_dim , self .feature_dim * 2 , bidirectional = bidirectional
64
+ )
65
+ self .band_comm = ResRNN (
66
+ self . feature_dim , self .feature_dim * 2 , bidirectional = bidirectional
67
+ )
66
68
67
69
def forward (self , input , dummy : Optional [torch .Tensor ] = None ):
68
70
# input shape: B, nband*N, T
69
71
B , N , T = input .shape
70
72
71
73
band_output = self .band_rnn (
72
- input .view (B * self .nband , self .feature_dim ,
73
- - 1 ) ).view (B , self .nband , - 1 , T )
74
+ input .view (B * self .nband , self .feature_dim , - 1 )
75
+ ).view (B , self .nband , - 1 , T )
74
76
75
77
# band comm
76
- band_output = (band_output .permute (0 , 3 , 2 , 1 ).contiguous ().view (
77
- B * T , - 1 , self .nband ))
78
- output = (self .band_comm (band_output ).view (
79
- B , T , - 1 , self .nband ).permute (0 , 3 , 2 , 1 ).contiguous ())
78
+ band_output = (
79
+ band_output .permute (0 , 3 , 2 , 1 ).contiguous ().view (B * T , - 1 , self .nband )
80
+ )
81
+ output = (
82
+ self .band_comm (band_output )
83
+ .view (B , T , - 1 , self .nband )
84
+ .permute (0 , 3 , 2 , 1 )
85
+ .contiguous ()
86
+ )
80
87
81
88
return output .view (B , N , T )
82
89
@@ -108,15 +115,17 @@ def __init__(
108
115
embed_dim = spk_emb_dim ,
109
116
feat_dim = feature_dim ,
110
117
fuse_type = spk_fuse_type ,
111
- ))
118
+ )
119
+ )
112
120
self .separation .append (BSNet (nband * feature_dim , nband ))
113
121
else :
114
122
self .separation .append (
115
123
SpeakerFuseLayer (
116
124
embed_dim = spk_emb_dim ,
117
125
feat_dim = feature_dim ,
118
126
fuse_type = spk_fuse_type ,
119
- ))
127
+ )
128
+ )
120
129
for _ in range (num_repeat ):
121
130
self .separation .append (BSNet (nband * feature_dim , nband ))
122
131
@@ -131,11 +140,9 @@ def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)):
131
140
for i , sep_func in enumerate (self .separation ):
132
141
x = sep_func (x , spk_embedding )
133
142
if i % 2 == 0 :
134
- x = x .view (batch_size * nch , self .nband * self .feature_dim ,
135
- - 1 )
143
+ x = x .view (batch_size * nch , self .nband * self .feature_dim , - 1 )
136
144
else :
137
- x = x .view (batch_size * nch , self .nband , self .feature_dim ,
138
- - 1 )
145
+ x = x .view (batch_size * nch , self .nband , self .feature_dim , - 1 )
139
146
else :
140
147
x = self .separation [0 ](x , spk_embedding )
141
148
x = x .view (batch_size * nch , self .nband * self .feature_dim , - 1 )
@@ -253,7 +260,8 @@ def __init__(
253
260
nn .Sequential (
254
261
nn .GroupNorm (1 , self .band_width [i ] * 2 , self .eps ),
255
262
nn .Conv1d (self .band_width [i ] * 2 , self .feature_dim , 1 ),
256
- ))
263
+ )
264
+ )
257
265
258
266
self .separator = FuseSeparation (
259
267
nband = self .nband ,
@@ -270,14 +278,14 @@ def __init__(
270
278
for i in range (self .nband ):
271
279
self .mask .append (
272
280
nn .Sequential (
273
- nn .GroupNorm (1 , self .feature_dim ,
274
- torch .finfo (torch .float32 ).eps ),
281
+ nn .GroupNorm (1 , self .feature_dim , torch .finfo (torch .float32 ).eps ),
275
282
nn .Conv1d (self .feature_dim , self .feature_dim * 4 , 1 ),
276
283
nn .Tanh (),
277
284
nn .Conv1d (self .feature_dim * 4 , self .feature_dim * 4 , 1 ),
278
285
nn .Tanh (),
279
286
nn .Conv1d (self .feature_dim * 4 , self .band_width [i ] * 4 , 1 ),
280
- ))
287
+ )
288
+ )
281
289
282
290
def pad_input (self , input , window , stride ):
283
291
"""
@@ -308,8 +316,9 @@ def forward(self, input, embeddings):
308
316
wav_input ,
309
317
n_fft = self .win ,
310
318
hop_length = self .stride ,
311
- window = torch .hann_window (self .win ).to (wav_input .device ).type (
312
- wav_input .type ()),
319
+ window = torch .hann_window (self .win )
320
+ .to (wav_input .device )
321
+ .type (wav_input .type ()),
313
322
return_complex = True ,
314
323
)
315
324
@@ -319,23 +328,26 @@ def forward(self, input, embeddings):
319
328
subband_mix_spec = []
320
329
band_idx = 0
321
330
for i in range (len (self .band_width )):
322
- subband_spec .append (spec_RI [:, :, band_idx :band_idx +
323
- self .band_width [i ]].contiguous ())
324
- subband_mix_spec .append (spec [:, band_idx :band_idx +
325
- self .band_width [i ]]) # B*nch, BW, T
331
+ subband_spec .append (
332
+ spec_RI [:, :, band_idx : band_idx + self .band_width [i ]].contiguous ()
333
+ )
334
+ subband_mix_spec .append (
335
+ spec [:, band_idx : band_idx + self .band_width [i ]]
336
+ ) # B*nch, BW, T
326
337
band_idx += self .band_width [i ]
327
338
328
339
# normalization and bottleneck
329
340
subband_feature = []
330
341
for i , bn_func in enumerate (self .BN ):
331
342
subband_feature .append (
332
- bn_func (subband_spec [i ].view (batch_size * nch ,
333
- self .band_width [i ] * 2 , - 1 )))
343
+ bn_func (
344
+ subband_spec [i ].view (batch_size * nch , self .band_width [i ] * 2 , - 1 )
345
+ )
346
+ )
334
347
subband_feature = torch .stack (subband_feature , 1 ) # B, nband, N, T
335
348
# print(subband_feature.size(), spk_emb_input.size())
336
349
337
- predict_speaker_lable = torch .tensor (0.0 ).to (
338
- spk_emb_input .device ) # dummy
350
+ predict_speaker_lable = torch .tensor (0.0 ).to (spk_emb_input .device ) # dummy
339
351
if self .joint_training :
340
352
if not self .spk_feat :
341
353
if self .feat_type == "consistent" :
@@ -344,7 +356,8 @@ def forward(self, input, embeddings):
344
356
spk_emb_input = self .spk_encoder (spk_emb_input ) + 1e-8
345
357
spk_emb_input = spk_emb_input .log ()
346
358
spk_emb_input = spk_emb_input - torch .mean (
347
- spk_emb_input , dim = - 1 , keepdim = True )
359
+ spk_emb_input , dim = - 1 , keepdim = True
360
+ )
348
361
spk_emb_input = spk_emb_input .permute (0 , 2 , 1 )
349
362
350
363
tmp_spk_emb_input = self .spk_model (spk_emb_input )
@@ -357,51 +370,58 @@ def forward(self, input, embeddings):
357
370
spk_embedding = self .spk_transform (spk_emb_input )
358
371
spk_embedding = spk_embedding .unsqueeze (1 ).unsqueeze (3 )
359
372
360
- sep_output = self .separator (subband_feature , spk_embedding ,
361
- torch .tensor (nch ))
373
+ sep_output = self .separator (subband_feature , spk_embedding , torch .tensor (nch ))
362
374
363
375
sep_subband_spec = []
364
376
for i , mask_func in enumerate (self .mask ):
365
377
this_output = mask_func (sep_output [:, i ]).view (
366
- batch_size * nch , 2 , 2 , self .band_width [i ], - 1 )
378
+ batch_size * nch , 2 , 2 , self .band_width [i ], - 1
379
+ )
367
380
this_mask = this_output [:, 0 ] * torch .sigmoid (
368
- this_output [:, 1 ]) # B*nch, 2, K, BW, T
381
+ this_output [:, 1 ]
382
+ ) # B*nch, 2, K, BW, T
369
383
this_mask_real = this_mask [:, 0 ] # B*nch, K, BW, T
370
384
this_mask_imag = this_mask [:, 1 ] # B*nch, K, BW, T
371
- est_spec_real = (subband_mix_spec [i ].real * this_mask_real -
372
- subband_mix_spec [i ].imag * this_mask_imag
373
- ) # B*nch, BW, T
374
- est_spec_imag = (subband_mix_spec [i ].real * this_mask_imag +
375
- subband_mix_spec [i ].imag * this_mask_real
376
- ) # B*nch, BW, T
377
- sep_subband_spec .append (torch .complex (est_spec_real ,
378
- est_spec_imag ))
385
+ est_spec_real = (
386
+ subband_mix_spec [i ].real * this_mask_real
387
+ - subband_mix_spec [i ].imag * this_mask_imag
388
+ ) # B*nch, BW, T
389
+ est_spec_imag = (
390
+ subband_mix_spec [i ].real * this_mask_imag
391
+ + subband_mix_spec [i ].imag * this_mask_real
392
+ ) # B*nch, BW, T
393
+ sep_subband_spec .append (torch .complex (est_spec_real , est_spec_imag ))
379
394
est_spec = torch .cat (sep_subband_spec , 1 ) # B*nch, F, T
380
395
output = torch .istft (
381
396
est_spec .view (batch_size * nch , self .enc_dim , - 1 ),
382
397
n_fft = self .win ,
383
398
hop_length = self .stride ,
384
- window = torch .hann_window (self .win ).to (wav_input .device ).type (
385
- wav_input .type ()),
399
+ window = torch .hann_window (self .win )
400
+ .to (wav_input .device )
401
+ .type (wav_input .type ()),
386
402
length = nsample ,
387
403
)
388
404
389
405
output = output .view (batch_size , nch , - 1 )
390
406
s = torch .squeeze (output , dim = 1 )
391
407
if torch .is_grad_enabled ():
392
408
self_embedding = s .detach ()
393
- self_predict_speaker_lable = torch .tensor (0.0 ).to (self_embedding .device ) # dummy
409
+ self_predict_speaker_lable = torch .tensor (0.0 ).to (
410
+ self_embedding .device
411
+ ) # dummy
394
412
if self .joint_training :
395
- if self .feat_type == ' consistent' :
413
+ if self .feat_type == " consistent" :
396
414
with torch .no_grad ():
397
415
self_embedding = self .preEmphasis (self_embedding )
398
- self_embedding = self .spk_encoder (self_embedding )+ 1e-8
416
+ self_embedding = self .spk_encoder (self_embedding ) + 1e-8
399
417
self_embedding = self_embedding .log ()
400
- self_embedding = self_embedding - torch .mean (self_embedding , dim = - 1 , keepdim = True )
418
+ self_embedding = self_embedding - torch .mean (
419
+ self_embedding , dim = - 1 , keepdim = True
420
+ )
401
421
self_embedding = self_embedding .permute (0 , 2 , 1 )
402
422
403
423
self_tmp_spk_emb_input = self .spk_model (self_embedding )
404
- if isinstance (self_tmp_spk_emb_input ,tuple ):
424
+ if isinstance (self_tmp_spk_emb_input , tuple ):
405
425
self_spk_emb_input = self_tmp_spk_emb_input [- 1 ]
406
426
else :
407
427
self_spk_emb_input = self_tmp_spk_emb_input
@@ -410,29 +430,46 @@ def forward(self, input, embeddings):
410
430
self_spk_embedding = self .spk_transform (self_spk_emb_input )
411
431
self_spk_embedding = self_spk_embedding .unsqueeze (1 ).unsqueeze (3 )
412
432
413
- self_sep_output = self .separator (subband_feature , self_spk_embedding , torch .tensor (nch ))
433
+ self_sep_output = self .separator (
434
+ subband_feature , self_spk_embedding , torch .tensor (nch )
435
+ )
414
436
415
437
self_sep_subband_spec = []
416
438
for i , mask_func in enumerate (self .mask ):
417
- this_output = mask_func (self_sep_output [:, i ]).view (batch_size * nch , 2 , 2 , self .band_width [i ], - 1 )
418
- this_mask = this_output [:, 0 ] * torch .sigmoid (this_output [:, 1 ]) # B*nch, 2, K, BW, T
439
+ this_output = mask_func (self_sep_output [:, i ]).view (
440
+ batch_size * nch , 2 , 2 , self .band_width [i ], - 1
441
+ )
442
+ this_mask = this_output [:, 0 ] * torch .sigmoid (
443
+ this_output [:, 1 ]
444
+ ) # B*nch, 2, K, BW, T
419
445
this_mask_real = this_mask [:, 0 ] # B*nch, K, BW, T
420
446
this_mask_imag = this_mask [:, 1 ] # B*nch, K, BW, T
421
- est_spec_real = subband_mix_spec [i ].real * this_mask_real - subband_mix_spec [
422
- i ].imag * this_mask_imag # B*nch, BW, T
423
- est_spec_imag = subband_mix_spec [i ].real * this_mask_imag + subband_mix_spec [
424
- i ].imag * this_mask_real # B*nch, BW, T
425
- self_sep_subband_spec .append (torch .complex (est_spec_real , est_spec_imag ))
447
+ est_spec_real = (
448
+ subband_mix_spec [i ].real * this_mask_real
449
+ - subband_mix_spec [i ].imag * this_mask_imag
450
+ ) # B*nch, BW, T
451
+ est_spec_imag = (
452
+ subband_mix_spec [i ].real * this_mask_imag
453
+ + subband_mix_spec [i ].imag * this_mask_real
454
+ ) # B*nch, BW, T
455
+ self_sep_subband_spec .append (
456
+ torch .complex (est_spec_real , est_spec_imag )
457
+ )
426
458
self_est_spec = torch .cat (self_sep_subband_spec , 1 ) # B*nch, F, T
427
- self_output = torch .istft (self_est_spec .view (batch_size * nch , self .enc_dim , - 1 ),
428
- n_fft = self .win , hop_length = self .stride ,
429
- window = torch .hann_window (self .win ).to (wav_input .device ).type (wav_input .type ()),
430
- length = nsample )
459
+ self_output = torch .istft (
460
+ self_est_spec .view (batch_size * nch , self .enc_dim , - 1 ),
461
+ n_fft = self .win ,
462
+ hop_length = self .stride ,
463
+ window = torch .hann_window (self .win )
464
+ .to (wav_input .device )
465
+ .type (wav_input .type ()),
466
+ length = nsample ,
467
+ )
431
468
432
469
self_output = self_output .view (batch_size , nch , - 1 )
433
470
self_s = torch .squeeze (self_output , dim = 1 )
434
471
435
- return s ,self_s , predict_speaker_lable ,self_predict_speaker_lable
472
+ return s , self_s , predict_speaker_lable , self_predict_speaker_lable
436
473
437
474
return s , predict_speaker_lable
438
475
0 commit comments