diff --git a/TTS/vocoder/tf/layers/parallel_wavegan.py b/TTS/vocoder/tf/layers/parallel_wavegan.py new file mode 100644 index 000000000..e3ca1659e --- /dev/null +++ b/TTS/vocoder/tf/layers/parallel_wavegan.py @@ -0,0 +1,79 @@ +import tensorflow as tf +from tensorflow.keras import layers + +class ResidualBlock(tf.keras.layers.Layer): + def __init__(self, + kernel_size=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + use_causal_conv=False, + bias = True, + **kwargs): + super(ResidualBlock, self).__init__(**kwargs) + self.dropout = dropout + if use_causal_conv: + + pad_left = (kernel_size - 1) * dilation + pad_right = 0 + else: + assert (kernel_size - 1) % 2 == 0, "kernel_size must be odd for symmetric padding" + pad_left = pad_right = (kernel_size - 1) // 2 * dilation + + self.use_causal_conv = use_causal_conv + + self.pad = layers.ZeroPadding1D(padding=(pad_left, pad_right)) + self.conv = layers.Conv1D( + filters=gate_channels, + kernel_size=kernel_size, + dilation_rate=dilation, + padding='valid', + use_bias=bias, + name='conv' + ) + if aux_channels > 0: + self.conv1x1_aux = layers.Conv1D( + filters=gate_channels, + kernel_size=1, + use_bias=False, + name='conv1x1_aux' + ) + else: + self.conv1x1_aux = None + + self.conv1x1_out = layers.Conv1D(filters=res_channels, + kernel_size=1, + padding='same', use_bias=bias,name='conv1x1_out') + self.conv1x1_skip = layers.Conv1D(filters=skip_channels, kernel_size=1, + padding='same', use_bias=bias,name= 'conv1x1_skip') + self.dropout = layers.Dropout(self.dropout) + + def call(self, x, c, training=False): + + residual = x + + x = self.dropout(x, training=training) + x = self.pad(x) + x = self.conv(x) + + x = x[:,:,:,:residual.size(-1)] if self.use_causal_conv else x + split = tf.split(x, num_or_size_splits=2, axis=-1) # tf uses channels-last + xa, xb = split[0], split[1] + + if c is not None : + assert self.conv1x1_aux is not None + c_conv = self.conv1x1_aux(c) + ca, cb = tf.split(c_conv, num_or_size_splits=2, axis=-1) + xa = xa + ca + xb = xb + cb + + x = tf.math.tanh(xa) * tf.math.sigmoid(xb) + + skip = self.conv1x1_skip(x) + res_out = self.conv1x1_out(x) + x = (res_out + residual) * 0.25 + return x, skip + \ No newline at end of file diff --git a/TTS/vocoder/tf/layers/upsample.py b/TTS/vocoder/tf/layers/upsample.py new file mode 100644 index 000000000..9a0fe23a1 --- /dev/null +++ b/TTS/vocoder/tf/layers/upsample.py @@ -0,0 +1,92 @@ +import tensorflow as tf +from tensorflow.keras import layers + +class UpsampleNetwork(tf.keras.layers.Layer): + def __init__(self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + **kwargs): + super(UpsampleNetwork, self).__init__(**kwargs) + self.use_causal_conv = use_causal_conv + + self.upsample_factors = upsample_factors + self.nonlinear = nonlinear_activation + self.nonlinear_params = nonlinear_activation_params or {} + self.freq_axis_kernel_size = freq_axis_kernel_size + self.use_causal_conv = use_causal_conv + + self.layers_list = [] + for scale in upsample_factors: + self.layers_list.append( + layers.UpSampling2D(size=(scale, 1), interpolation='nearest', data_format='channels_last') + ) + assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + self.layers_list.append( + layers.Conv2D(1,kernel_size=kernel_size, + padding='same', + use_bias= False, + name='upsample_convolution')) + + if nonlinear_activation is not None: + Activation = getattr(tf.keras.layers, nonlinear_activation) + self.layers_list.append(Activation(**self.nonlinear_params)) + + def call(self, c): + + c2d = c + for layer in self.layers_list: + + c2d = layer(c2d) + c2d = tf.squeeze(c2d, -1) + + return c2d + +class ConvUpsample(tf.keras.layers.Layer): + def __init__(self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False, + **kwargs): + super(ConvUpsample, self).__init__(**kwargs) + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + self.conv_in = tf.keras.layers.Conv1D( + filters=aux_channels, + kernel_size= kernel_size, + use_bias=False, + name = 'conv_in') + + self.upsample_net = UpsampleNetwork( + upsample_factors=upsample_factors, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + name = 'upsample_net') + + def call(self, c): + c2d = tf.transpose(c,[0,2,1]) + c2d = self.conv_in(c2d) + print(c2d.shape) + c2d = c2d[:,:, :-self.aux_context_window, :] if self.use_causal_conv else c2d + c_upsampled = tf.expand_dims(c2d, axis=-1) + return self.upsample_net(c_upsampled) \ No newline at end of file diff --git a/TTS/vocoder/tf/models/parallel_wavegan_generator.py b/TTS/vocoder/tf/models/parallel_wavegan_generator.py new file mode 100644 index 000000000..70b1abbd4 --- /dev/null +++ b/TTS/vocoder/tf/models/parallel_wavegan_generator.py @@ -0,0 +1,97 @@ +import tensorflow as tf +from tensorflow.keras import layers +import numpy as np + +from TTS.vocoder.tf.layers.parallel_wavegan import ResidualBlock +from TTS.vocoder.tf.layers.upsample import ConvUpsample + +class ParallelWaveganGenerator(tf.keras.Model): + def __init__(self, in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + bias = True, + use_weight_norm=True, + upsample_factors=[4,4,4,4], + inference_padding=2, + **kwargs): + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.num_res_blocks = num_res_blocks + self.stacks = stacks + self.kernel_size = kernel_size + self.upsample_factors = upsample_factors + self.upsample_scale = np.prod(upsample_factors) + self.inference_padding = inference_padding + self.use_weight_norm = use_weight_norm + + assert num_res_blocks % stacks == 0 + layers_per_stack = num_res_blocks // stacks + + self.first_conv = layers.Conv1D(filters=res_channels, + kernel_size=1, + use_bias=True, + name='first_conv') + + self.upsample_net = ConvUpsample( + upsample_factors=upsample_factors) + self.upsample_scale = np.prod(upsample_factors) + + self.residual_blocks = [] + for layer in range(num_res_blocks): + dilation = 2 ** (layer % layers_per_stack) + rb = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + name=f'conv_layers.{layer}' + ) + self.residual_blocks.append(rb) + + self.post_relu1 = layers.ReLU() + self.conv_post1 = layers.Conv1D(filters=skip_channels, kernel_size=1, + padding='same', use_bias=True) + self.post_relu2 = layers.ReLU() + self.conv_post2 = layers.Conv1D(filters=out_channels, kernel_size=1, + padding='same', use_bias=True) + + def call(self, c, training=False): + batch = tf.shape(c)[0] + t_in = tf.shape(c)[-1] + t_out = t_in * self.upsample_scale + x = tf.random.normal((batch, t_out ,1)) + + if c is not None and self.upsample_net is not None: + c_up = self.upsample_net(c) + assert c_up.shape[1] == x.shape[ + 1], f" [!] Upsampling scale does not match the expected output. {c_up.shape} vs {x.shape}" + + x = self.first_conv(x) + skips = 0 + for rb in self.residual_blocks: + x, s = rb(x, c_up, training=training) + skips += s + skips *= tf.math.sqrt(1.0 / float(self.num_res_blocks)) + x = skips + x = self.post_relu1(x) + x = self.conv_post1(x) + x = self.post_relu2(x) + x = self.conv_post2(x) + return x + + def inference(self, c): + c_padded = tf.pad(c, [[0,0], [self.inference_padding, self.inference_padding], [0,0]], mode='SYMMETRIC') + return self.call(c_padded, training=False) diff --git a/TTS/vocoder/tf/utils/generic_utils.py b/TTS/vocoder/tf/utils/generic_utils.py index 0daf2d6e1..e3de1383f 100644 --- a/TTS/vocoder/tf/utils/generic_utils.py +++ b/TTS/vocoder/tf/utils/generic_utils.py @@ -32,4 +32,18 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) + if c.generator_model.lower() in 'parallel_wavegan_generator': + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=c.generator_model_params['num_res_blocks'], + stacks=c.generator_model_params['stacks'], + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=c.audio['num_mels'], + dropout=0.0, + use_weight_norm=True, + upsample_factors=c.generator_model_params['upsample_factors']) return model