Skip to content

Commit c261deb

Browse files
committed
update for checks
1 parent 06e8a99 commit c261deb

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

examples/librimix/tse/v2/confs/bsrnn_feats.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ model_args:
6060
spk_model: ECAPA_TDNN_GLOB_c512
6161
spk_model_init: ./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
6262
spk_args:
63-
embed_dim: &embed_dim 192
63+
embed_dim: &embed_dim 192
6464
feat_dim: 80
6565
pooling_func: ASTP
6666
#################################################################

wesep/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_model(model_name: str):
1313
elif model_name.startswith("BSRNN_Multi"):
1414
return getattr(bsrnn_multi, model_name)
1515
elif model_name.startswith("BSRNN_Feats"):
16-
return getattr(bsrnn_feats, model_name)
16+
return getattr(bsrnn_feats, model_name)
1717
elif model_name.startswith("BSRNN"):
1818
return getattr(bsrnn, model_name)
1919
elif model_name.startswith("DPCCN"):

wesep/models/bsrnn_feats.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from wesep.modules.common.speaker import PreEmphasis
1313
from wesep.modules.common.speaker import SpeakerFuseLayer
1414
from wesep.modules.common.speaker import SpeakerTransform
15-
from wesep.utils.funcs import compute_fbank,apply_cmvn
15+
from wesep.utils.funcs import compute_fbank, apply_cmvn
1616

1717

1818
class ResRNN(nn.Module):
@@ -94,7 +94,7 @@ def forward(self, query, key, value):
9494
if query.dim() == 4:
9595
spk_embeddings = []
9696
for i in range(query.shape[1]):
97-
x = query[:, i, :, :].squeeze(dim=1) #(batch, feature, time)
97+
x = query[:, i, :, :].squeeze(dim=1) # (batch, feature, time)
9898
x, _ = self.multihead_attn(x.transpose(1, 2),
9999
key.transpose(1, 2),
100100
value.transpose(1, 2))
@@ -106,7 +106,7 @@ def forward(self, query, key, value):
106106
value.transpose(1, 2))
107107
spk_embeddings = x.transpose(1, 2)
108108
return spk_embeddings
109-
109+
110110
class FuseSeparation(nn.Module):
111111

112112
def __init__(
@@ -143,7 +143,7 @@ def __init__(
143143
SpeakerFuseLayer(
144144
embed_dim=spk_emb_dim,
145145
feat_dim=feature_dim,
146-
fuse_type=spk_fuse_type.lstrip("cross_"),
146+
fuse_type=spk_fuse_type.removeprefix("cross_"),
147147
))
148148
self.separation.append(BSNet(nband * feature_dim, nband))
149149
else:
@@ -152,7 +152,7 @@ def __init__(
152152
SpeakerFuseLayer(
153153
embed_dim=spk_emb_dim,
154154
feat_dim=feature_dim,
155-
fuse_type=spk_fuse_type.lstrip("cross_"),
155+
fuse_type=spk_fuse_type.removeprefix("cross_"),
156156
))
157157
for _ in range(num_repeat):
158158
self.separation.append(BSNet(nband * feature_dim, nband))
@@ -265,7 +265,7 @@ def __init__(
265265
else:
266266
self.spk_transform = nn.Identity()
267267

268-
if joint_training and (spk_fuse_type or spectral_feat=='tfmap_emb'):
268+
if joint_training and (spk_fuse_type or spectral_feat == 'tfmap_emb'):
269269
self.spk_model = get_speaker_model(spk_model)(**spk_args)
270270
if spk_model_init:
271271
pretrained_model = torch.load(spk_model_init)
@@ -300,7 +300,7 @@ def __init__(
300300
self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain)
301301
else:
302302
self.pred_linear = nn.Identity()
303-
303+
304304
spec_map = 2
305305
if spectral_feat:
306306
spec_map += 1
@@ -373,7 +373,7 @@ def forward(self, input, embeddings):
373373

374374
spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
375375

376-
########################## Calculate the spectral level feature
376+
# Calculate the spectral level feature
377377
if self.spectral_feat:
378378
aux_c = torch.stft(
379379
spk_emb_input,
@@ -383,22 +383,22 @@ def forward(self, input, embeddings):
383383
spk_emb_input.type()),
384384
return_complex=True,
385385
)
386-
if self.spectral_feat=='tfmap_spec':
386+
if self.spectral_feat == 'tfmap_spec':
387387
mix_mag_ori = torch.abs(spec)
388388
enroll_mag = torch.abs(aux_c)
389389

390390
mix_mag = F.normalize(mix_mag_ori, p=2, dim=1)
391391
enroll_mag = F.normalize(enroll_mag, p=2, dim=1)
392-
393-
mix_mag = mix_mag.permute(0,2,1).contiguous()
392+
393+
mix_mag = mix_mag.permute(0, 2, 1).contiguous()
394394
att_scores = torch.matmul(mix_mag, enroll_mag)
395395
att_weights = F.softmax(att_scores, dim=-1)
396-
enroll_mag = enroll_mag.permute(0,2,1).contiguous()
396+
enroll_mag = enroll_mag.permute(0, 2, 1).contiguous()
397397
tf_map = torch.matmul(att_weights, enroll_mag)
398-
tf_map = tf_map.permute(0,2,1).contiguous()
398+
tf_map = tf_map.permute(0, 2, 1).contiguous()
399399

400400
tf_map = tf_map / tf_map.norm(dim=1, keepdim=True)
401-
######### Recover the energy of estimated tfmap feature
401+
# Recover the energy of estimated tfmap feature
402402
tf_map = (
403403
torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True)
404404
* tf_map
@@ -407,8 +407,8 @@ def forward(self, input, embeddings):
407407
# tf_map = tf_map * mix_mag_ori.norm(dim=1, keepdim=True)
408408

409409
spec_RI = torch.cat((spec_RI, tf_map.unsqueeze(1)), dim=1)
410-
411-
if self.spectral_feat=='tfmap_emb': #Only supports Ecapa-TDNN model.
410+
411+
if self.spectral_feat == 'tfmap_emb': # Only Ecapa-TDNN.
412412
with torch.no_grad():
413413
signal_dim = wav_input.dim()
414414
extended_shape = (
@@ -454,33 +454,33 @@ def forward(self, input, embeddings):
454454
spk_emb = apply_cmvn(spk_emb)
455455

456456
spk_emb = self.spk_model(spk_emb)
457-
if isinstance(spk_emb,tuple):
457+
if isinstance(spk_emb, tuple):
458458
spk_emb_frame = spk_emb[0]
459459
else:
460460
spk_emb_frame = spk_emb
461461
mix_emb = self.spk_model(mix_emb)
462-
if isinstance(mix_emb,tuple):
462+
if isinstance(mix_emb, tuple):
463463
mix_emb_frame = mix_emb[0]
464464
else:
465465
mix_emb_frame = mix_emb
466466

467467
mix_emb_frame_ = F.normalize(mix_emb_frame, p=2, dim=1)
468468
spk_emb_frame_ = F.normalize(spk_emb_frame, p=2, dim=1)
469469

470-
mix_emb_frame_ = mix_emb_frame_.transpose(1,2)
470+
mix_emb_frame_ = mix_emb_frame_.transpose(1, 2)
471471
att_scores = torch.matmul(mix_emb_frame_, spk_emb_frame_)
472472
att_weights = F.softmax(att_scores, dim=-1)
473473

474474
mix_mag_ori = torch.abs(spec)
475475
enroll_mag = torch.abs(aux_c)
476476

477-
enroll_mag = enroll_mag.transpose(1,2)
477+
enroll_mag = enroll_mag.transpose(1, 2)
478478
# enroll_mag = F.normalize(enroll_mag, p=2, dim=1)
479479
tf_map = torch.matmul(att_weights, enroll_mag)
480-
tf_map = tf_map.transpose(1,2)
480+
tf_map = tf_map.transpose(1, 2)
481481

482482
tf_map = tf_map / tf_map.norm(dim=1, keepdim=True)
483-
######### Recover the energy of estimated tfmap feature
483+
# Recover the energy of estimated tfmap feature
484484
tf_map = (
485485
torch.sum(mix_mag_ori * tf_map, dim=1, keepdim=True)
486486
* tf_map
@@ -531,7 +531,7 @@ def forward(self, input, embeddings):
531531

532532
if self.spk_fuse_type and self.spk_fuse_type.startswith("cross_"):
533533
tmp_spk_emb_input = self.spk_model._get_frame_level_feat(
534-
spk_emb_input)
534+
spk_emb_input)
535535
else:
536536
tmp_spk_emb_input = self.spk_model(spk_emb_input)
537537
if isinstance(tmp_spk_emb_input, tuple):
@@ -543,7 +543,7 @@ def forward(self, input, embeddings):
543543
spk_embedding = self.spk_transform(spk_emb_input)
544544
if self.spk_fuse_type and not self.spk_fuse_type.startswith("cross_"):
545545
spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3)
546-
546+
547547
sep_output = self.separator(subband_feature, spk_embedding,
548548
torch.tensor(nch))
549549

0 commit comments

Comments
 (0)