Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions TTS/vocoder/tf/layers/parallel_wavegan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import tensorflow as tf
from tensorflow.keras import layers

class ResidualBlock(tf.keras.layers.Layer):
def __init__(self,
kernel_size=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
dilation=1,
use_causal_conv=False,
bias = True,
**kwargs):
super(ResidualBlock, self).__init__(**kwargs)
self.dropout = dropout
if use_causal_conv:

pad_left = (kernel_size - 1) * dilation
pad_right = 0
else:
assert (kernel_size - 1) % 2 == 0, "kernel_size must be odd for symmetric padding"
pad_left = pad_right = (kernel_size - 1) // 2 * dilation

self.use_causal_conv = use_causal_conv

self.pad = layers.ZeroPadding1D(padding=(pad_left, pad_right))
self.conv = layers.Conv1D(
filters=gate_channels,
kernel_size=kernel_size,
dilation_rate=dilation,
padding='valid',
use_bias=bias,
name='conv'
)
if aux_channels > 0:
self.conv1x1_aux = layers.Conv1D(
filters=gate_channels,
kernel_size=1,
use_bias=False,
name='conv1x1_aux'
)
else:
self.conv1x1_aux = None

self.conv1x1_out = layers.Conv1D(filters=res_channels,
kernel_size=1,
padding='same', use_bias=bias,name='conv1x1_out')
self.conv1x1_skip = layers.Conv1D(filters=skip_channels, kernel_size=1,
padding='same', use_bias=bias,name= 'conv1x1_skip')
self.dropout = layers.Dropout(self.dropout)

def call(self, x, c, training=False):

residual = x

x = self.dropout(x, training=training)
x = self.pad(x)
x = self.conv(x)

x = x[:,:,:,:residual.size(-1)] if self.use_causal_conv else x
split = tf.split(x, num_or_size_splits=2, axis=-1) # tf uses channels-last
xa, xb = split[0], split[1]

if c is not None :
assert self.conv1x1_aux is not None
c_conv = self.conv1x1_aux(c)
ca, cb = tf.split(c_conv, num_or_size_splits=2, axis=-1)
xa = xa + ca
xb = xb + cb

x = tf.math.tanh(xa) * tf.math.sigmoid(xb)

skip = self.conv1x1_skip(x)
res_out = self.conv1x1_out(x)
x = (res_out + residual) * 0.25
return x, skip

92 changes: 92 additions & 0 deletions TTS/vocoder/tf/layers/upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import tensorflow as tf
from tensorflow.keras import layers

class UpsampleNetwork(tf.keras.layers.Layer):
def __init__(self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
use_causal_conv=False,
**kwargs):
super(UpsampleNetwork, self).__init__(**kwargs)
self.use_causal_conv = use_causal_conv

self.upsample_factors = upsample_factors
self.nonlinear = nonlinear_activation
self.nonlinear_params = nonlinear_activation_params or {}
self.freq_axis_kernel_size = freq_axis_kernel_size
self.use_causal_conv = use_causal_conv

self.layers_list = []
for scale in upsample_factors:
self.layers_list.append(
layers.UpSampling2D(size=(scale, 1), interpolation='nearest', data_format='channels_last')
)
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)

if use_causal_conv:
padding = (freq_axis_padding, scale * 2)
else:
padding = (freq_axis_padding, scale)
self.layers_list.append(
layers.Conv2D(1,kernel_size=kernel_size,
padding='same',
use_bias= False,
name='upsample_convolution'))

if nonlinear_activation is not None:
Activation = getattr(tf.keras.layers, nonlinear_activation)
self.layers_list.append(Activation(**self.nonlinear_params))

def call(self, c):

c2d = c
for layer in self.layers_list:

c2d = layer(c2d)
c2d = tf.squeeze(c2d, -1)

return c2d

class ConvUpsample(tf.keras.layers.Layer):
def __init__(self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
aux_channels=80,
aux_context_window=0,
use_causal_conv=False,
**kwargs):
super(ConvUpsample, self).__init__(**kwargs)
self.aux_context_window = aux_context_window
self.use_causal_conv = use_causal_conv and aux_context_window > 0

kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
self.conv_in = tf.keras.layers.Conv1D(
filters=aux_channels,
kernel_size= kernel_size,
use_bias=False,
name = 'conv_in')

self.upsample_net = UpsampleNetwork(
upsample_factors=upsample_factors,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
interpolate_mode=interpolate_mode,
freq_axis_kernel_size=freq_axis_kernel_size,
use_causal_conv=use_causal_conv,
name = 'upsample_net')

def call(self, c):
c2d = tf.transpose(c,[0,2,1])
c2d = self.conv_in(c2d)
print(c2d.shape)
c2d = c2d[:,:, :-self.aux_context_window, :] if self.use_causal_conv else c2d
c_upsampled = tf.expand_dims(c2d, axis=-1)
return self.upsample_net(c_upsampled)
97 changes: 97 additions & 0 deletions TTS/vocoder/tf/models/parallel_wavegan_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

from TTS.vocoder.tf.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.tf.layers.upsample import ConvUpsample

class ParallelWaveganGenerator(tf.keras.Model):
def __init__(self, in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
bias = True,
use_weight_norm=True,
upsample_factors=[4,4,4,4],
inference_padding=2,
**kwargs):
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.num_res_blocks = num_res_blocks
self.stacks = stacks
self.kernel_size = kernel_size
self.upsample_factors = upsample_factors
self.upsample_scale = np.prod(upsample_factors)
self.inference_padding = inference_padding
self.use_weight_norm = use_weight_norm

assert num_res_blocks % stacks == 0
layers_per_stack = num_res_blocks // stacks

self.first_conv = layers.Conv1D(filters=res_channels,
kernel_size=1,
use_bias=True,
name='first_conv')

self.upsample_net = ConvUpsample(
upsample_factors=upsample_factors)
self.upsample_scale = np.prod(upsample_factors)

self.residual_blocks = []
for layer in range(num_res_blocks):
dilation = 2 ** (layer % layers_per_stack)
rb = ResidualBlock(
kernel_size=kernel_size,
res_channels=res_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias,
name=f'conv_layers.{layer}'
)
self.residual_blocks.append(rb)

self.post_relu1 = layers.ReLU()
self.conv_post1 = layers.Conv1D(filters=skip_channels, kernel_size=1,
padding='same', use_bias=True)
self.post_relu2 = layers.ReLU()
self.conv_post2 = layers.Conv1D(filters=out_channels, kernel_size=1,
padding='same', use_bias=True)

def call(self, c, training=False):
batch = tf.shape(c)[0]
t_in = tf.shape(c)[-1]
t_out = t_in * self.upsample_scale
x = tf.random.normal((batch, t_out ,1))

if c is not None and self.upsample_net is not None:
c_up = self.upsample_net(c)
assert c_up.shape[1] == x.shape[
1], f" [!] Upsampling scale does not match the expected output. {c_up.shape} vs {x.shape}"

x = self.first_conv(x)
skips = 0
for rb in self.residual_blocks:
x, s = rb(x, c_up, training=training)
skips += s
skips *= tf.math.sqrt(1.0 / float(self.num_res_blocks))
x = skips
x = self.post_relu1(x)
x = self.conv_post1(x)
x = self.post_relu2(x)
x = self.conv_post2(x)
return x

def inference(self, c):
c_padded = tf.pad(c, [[0,0], [self.inference_padding, self.inference_padding], [0,0]], mode='SYMMETRIC')
return self.call(c_padded, training=False)
14 changes: 14 additions & 0 deletions TTS/vocoder/tf/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,18 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model.lower() in 'parallel_wavegan_generator':
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=c.generator_model_params['num_res_blocks'],
stacks=c.generator_model_params['stacks'],
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=c.audio['num_mels'],
dropout=0.0,
use_weight_norm=True,
upsample_factors=c.generator_model_params['upsample_factors'])
return model