Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions Modules/diffusion/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,46 @@
"""
Utils
"""

class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps

self.fc = nn.Linear(style_dim, channels*2)
self.fc = nn.Linear(style_dim, channels * 2)

def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)

def forward(self, x, s, mask=None):
# x: (B, T, C)
# s: (B, style_dim)
# mask: (B, T), optional

h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
h = h.view(h.size(0), h.size(1), 1) # (B, 2C, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)


x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) # (B, C, 1)

if mask is not None:
# Expand mask to match (B, T, C)
mask = mask.unsqueeze(-1) # (B, T, 1)

# Calculate masked mean and variance
mean = (x * mask).sum(dim=1) / mask.sum(dim=1)
var = ((x - mean.unsqueeze(1)) ** 2 * mask).sum(dim=1) / mask.sum(dim=1)

# Normalize
x_normalized = (x - mean.unsqueeze(1)) / torch.sqrt(var.unsqueeze(1) + self.eps)

# Apply gamma and beta
x_normalized = (1 + gamma) * x_normalized + beta

# Apply mask to the final output to ensure masked positions remain zero
x_normalized = x_normalized * mask
else:
# Standard layer norm without mask
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x_normalized = (1 + gamma) * x + beta

return x_normalized.transpose(1, -1).transpose(-1, -2)

class StyleTransformer1d(nn.Module):
def __init__(
Expand Down Expand Up @@ -67,6 +85,7 @@ def __init__(
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
context_features=384,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -141,15 +160,14 @@ def get_mapping(

return mapping

def run(self, x, time, embedding, features):
def run(self, x, time, embedding, features, context, context_mask):

mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)

for block in self.blocks:
x = x + mapping
x = block(x, features)
x = block(x, features, context=context, context_mask=context_mask)

x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
Expand All @@ -162,6 +180,8 @@ def forward(self, x: Tensor,
embedding_mask_proba: float = 0.0,
embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
embedding_scale: float = 1.0) -> Tensor:

b, device = embedding.shape[0], embedding.device
Expand All @@ -175,14 +195,12 @@ def forward(self, x: Tensor,

if embedding_scale != 1.0:
# Compute both normal and fixed embedding outputs
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
out = self.run(x, time, embedding=embedding, features=features, context=context, context_mask=context_mask)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features, context=context, context_mask=context_mask)
# Scale conditional output using classifier-free guidance
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)

return x
return self.run(x, time, embedding=embedding, features=features, context=context, context_mask=context_mask)


class StyleTransformerBlock(nn.Module):
Expand Down Expand Up @@ -226,10 +244,10 @@ def __init__(

self.feed_forward = FeedForward(features=features, multiplier=multiplier)

def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor]) -> Tensor:
x = self.attention(x, s) + x
if self.use_cross_attention:
x = self.cross_attention(x, s, context=context) + x
x = self.cross_attention(x, s, context=context, context_mask=context_mask) + x
x = self.feed_forward(x) + x
return x

Expand All @@ -247,10 +265,9 @@ def __init__(
rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)

self.context_features = context_features
self.norm = AdaLayerNorm(style_dim, features)
self.norm_context = AdaLayerNorm(style_dim, context_features)
self.to_q = nn.Linear(
Expand All @@ -268,13 +285,13 @@ def __init__(
rel_pos_max_distance=rel_pos_max_distance,
)

def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None) -> Tensor:
assert_message = "You must provide a context when using context_features"
assert not self.context_features or exists(context), assert_message
# assert not self.context_features or exists(context), assert_message
# Use context if provided
context = default(context, x)
# Normalize then compute q from input and k,v from context
x, context = self.norm(x, s), self.norm_context(context, s)
x, context = self.norm(x, s), self.norm_context(context, s, context_mask)

q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
# Compute and return attention
Expand Down
8 changes: 3 additions & 5 deletions Modules/heteroGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@ def __init__(self, hidden_channels, out_channels, num_heads, num_layers, data):
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x_dict, edge_index_dict):

x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}

for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)

out_text = self.lin(x_dict["text"])

return x_dict["text"]



return out_text
26 changes: 13 additions & 13 deletions Modules/slmadv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@ def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, s

def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
text_mask = length_to_mask(ref_lengths).to(ref_text.device)
h_bert = self.model.bert(ref_text, attention_mask=(~text_mask).int())
style = torch.zeros(h_bert.size(0), 1, 256).to(h_bert.device)
bert_dur = torch.cat([h_bert, style.expand(-1, h_bert.size(1), -1)], dim=-1)
bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
# d_en = bert_dur.transpose(-1, -2)
if use_ind and np.random.rand() < 0.5:
s_preds = s_trg
else:
num_steps = np.random.randint(3, 5)
context = torch.zeros(1, 1, 384).to(ref_text.device)
if ref_s is not None:
s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
embedding=bert_dur,
embedding_scale=1,
features=ref_s, # reference from the same speaker as the embedding
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
embedding=bert_dur,
embedding_scale=1,
features=ref_s, # reference from the same speaker as the embedding
context=context,
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
else:
s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
embedding=bert_dur,
embedding_scale=1,
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
embedding=bert_dur,
embedding_scale=1,
context=context,
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)

s_dur = s_preds[:, 128:]
s = s_preds[:, :128]
Expand Down
12 changes: 6 additions & 6 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,23 +698,23 @@ def build_model(args, text_aligner, pitch_extractor, bert):
style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder

style_predictor = StylePredictor(query_dim=512, key_dim=384, num_units=256, num_heads=2)
# style_predictor = StylePredictor(query_dim=512, key_dim=384, num_units=256, num_heads=2)

# define diffusion model
if args.multispeaker:
transformer = StyleTransformer1d(channels=args.style_dim*2,
context_embedding_features=bert.config.hidden_size + 256,
context_embedding_features=bert.config.hidden_size,
context_features=args.style_dim*2,
**args.diffusion.transformer)
else:
transformer = Transformer1d(channels=args.style_dim*2,
context_embedding_features=bert.config.hidden_size + 256,
context_embedding_features=bert.config.hidden_size,
**args.diffusion.transformer)

diffusion = AudioDiffusionConditional(
in_channels=1,
embedding_max_length=bert.config.max_position_embeddings,
embedding_features=bert.config.hidden_size + 256,
embedding_features=bert.config.hidden_size,
embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
channels=args.style_dim*2,
context_features=args.style_dim*2,
Expand All @@ -741,15 +741,15 @@ def build_model(args, text_aligner, pitch_extractor, bert):

nets = Munch(
bert=bert,
bert_encoder=nn.Linear(bert.config.hidden_size + 256, args.hidden_dim),
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),

predictor=predictor,
decoder=decoder,
text_encoder=text_encoder,

predictor_encoder=predictor_encoder,
style_encoder=style_encoder,
style_predictor=style_predictor,
# style_predictor=style_predictor,
diffusion=diffusion,

text_aligner = text_aligner,
Expand Down
Loading