12
12
from wesep .modules .common .speaker import PreEmphasis
13
13
from wesep .modules .common .speaker import SpeakerFuseLayer
14
14
from wesep .modules .common .speaker import SpeakerTransform
15
- from wesep .utils .funcs import compute_fbank ,apply_cmvn
15
+ from wesep .utils .funcs import compute_fbank , apply_cmvn
16
16
17
17
18
18
class ResRNN (nn .Module ):
@@ -94,7 +94,7 @@ def forward(self, query, key, value):
94
94
if query .dim () == 4 :
95
95
spk_embeddings = []
96
96
for i in range (query .shape [1 ]):
97
- x = query [:, i , :, :].squeeze (dim = 1 ) #(batch, feature, time)
97
+ x = query [:, i , :, :].squeeze (dim = 1 ) # (batch, feature, time)
98
98
x , _ = self .multihead_attn (x .transpose (1 , 2 ),
99
99
key .transpose (1 , 2 ),
100
100
value .transpose (1 , 2 ))
@@ -106,7 +106,7 @@ def forward(self, query, key, value):
106
106
value .transpose (1 , 2 ))
107
107
spk_embeddings = x .transpose (1 , 2 )
108
108
return spk_embeddings
109
-
109
+
110
110
class FuseSeparation (nn .Module ):
111
111
112
112
def __init__ (
@@ -143,7 +143,7 @@ def __init__(
143
143
SpeakerFuseLayer (
144
144
embed_dim = spk_emb_dim ,
145
145
feat_dim = feature_dim ,
146
- fuse_type = spk_fuse_type .lstrip ("cross_" ),
146
+ fuse_type = spk_fuse_type .removeprefix ("cross_" ),
147
147
))
148
148
self .separation .append (BSNet (nband * feature_dim , nband ))
149
149
else :
@@ -152,7 +152,7 @@ def __init__(
152
152
SpeakerFuseLayer (
153
153
embed_dim = spk_emb_dim ,
154
154
feat_dim = feature_dim ,
155
- fuse_type = spk_fuse_type .lstrip ("cross_" ),
155
+ fuse_type = spk_fuse_type .removeprefix ("cross_" ),
156
156
))
157
157
for _ in range (num_repeat ):
158
158
self .separation .append (BSNet (nband * feature_dim , nband ))
@@ -265,7 +265,7 @@ def __init__(
265
265
else :
266
266
self .spk_transform = nn .Identity ()
267
267
268
- if joint_training and (spk_fuse_type or spectral_feat == 'tfmap_emb' ):
268
+ if joint_training and (spk_fuse_type or spectral_feat == 'tfmap_emb' ):
269
269
self .spk_model = get_speaker_model (spk_model )(** spk_args )
270
270
if spk_model_init :
271
271
pretrained_model = torch .load (spk_model_init )
@@ -300,7 +300,7 @@ def __init__(
300
300
self .pred_linear = nn .Linear (spk_emb_dim , spksInTrain )
301
301
else :
302
302
self .pred_linear = nn .Identity ()
303
-
303
+
304
304
spec_map = 2
305
305
if spectral_feat :
306
306
spec_map += 1
@@ -373,7 +373,7 @@ def forward(self, input, embeddings):
373
373
374
374
spec_RI = torch .stack ([spec .real , spec .imag ], 1 ) # B*nch, 2, F, T
375
375
376
- ########################## Calculate the spectral level feature
376
+ # Calculate the spectral level feature
377
377
if self .spectral_feat :
378
378
aux_c = torch .stft (
379
379
spk_emb_input ,
@@ -383,22 +383,22 @@ def forward(self, input, embeddings):
383
383
spk_emb_input .type ()),
384
384
return_complex = True ,
385
385
)
386
- if self .spectral_feat == 'tfmap_spec' :
386
+ if self .spectral_feat == 'tfmap_spec' :
387
387
mix_mag_ori = torch .abs (spec )
388
388
enroll_mag = torch .abs (aux_c )
389
389
390
390
mix_mag = F .normalize (mix_mag_ori , p = 2 , dim = 1 )
391
391
enroll_mag = F .normalize (enroll_mag , p = 2 , dim = 1 )
392
-
393
- mix_mag = mix_mag .permute (0 ,2 , 1 ).contiguous ()
392
+
393
+ mix_mag = mix_mag .permute (0 , 2 , 1 ).contiguous ()
394
394
att_scores = torch .matmul (mix_mag , enroll_mag )
395
395
att_weights = F .softmax (att_scores , dim = - 1 )
396
- enroll_mag = enroll_mag .permute (0 ,2 , 1 ).contiguous ()
396
+ enroll_mag = enroll_mag .permute (0 , 2 , 1 ).contiguous ()
397
397
tf_map = torch .matmul (att_weights , enroll_mag )
398
- tf_map = tf_map .permute (0 ,2 , 1 ).contiguous ()
398
+ tf_map = tf_map .permute (0 , 2 , 1 ).contiguous ()
399
399
400
400
tf_map = tf_map / tf_map .norm (dim = 1 , keepdim = True )
401
- ######### Recover the energy of estimated tfmap feature
401
+ # Recover the energy of estimated tfmap feature
402
402
tf_map = (
403
403
torch .sum (mix_mag_ori * tf_map , dim = 1 , keepdim = True )
404
404
* tf_map
@@ -407,8 +407,8 @@ def forward(self, input, embeddings):
407
407
# tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True)
408
408
409
409
spec_RI = torch .cat ((spec_RI , tf_map .unsqueeze (1 )), dim = 1 )
410
-
411
- if self .spectral_feat == 'tfmap_emb' : # Only supports Ecapa-TDNN model .
410
+
411
+ if self .spectral_feat == 'tfmap_emb' : # Only Ecapa-TDNN.
412
412
with torch .no_grad ():
413
413
signal_dim = wav_input .dim ()
414
414
extended_shape = (
@@ -454,33 +454,33 @@ def forward(self, input, embeddings):
454
454
spk_emb = apply_cmvn (spk_emb )
455
455
456
456
spk_emb = self .spk_model (spk_emb )
457
- if isinstance (spk_emb ,tuple ):
457
+ if isinstance (spk_emb , tuple ):
458
458
spk_emb_frame = spk_emb [0 ]
459
459
else :
460
460
spk_emb_frame = spk_emb
461
461
mix_emb = self .spk_model (mix_emb )
462
- if isinstance (mix_emb ,tuple ):
462
+ if isinstance (mix_emb , tuple ):
463
463
mix_emb_frame = mix_emb [0 ]
464
464
else :
465
465
mix_emb_frame = mix_emb
466
466
467
467
mix_emb_frame_ = F .normalize (mix_emb_frame , p = 2 , dim = 1 )
468
468
spk_emb_frame_ = F .normalize (spk_emb_frame , p = 2 , dim = 1 )
469
469
470
- mix_emb_frame_ = mix_emb_frame_ .transpose (1 ,2 )
470
+ mix_emb_frame_ = mix_emb_frame_ .transpose (1 , 2 )
471
471
att_scores = torch .matmul (mix_emb_frame_ , spk_emb_frame_ )
472
472
att_weights = F .softmax (att_scores , dim = - 1 )
473
473
474
474
mix_mag_ori = torch .abs (spec )
475
475
enroll_mag = torch .abs (aux_c )
476
476
477
- enroll_mag = enroll_mag .transpose (1 ,2 )
477
+ enroll_mag = enroll_mag .transpose (1 , 2 )
478
478
# enroll_mag = F.normalize(enroll_mag, p=2, dim=1)
479
479
tf_map = torch .matmul (att_weights , enroll_mag )
480
- tf_map = tf_map .transpose (1 ,2 )
480
+ tf_map = tf_map .transpose (1 , 2 )
481
481
482
482
tf_map = tf_map / tf_map .norm (dim = 1 , keepdim = True )
483
- ######### Recover the energy of estimated tfmap feature
483
+ # Recover the energy of estimated tfmap feature
484
484
tf_map = (
485
485
torch .sum (mix_mag_ori * tf_map , dim = 1 , keepdim = True )
486
486
* tf_map
@@ -531,7 +531,7 @@ def forward(self, input, embeddings):
531
531
532
532
if self .spk_fuse_type and self .spk_fuse_type .startswith ("cross_" ):
533
533
tmp_spk_emb_input = self .spk_model ._get_frame_level_feat (
534
- spk_emb_input )
534
+ spk_emb_input )
535
535
else :
536
536
tmp_spk_emb_input = self .spk_model (spk_emb_input )
537
537
if isinstance (tmp_spk_emb_input , tuple ):
@@ -543,7 +543,7 @@ def forward(self, input, embeddings):
543
543
spk_embedding = self .spk_transform (spk_emb_input )
544
544
if self .spk_fuse_type and not self .spk_fuse_type .startswith ("cross_" ):
545
545
spk_embedding = spk_embedding .unsqueeze (1 ).unsqueeze (3 )
546
-
546
+
547
547
sep_output = self .separator (subband_feature , spk_embedding ,
548
548
torch .tensor (nch ))
549
549
0 commit comments