diff --git a/Modules/diffusion/modules.py b/Modules/diffusion/modules.py index 9df85b99..ed297fb4 100644 --- a/Modules/diffusion/modules.py +++ b/Modules/diffusion/modules.py @@ -14,28 +14,46 @@ """ Utils """ - class AdaLayerNorm(nn.Module): def __init__(self, style_dim, channels, eps=1e-5): super().__init__() self.channels = channels self.eps = eps - self.fc = nn.Linear(style_dim, channels*2) + self.fc = nn.Linear(style_dim, channels * 2) - def forward(self, x, s): - x = x.transpose(-1, -2) - x = x.transpose(1, -1) - + def forward(self, x, s, mask=None): + # x: (B, T, C) + # s: (B, style_dim) + # mask: (B, T), optional + h = self.fc(s) - h = h.view(h.size(0), h.size(1), 1) + h = h.view(h.size(0), h.size(1), 1) # (B, 2C, 1) gamma, beta = torch.chunk(h, chunks=2, dim=1) - gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) - - - x = F.layer_norm(x, (self.channels,), eps=self.eps) - x = (1 + gamma) * x + beta - return x.transpose(1, -1).transpose(-1, -2) + gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) # (B, C, 1) + + if mask is not None: + # Expand mask to match (B, T, C) + mask = mask.unsqueeze(-1) # (B, T, 1) + + # Calculate masked mean and variance + mean = (x * mask).sum(dim=1) / mask.sum(dim=1) + var = ((x - mean.unsqueeze(1)) ** 2 * mask).sum(dim=1) / mask.sum(dim=1) + + # Normalize + x_normalized = (x - mean.unsqueeze(1)) / torch.sqrt(var.unsqueeze(1) + self.eps) + + # Apply gamma and beta + x_normalized = (1 + gamma) * x_normalized + beta + + # Apply mask to the final output to ensure masked positions remain zero + x_normalized = x_normalized * mask + else: + # Standard layer norm without mask + x = F.layer_norm(x, (self.channels,), eps=self.eps) + x_normalized = (1 + gamma) * x + beta + + return x_normalized.transpose(1, -1).transpose(-1, -2) class StyleTransformer1d(nn.Module): def __init__( @@ -67,6 +85,7 @@ def __init__( use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance, + context_features=384, ) for i in range(num_layers) ] @@ -141,15 +160,14 @@ def get_mapping( return mapping - def run(self, x, time, embedding, features): + def run(self, x, time, embedding, features, context, context_mask): mapping = self.get_mapping(time, features) x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1) mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1) - for block in self.blocks: x = x + mapping - x = block(x, features) + x = block(x, features, context=context, context_mask=context_mask) x = x.mean(axis=1).unsqueeze(1) x = self.to_out(x) @@ -162,6 +180,8 @@ def forward(self, x: Tensor, embedding_mask_proba: float = 0.0, embedding: Optional[Tensor] = None, features: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, embedding_scale: float = 1.0) -> Tensor: b, device = embedding.shape[0], embedding.device @@ -175,14 +195,12 @@ def forward(self, x: Tensor, if embedding_scale != 1.0: # Compute both normal and fixed embedding outputs - out = self.run(x, time, embedding=embedding, features=features) - out_masked = self.run(x, time, embedding=fixed_embedding, features=features) + out = self.run(x, time, embedding=embedding, features=features, context=context, context_mask=context_mask) + out_masked = self.run(x, time, embedding=fixed_embedding, features=features, context=context, context_mask=context_mask) # Scale conditional output using classifier-free guidance return out_masked + (out - out_masked) * embedding_scale else: - return self.run(x, time, embedding=embedding, features=features) - - return x + return self.run(x, time, embedding=embedding, features=features, context=context, context_mask=context_mask) class StyleTransformerBlock(nn.Module): @@ -226,10 +244,10 @@ def __init__( self.feed_forward = FeedForward(features=features, multiplier=multiplier) - def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor]) -> Tensor: x = self.attention(x, s) + x if self.use_cross_attention: - x = self.cross_attention(x, s, context=context) + x + x = self.cross_attention(x, s, context=context, context_mask=context_mask) + x x = self.feed_forward(x) + x return x @@ -247,10 +265,9 @@ def __init__( rel_pos_max_distance: Optional[int] = None, ): super().__init__() - self.context_features = context_features mid_features = head_features * num_heads context_features = default(context_features, features) - + self.context_features = context_features self.norm = AdaLayerNorm(style_dim, features) self.norm_context = AdaLayerNorm(style_dim, context_features) self.to_q = nn.Linear( @@ -268,13 +285,13 @@ def __init__( rel_pos_max_distance=rel_pos_max_distance, ) - def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None) -> Tensor: assert_message = "You must provide a context when using context_features" - assert not self.context_features or exists(context), assert_message + # assert not self.context_features or exists(context), assert_message # Use context if provided context = default(context, x) # Normalize then compute q from input and k,v from context - x, context = self.norm(x, s), self.norm_context(context, s) + x, context = self.norm(x, s), self.norm_context(context, s, context_mask) q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) # Compute and return attention diff --git a/Modules/heteroGraph.py b/Modules/heteroGraph.py index 57adb040..c9258d74 100755 --- a/Modules/heteroGraph.py +++ b/Modules/heteroGraph.py @@ -25,7 +25,6 @@ def __init__(self, hidden_channels, out_channels, num_heads, num_layers, data): self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): - x_dict = { node_type: self.lin_dict[node_type](x).relu_() for node_type, x in x_dict.items() @@ -33,8 +32,7 @@ def forward(self, x_dict, edge_index_dict): for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) + + out_text = self.lin(x_dict["text"]) - return x_dict["text"] - - - + return out_text \ No newline at end of file diff --git a/Modules/slmadv.py b/Modules/slmadv.py index 8d5a8557..8fc5a738 100644 --- a/Modules/slmadv.py +++ b/Modules/slmadv.py @@ -19,28 +19,28 @@ def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, s def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None): text_mask = length_to_mask(ref_lengths).to(ref_text.device) - h_bert = self.model.bert(ref_text, attention_mask=(~text_mask).int()) - style = torch.zeros(h_bert.size(0), 1, 256).to(h_bert.device) - bert_dur = torch.cat([h_bert, style.expand(-1, h_bert.size(1), -1)], dim=-1) + bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) - # d_en = bert_dur.transpose(-1, -2) if use_ind and np.random.rand() < 0.5: s_preds = s_trg else: num_steps = np.random.randint(3, 5) + context = torch.zeros(1, 1, 384).to(ref_text.device) if ref_s is not None: s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), - embedding=bert_dur, - embedding_scale=1, - features=ref_s, # reference from the same speaker as the embedding - embedding_mask_proba=0.1, - num_steps=num_steps).squeeze(1) + embedding=bert_dur, + embedding_scale=1, + features=ref_s, # reference from the same speaker as the embedding + context=context, + embedding_mask_proba=0.1, + num_steps=num_steps).squeeze(1) else: s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), - embedding=bert_dur, - embedding_scale=1, - embedding_mask_proba=0.1, - num_steps=num_steps).squeeze(1) + embedding=bert_dur, + embedding_scale=1, + context=context, + embedding_mask_proba=0.1, + num_steps=num_steps).squeeze(1) s_dur = s_preds[:, 128:] s = s_preds[:, :128] diff --git a/models.py b/models.py index 8bad8747..8ef02d5b 100644 --- a/models.py +++ b/models.py @@ -698,23 +698,23 @@ def build_model(args, text_aligner, pitch_extractor, bert): style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder - style_predictor = StylePredictor(query_dim=512, key_dim=384, num_units=256, num_heads=2) + # style_predictor = StylePredictor(query_dim=512, key_dim=384, num_units=256, num_heads=2) # define diffusion model if args.multispeaker: transformer = StyleTransformer1d(channels=args.style_dim*2, - context_embedding_features=bert.config.hidden_size + 256, + context_embedding_features=bert.config.hidden_size, context_features=args.style_dim*2, **args.diffusion.transformer) else: transformer = Transformer1d(channels=args.style_dim*2, - context_embedding_features=bert.config.hidden_size + 256, + context_embedding_features=bert.config.hidden_size, **args.diffusion.transformer) diffusion = AudioDiffusionConditional( in_channels=1, embedding_max_length=bert.config.max_position_embeddings, - embedding_features=bert.config.hidden_size + 256, + embedding_features=bert.config.hidden_size, embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements, channels=args.style_dim*2, context_features=args.style_dim*2, @@ -741,7 +741,7 @@ def build_model(args, text_aligner, pitch_extractor, bert): nets = Munch( bert=bert, - bert_encoder=nn.Linear(bert.config.hidden_size + 256, args.hidden_dim), + bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim), predictor=predictor, decoder=decoder, @@ -749,7 +749,7 @@ def build_model(args, text_aligner, pitch_extractor, bert): predictor_encoder=predictor_encoder, style_encoder=style_encoder, - style_predictor=style_predictor, + # style_predictor=style_predictor, diffusion=diffusion, text_aligner = text_aligner, diff --git a/train_second.py b/train_second.py index a0c9497b..ca59d188 100644 --- a/train_second.py +++ b/train_second.py @@ -153,7 +153,7 @@ def main(config_path): None, first_stage_path, load_only_params=True, - ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log + ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion', 'hgt']) # keep starting epoch for tensorboard log # these epochs should be counted from the start epoch diff_epoch += start_epoch @@ -253,7 +253,7 @@ def main(config_path): _ = [model[key].eval() for key in model] - model.style_predictor.train() + # model.style_predictor.train() model.predictor.train() model.bert.train() model.bert_encoder.train() @@ -303,70 +303,64 @@ def main(config_path): # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool) ss = [] gs = [] - bert_dur = [] + contexts = [] for bib in range(len(mel_input_length)): - ### Heterograph-based encoding - history = histories[bib] - history_len = history["history_len"] - if history_len > 0: - current_text_tensor = history["current_text_tensor"] - history_text_tensors = history["history_text_tensors"] - history_text_tensors = torch.stack(history_text_tensors) - history_acoustic_features = history["history_acoustic_features"] - history_ss = [] - history_gs = [] - for history_acoustic_feature in history_acoustic_features: - history_s = model.predictor_encoder(history_acoustic_feature.unsqueeze(0).unsqueeze(1)) - history_ss.append(history_s) - history_s = model.style_encoder(history_acoustic_feature.unsqueeze(0).unsqueeze(1)) - history_gs.append(history_s) - history_ss = torch.stack(history_ss).squeeze() - history_gs = torch.stack(history_gs).squeeze() - if history_ss.dim() == 1: - history_ss = history_ss.unsqueeze(0) - if history_gs.dim() == 1: - history_gs = history_gs.unsqueeze(0) - - data = HeteroData() - data["text"].x = history_text_tensors - data["acoustic"].x = history_ss - data["prosody"].x = history_gs - - edge = [] - for _i in range(data["prosody"].x.shape[0]): - for _j in range(data["acoustic"].x.shape[0]): - edge.append([_j, _i]) - data["acoustic", "to", "prosody"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) - data["acoustic", "to", "acoustic"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) - data["prosody", "to", "prosody"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) - - edge = [] - # the length of the text is one more than the length of the acoustic/prosodic features - for _i in range(data["text"].x.shape[0]): - for _j in range(data["acoustic"].x.shape[0]): - edge.append([_j, _i]) - data["acoustic", "to", "text"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) - data["prosody", "to", "text"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) - - edge = [] - for _i in range(data["text"].x.shape[0]): - for _j in range(data["text"].x.shape[0]): - edge.append([_j, _i]) - data["text", "to", "text"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) - data = T.ToUndirected()(data) - - data, model.hgt = data.to(device), model.hgt.to(device) - out_text = model.hgt(data.x_dict, data.edge_index_dict) - - q = current_text_tensor.unsqueeze(0).unsqueeze(0) - k = v = out_text[:-1].unsqueeze(0) - style = model.style_predictor(q, k, v)[0] # (1, 1, 256) - else: - style = torch.zeros(1, 1, 256).to(device) - - h_bert = model.bert(texts[bib].unsqueeze(0), attention_mask=(~text_mask[bib].unsqueeze(0)).int()) - h_bert = torch.cat([h_bert, style.expand(-1, h_bert.size(1), -1)], dim=-1) - bert_dur.append(h_bert) + if epoch >= diff_epoch: + ### Heterograph-based encoding + history = histories[bib] + history_len = history["history_len"] + if history_len > 0: + current_text_tensor = history["current_text_tensor"] + history_text_tensors = history["history_text_tensors"] + history_text_tensors = torch.stack(history_text_tensors) + history_acoustic_features = history["history_acoustic_features"] + history_ss = [] + history_gs = [] + for history_acoustic_feature in history_acoustic_features: + history_s = model.predictor_encoder(history_acoustic_feature.unsqueeze(0).unsqueeze(1)) + history_ss.append(history_s) + history_s = model.style_encoder(history_acoustic_feature.unsqueeze(0).unsqueeze(1)) + history_gs.append(history_s) + history_ss = torch.stack(history_ss).squeeze() + history_gs = torch.stack(history_gs).squeeze() + if history_ss.dim() == 1: + history_ss = history_ss.unsqueeze(0) + if history_gs.dim() == 1: + history_gs = history_gs.unsqueeze(0) + + data = HeteroData() + data["text"].x = history_text_tensors + data["acoustic"].x = history_ss + data["prosody"].x = history_gs + + edge = [] + for _i in range(data["prosody"].x.shape[0]): + for _j in range(data["acoustic"].x.shape[0]): + edge.append([_j, _i]) + data["acoustic", "to", "prosody"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) + data["acoustic", "to", "acoustic"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) + data["prosody", "to", "prosody"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) + + edge = [] + # the length of the text is one more than the length of the acoustic/prosodic features + for _i in range(data["text"].x.shape[0]): + for _j in range(data["acoustic"].x.shape[0]): + edge.append([_j, _i]) + data["acoustic", "to", "text"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) + data["prosody", "to", "text"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) + + edge = [] + for _i in range(data["text"].x.shape[0]): + for _j in range(data["text"].x.shape[0]): + edge.append([_j, _i]) + data["text", "to", "text"].edge_index = torch.tensor(edge).contiguous().transpose(-2, -1) + data = T.ToUndirected()(data) + + data, model.hgt = data.to(device), model.hgt.to(device) + context = model.hgt(data.x_dict, data.edge_index_dict) # (H, 384) H means the number of history utterances + contexts.append(context) + else: + contexts.append(torch.zeros(1, 384).to(device)) # current style mel_length = int(mel_input_length[bib].item()) @@ -376,16 +370,22 @@ def main(config_path): s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) gs.append(s) - bert_dur = torch.stack(bert_dur).squeeze() - d_en = model.bert_encoder(bert_dur).transpose(-1, -2) - # d_en = bert_dur.transpose(-1, -2) - s_dur = torch.stack(ss).squeeze() # global prosodic styles gs = torch.stack(gs).squeeze() # global acoustic styles s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser + bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) # denoiser training if epoch >= diff_epoch: + # context padding + context_mask = length_to_mask(torch.tensor([c.size(0) for c in contexts])).to(device) # (B, H) + max_context_len = max([c.size(0) for c in contexts]) + for j in range(len(contexts)): + if contexts[j].size(0) < max_context_len: + contexts[j] = F.pad(contexts[j], (0, 0, 0, max_context_len - contexts[j].size(0))) + context = torch.stack(contexts) # historical context + num_steps = np.random.randint(3, 5) if model_params.diffusion.dist.estimate_sigma_data: @@ -394,20 +394,24 @@ def main(config_path): if multispeaker: s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device), - embedding=bert_dur, - embedding_scale=1, - features=ref, # reference from the same speaker as the embedding - embedding_mask_proba=0.1, - num_steps=num_steps).squeeze(1) - loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss + embedding=bert_dur, + embedding_scale=1, + features=ref, # reference from the same speaker as the embedding + context=context, + context_mask=~context_mask, + embedding_mask_proba=0.1, + num_steps=num_steps).squeeze(1) + loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref, context=context, context_mask=~context_mask).mean() # EDM loss loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss else: s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device), - embedding=bert_dur, - embedding_scale=1, - embedding_mask_proba=0.1, - num_steps=num_steps).squeeze(1) - loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss + embedding=bert_dur, + embedding_scale=1, + context=context, + context_mask=~context_mask, + embedding_mask_proba=0.1, + num_steps=num_steps).squeeze(1) + loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, context=context, context_mask=~context_mask).mean() # EDM loss loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss else: loss_sty = 0 @@ -536,7 +540,7 @@ def main(config_path): optimizer.step('predictor') optimizer.step('predictor_encoder') optimizer.step('hgt') - optimizer.step('style_predictor') + # optimizer.step('style_predictor') if epoch >= diff_epoch: optimizer.step('diffusion') @@ -685,13 +689,9 @@ def main(config_path): gs = torch.stack(gs).squeeze() s_trg = torch.cat([s, gs], dim=-1).detach() - # bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) - h_bert = model.bert(texts, attention_mask=(~text_mask).int()) - style = torch.zeros(h_bert.size(0), 1, 256).to(h_bert.device) - bert_dur = torch.cat([h_bert, style.expand(-1, h_bert.size(1), -1)], dim=-1) - + bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_dur).transpose(-1, -2) - # d_en = bert_dur.transpose(-1, -2) + d, p = model.predictor(d_en, s, input_lengths, s2s_attn_mono, @@ -805,18 +805,22 @@ def main(config_path): ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1)) ref_s = torch.cat([ref_ss, ref_sp], dim=1) + context = torch.zeros(1, 1, 384).to(device) for bib in range(len(d_en)): + if multispeaker: s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device), - embedding=bert_dur[bib].unsqueeze(0), - embedding_scale=1, + embedding=bert_dur[bib].unsqueeze(0), + embedding_scale=1, features=ref_s[bib].unsqueeze(0), # reference from the same speaker as the embedding - num_steps=5).squeeze(1) + context=context, + num_steps=5).squeeze(1) else: s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device), - embedding=bert_dur[bib].unsqueeze(0), - embedding_scale=1, - num_steps=5).squeeze(1) + embedding=bert_dur[bib].unsqueeze(0), + embedding_scale=1, + context=context, + num_steps=5).squeeze(1) s = s_pred[:, 128:] ref = s_pred[:, :128]