Skip to content

Commit

Permalink
Merge pull request #3404 from freds0/dev
Browse files Browse the repository at this point in the history
Training fastspeech2 with External Speaker Embeddings
  • Loading branch information
erogol authored Dec 12, 2023
2 parents c99e885 + f911791 commit 936084b
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions TTS/tts/models/forward_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def __init__(
)

self.duration_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.hidden_channels,
self.args.duration_predictor_hidden_channels,
self.args.duration_predictor_kernel_size,
self.args.duration_predictor_dropout_p,
)

if self.args.use_pitch:
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.hidden_channels,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
Expand All @@ -263,7 +263,7 @@ def __init__(

if self.args.use_energy:
self.energy_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.hidden_channels,
self.args.energy_predictor_hidden_channels,
self.args.energy_predictor_kernel_size,
self.args.energy_predictor_dropout_p,
Expand Down Expand Up @@ -299,7 +299,8 @@ def init_multispeaker(self, config: Coqpit):
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
if self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
Expand Down Expand Up @@ -403,10 +404,13 @@ def _forward_encoder(
# [B, T, C]
x_emb = self.emb(x)
# encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
# speaker conditioning
# TODO: try different ways of conditioning
if g is not None:
if g is not None:
if hasattr(self, "proj_g"):
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
o_en = o_en + g
return o_en, x_mask, g, x_emb

Expand Down

0 comments on commit 936084b

Please sign in to comment.