diff --git a/.gitignore b/.gitignore index 6906016..1df7ba5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,27 @@ idea/* /trace_loss_nvidia.txt /conf /etc +.ipynb_checkpoints/Untitled-checkpoint.ipynb +dataset/audio/__pycache__/__init__.cpython-36.pyc +*.pyc +Untitled.ipynb +mel.npy +*.png +*.npy +Testing/2log_v2/no_exp_before_bins_fs2v2_2_31k_test_tts.wav +Testing/exp_log/test_tts.wav +Testing/exp_log_v2/exp_before_bins_fs2v2_2_31k_test_tts.wav +mel.png +mel.npy +Testing/v2_2/test_tts.wav +*.npy +*.png +mel.png +*.wav +*.npy +.ipynb_checkpoints/pitch_cwt-checkpoint.ipynb +pitch_cwt.ipynb +*.wav +Testing/test_tts.wav +*.wav +Testing/test_tts.wav diff --git a/Testing/test_tts.wav b/Testing/test_tts.wav new file mode 100644 index 0000000..cc5e972 Binary files /dev/null and b/Testing/test_tts.wav differ diff --git a/configs/default.yaml b/configs/default.yaml index 22f5c10..dcea266 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,6 +1,6 @@ data: - data_dir: 'H:\Deepsync\backup\fastspeech\data\' - wav_dir: 'H:\Deepsync\backup\deepsync\LJSpeech-1.1\wavs\' + data_dir: './data/LJSpeech/good_file/' + wav_dir: '/mnt/Karan/LJSpeech-1.1/wavs/' # Compute statistics e_mean: 21.578571319580078 e_std: 18.916799545288086 @@ -10,7 +10,7 @@ data: f0_mean: 206.5135564772342 f0_std: 53.633228905750336 p_min: 71.0 - p_max: 676.2260946528305 # 799.8901977539062 + p_max: 500.0 # 799.8901977539062 train_filelist: "./filelists/train_filelist.txt" valid_filelist: "./filelists/valid_filelist.txt" tts_cleaner_names: ['english_cleaners'] @@ -30,6 +30,7 @@ audio: bits : 9 # bit depth of signal mu_law : True # Recommended to suppress noise if using raw bits in hp.voc_mode below peak_norm : False # Normalise to the peak of each wav file + cwt_bins : 10 @@ -46,7 +47,7 @@ model: aheads: 2 elayers: 4 eunits: 1024 - ddim: 384 + ddim: 256 dlayers: 4 dunits: 1024 positionwise_layer_type : "conv1d" # linear @@ -110,7 +111,7 @@ train: # optimization related eos: False #True opt: 'noam' - accum_grad: 4 + accum_grad: 1 grad_clip: 1.0 weight_decay: 0.001 patience: 0 @@ -126,7 +127,7 @@ train: seed: 1 # random seed number resume: "" # the snapshot path to resume (if set empty, no effect) use_phonemes: True - batch_size : 16 + batch_size : 48 # other melgan_vocoder : True save_interval : 1000 @@ -135,4 +136,4 @@ train: summary_interval : 200 validation_step : 500 tts_max_mel_len : 870 # if you have a couple of extremely long spectrograms you might want to use this - tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training \ No newline at end of file + tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training diff --git a/core/variance_predictor.py b/core/variance_predictor.py index 11b93b6..d415a96 100644 --- a/core/variance_predictor.py +++ b/core/variance_predictor.py @@ -2,7 +2,9 @@ import torch.nn.functional as F from typing import Optional from core.modules import LayerNorm - +#import pycwt +import numpy as np +from sklearn import preprocessing class VariancePredictor(torch.nn.Module): def __init__( @@ -149,7 +151,11 @@ def inference(self, xs: torch.Tensor, alpha: float = 1.0): """ out = self.predictor.inference(xs, False, alpha=alpha) - return self.to_one_hot(out) # Need to do One hot code + #print(out.shape, type(out)) + #out = torch.from_numpy(np.load("/results/chkpts/LJ/Fastspeech2_V2/data/energy/LJ001-0001.npy")).cuda() + #print(out, "Energy Pricted") + out = torch.exp(out) + return self.to_one_hot(out), out # Need to do One hot code def to_one_hot(self, x): # e = de_norm_mean_std(e, hp.e_mean, hp.e_std) @@ -171,6 +177,7 @@ def __init__( min=0, max=0, n_bins=256, + out=5, ): """Initilize pitch predictor module. @@ -195,9 +202,29 @@ def __init__( ) ), ) - self.predictor = VariancePredictor(idim) + self.offset = offset + self.conv = torch.nn.ModuleList() + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chans, + n_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.ReLU(), + LayerNorm(n_chans), + torch.nn.Dropout(dropout_rate), + ) + ] + self.spectrogram_out = torch.nn.Linear(n_chans, out) + self.mean = torch.nn.Linear(n_chans, 1) + self.std = torch.nn.Linear(n_chans, 1) - def forward(self, xs: torch.Tensor, x_masks: torch.Tensor): + def forward(self, xs: torch.Tensor, olens: torch.Tensor, x_masks: torch.Tensor): """Calculate forward propagation. Args: @@ -208,9 +235,42 @@ def forward(self, xs: torch.Tensor, x_masks: torch.Tensor): Tensor: Batch of predicted durations in log domain (B, Tmax). """ - return self.predictor(xs, x_masks) + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) - def inference(self, xs: torch.Tensor, alpha: float = 1.0): + # NOTE: calculate in log domain + xs = xs.transpose(1, -1) + f0_spec = self.spectrogram_out(xs) # (B, Tmax, 10) + + if x_masks is not None: + # print("olen:", olens) + #f0_spec = f0_spec.transpose(1, -1) + # print("F0 spec dimension:", f0_spec.shape) + # print("x_masks dimension:", x_masks.shape) + f0_spec = f0_spec.masked_fill(x_masks, 0.0) + #f0_spec = f0_spec.transpose(1, -1) + # print("F0 spec dimension:", f0_spec.shape) + #xs = xs.transpose(1, -1) + xs = xs.masked_fill(x_masks, 0.0) + #xs = xs.transpose(1, -1) + # print("xs dimension:", xs.shape) + x_avg = xs.sum(dim=1).squeeze(1) + # print(x_avg) + # print("xs dim :", x_avg.shape) + # print("olens ;", olens.shape) + if olens is not None: + x_avg = x_avg / olens.unsqueeze(1) + # print(x_avg) + f0_mean = self.mean(x_avg).squeeze(-1) + f0_std = self.std(x_avg).squeeze(-1) + + # if x_masks is not None: + # f0_spec = f0_spec.masked_fill(x_masks, 0.0) + + return f0_spec, f0_mean, f0_std + + def inference(self, xs: torch.Tensor, olens = None, alpha: float = 1.0): """Inference duration. Args: @@ -221,8 +281,14 @@ def inference(self, xs: torch.Tensor, alpha: float = 1.0): LongTensor: Batch of predicted durations in linear domain (B, Tmax). """ - out = self.predictor.inference(xs, False, alpha=alpha) - return self.to_one_hot(out) + f0_spec, f0_mean, f0_std = self.forward(xs, olens, x_masks=None) # (B, Tmax, 10) + #print(f0_spec) + f0_reconstructed = self.inverse(f0_spec, f0_mean, f0_std) + #print(f0_reconstructed) + #f0_reconstructed = torch.from_numpy(np.load("/results/chkpts/LJ/Fastspeech2_V2/data/pitch/LJ001-0001.npy").reshape(1,-1)).cuda() + #print(f0_reconstructed, "Pitch coef output") + + return self.to_one_hot(f0_reconstructed), f0_reconstructed def to_one_hot(self, x: torch.Tensor): # e = de_norm_mean_std(e, hp.e_mean, hp.e_std) @@ -231,6 +297,24 @@ def to_one_hot(self, x: torch.Tensor): quantize = torch.bucketize(x, self.pitch_bins).to(device=x.device) # .cuda() return F.one_hot(quantize.long(), 256).float() + def inverse(self, Wavelet_lf0, f0_mean, f0_std): + scales = np.array([0.01, 0.02, 0.04, 0.08, 0.16]) #np.arange(1,11) + #print(Wavelet_lf0.shape) + Wavelet_lf0 = Wavelet_lf0.squeeze(0).cpu().numpy() + lf0_rec = np.zeros([Wavelet_lf0.shape[0], len(scales)]) + for i in range(0,len(scales)): + lf0_rec[:,i] = Wavelet_lf0[:,i]*((i+200+2.5)**(-2.5)) + + lf0_rec_sum = np.sum(lf0_rec,axis = 1) + lf0_rec_sum_norm = preprocessing.scale(lf0_rec_sum) + + f0_reconstructed = (torch.Tensor(lf0_rec_sum_norm).cuda()*f0_std) + f0_mean + + f0_reconstructed = torch.exp(f0_reconstructed) + #print(f0_reconstructed.shape) + #print(f0_reconstructed.shape) + return f0_reconstructed.reshape(1,-1) + class PitchPredictorLoss(torch.nn.Module): """Loss function module for duration predictor. diff --git a/dataset/audio/__init__.py b/dataset/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset/audio/audio_processing.py b/dataset/audio/audio_processing.py new file mode 100644 index 0000000..e176f49 --- /dev/null +++ b/dataset/audio/audio_processing.py @@ -0,0 +1,258 @@ +import math +import numpy as np +import librosa +from scipy.signal import lfilter +import pyworld as pw +import torch +from scipy.signal import get_window +import librosa.util as librosa_util + + +def label_2_float(x, bits): + return 2 * x / (2 ** bits - 1.0) - 1.0 + + +def float_2_label(x, bits): + assert abs(x).max() <= 1.0 + x = (x + 1.0) * (2 ** bits - 1) / 2 + return x.clip(0, 2 ** bits - 1) + + +def load_wav(path, hp): + return librosa.load(path, sr=hp.audio.sample_rate)[0] + + +def save_wav(x, path, hp): + librosa.output.write_wav(path, x.astype(np.float32), sr=hp.audio.sample_rate) + + +def split_signal(x): + unsigned = x + 2 ** 15 + coarse = unsigned // 256 + fine = unsigned % 256 + return coarse, fine + + +def combine_signal(coarse, fine): + return coarse * 256 + fine - 2 ** 15 + + +def encode_16bits(x): + return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) + + +mel_basis = None + + +def energy(y): + # Extract energy + S = librosa.magphase(stft(y))[0] + e = np.sqrt(np.sum(S ** 2, axis=0)) # np.linalg.norm(S, axis=0) + return e.squeeze() # (Number of frames) => (654,) + + +def pitch(y, hp): + # Extract Pitch/f0 from raw waveform using PyWORLD + y = y.astype(np.float64) + """ + f0_floor : float + Lower F0 limit in Hz. + Default: 71.0 + f0_ceil : float + Upper F0 limit in Hz. + Default: 800.0 + """ + f0, timeaxis = pw.dio( + y, + hp.audio.sample_rate, + frame_period=hp.audio.hop_length / hp.audio.sample_rate * 1000, + ) # For hop size 256 frame period is 11.6 ms + return f0 # (Number of Frames) = (654,) + + +def linear_to_mel(spectrogram, hp): + global mel_basis + if mel_basis is None: + mel_basis = build_mel_basis(hp) + return np.dot(mel_basis, spectrogram) + + +def build_mel_basis(hp): + return librosa.filters.mel( + hp.audio.sample_rate, + hp.audio.n_fft, + n_mels=hp.audio.num_mels, + fmin=hp.audio.fmin, + ) + + +def normalize(S, hp): + return np.clip((S - hp.audio.min_level_db) / -hp.audio.min_level_db, 0, 1) + + +def denormalize(S, hp): + return (np.clip(S, 0, 1) * -hp.audio.min_level_db) + hp.audio.min_level_db + + +def amp_to_db(x): + return 20 * np.log10(np.maximum(1e-5, x)) + + +def db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def spectrogram(y, hp): + D = stft(y, hp) + S = amp_to_db(np.abs(D)) - hp.audio.ref_level_db + return normalize(S, hp) + + +def melspectrogram(y, hp): + D = stft(y, hp) + S = amp_to_db(linear_to_mel(np.abs(D), hp)) + return normalize(S, hp) + + +def stft(y, hp): + return librosa.stft( + y=y, + n_fft=hp.audio.n_fft, + hop_length=hp.audio.hop_length, + win_length=hp.audio.win_length, + ) + + +def pre_emphasis(x, hp): + return lfilter([1, -hp.audio.preemphasis], [1], x) + + +def de_emphasis(x, hp): + return lfilter([1], [1, -hp.audio.preemphasis], x) + + +def encode_mu_law(x, mu): + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return np.floor((fx + 1) / 2 * mu + 0.5) + + +def decode_mu_law(y, mu, from_labels=True): + # TODO : get rid of log2 - makes no sense + if from_labels: + y = label_2_float(y, math.log2(mu)) + mu = mu - 1 + x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1) + return x + + +def reconstruct_waveform(mel, hp, n_iter=32): + """Uses Griffin-Lim phase reconstruction to convert from a normalized + mel spectrogram back into a waveform.""" + denormalized = denormalize(mel) + amp_mel = db_to_amp(denormalized) + S = librosa.feature.inverse.mel_to_stft( + amp_mel, + power=1, + sr=hp.audio.sample_rate, + n_fft=hp.audio.n_fft, + fmin=hp.audio.fmin, + ) + wav = librosa.core.griffinlim( + S, n_iter=n_iter, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length + ) + return wav + + +def quantize_input(input, min, max, num_bins=256): + bins = np.linspace(min, max, num=num_bins) + quantize = np.digitize(input, bins) + return quantize + + +def window_sumsquare( + window, + n_frames, + hop_length=200, + win_length=800, + n_fft=800, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + n_frames : int > 0 + The number of analysis frames + hop_length : int > 0 + The number of samples to advance between frames + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + n_fft : int > 0 + The length of each analysis frame. + dtype : np.dtype + The data type of the output + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles).cuda()) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/dataset/audio/pitch.py b/dataset/audio/pitch.py new file mode 100644 index 0000000..bfeedb6 --- /dev/null +++ b/dataset/audio/pitch.py @@ -0,0 +1,199 @@ +"""F0 extractor using DIO + Stonemask algorithm.""" + +import logging + +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +import numpy as np +import pyworld +import torch +import torch.nn.functional as F +import pycwt + +from scipy.interpolate import interp1d +from typeguard import check_argument_types + + +class Dio(): + """F0 estimation with dio + stonemask algortihm. + This is f0 extractor based on dio + stonmask algorithm introduced in `WORLD: + a vocoder-based high-quality speech synthesis system for real-time applications`_. + .. _`WORLD: a vocoder-based high-quality speech synthesis system for real-time + applications`: https://doi.org/10.1587/transinf.2015EDP7457 + Note: + This module is based on NumPy implementation. Therefore, the computational graph + is not connected. + Todo: + Replace this module with PyTorch-based implementation. + """ + + + def __init__( + self, + fs: int = 22050, + n_fft: int = 1024, + hop_length: int = 256, + f0min: Optional[int] = 71, + f0max: Optional[int] = 400, + use_token_averaged_f0: bool = False, + use_continuous_f0: bool = True, + use_log_f0: bool = True, + ): + assert check_argument_types() + super().__init__() + self.fs = fs + self.n_fft = n_fft + self.hop_length = hop_length + self.frame_period = 1000 * hop_length / fs + self.f0min = f0min + self.f0max = f0max + self.use_token_averaged_f0 = use_token_averaged_f0 + self.use_continuous_f0 = use_continuous_f0 + self.use_log_f0 = use_log_f0 + + def output_size(self) -> int: + return 1 + + def get_parameters(self) -> Dict[str, Any]: + return dict( + fs=self.fs, + n_fft=self.n_fft, + hop_length=self.hop_length, + f0min=self.f0min, + f0max=self.f0max, + use_token_averaged_f0=self.use_token_averaged_f0, + use_continuous_f0=self.use_continuous_f0, + use_log_f0=self.use_log_f0, + ) + + def forward( + self, + input: torch.Tensor, + feats_lengths: torch.Tensor = None, + durations: torch.Tensor = None, + utterance: list = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If not provide, we assume that the inputs have the same length + # F0 extraction + + # input shape = [T,] + pitch, pitch_log = self._calculate_f0(input) + # (Optional): Adjust length to match with the mel-spectrogram + if feats_lengths is not None: + pitch = [ + self._adjust_num_frames(p, fl).view(-1) + for p, fl in zip(pitch, feats_lengths) + ] + pitch_log = [ + self._adjust_num_frames(p, fl).view(-1) + for p, fl in zip(pitch_log, feats_lengths) + ] + + pitch_log_norm, mean, std = self._normalize(pitch_log, durations) + coefs = self._cwt(pitch_log_norm) + # (Optional): Average by duration to calculate token-wise f0 + if self.use_token_averaged_f0: + pitch = self._average_by_duration(pitch, durations) + pitch_lengths = len(durations) + else: + pitch_lengths = 22 #input.new_tensor([len(p) for p in pitch], dtype=torch.long) + # Return with the shape (B, T, 1) + return pitch, mean, std, coefs + + + def _calculate_f0(self, input: torch.Tensor) -> torch.Tensor: + x = input.cpu().numpy().astype(np.double) + #print(self.frame_period) + f0, timeaxis = pyworld.dio( + x, + self.fs, + f0_floor=self.f0min, + f0_ceil=self.f0max, + frame_period=self.frame_period, + ) + + f0 = pyworld.stonemask(x, f0, timeaxis, self.fs) + if self.use_continuous_f0: + f0 = self._convert_to_continuous_f0(f0) + + f0_log = np.zeros_like(f0) + + if self.use_log_f0: + nonzero_idxs = np.where(f0 != 0)[0] + f0_log[nonzero_idxs] = np.log(f0[nonzero_idxs]) + + return input.new_tensor(f0.reshape(-1), dtype=torch.float), input.new_tensor(f0_log.reshape(-1), dtype=torch.float) + + + @staticmethod + def _adjust_num_frames(x: torch.Tensor, num_frames: torch.Tensor) -> torch.Tensor: + if num_frames > len(x): + x = F.pad(x, (0, num_frames - len(x))) + elif num_frames < len(x): + x = x[:num_frames] + return x + + + @staticmethod + def _convert_to_continuous_f0(f0: np.array) -> np.array: + if (f0 == 0).all(): + logging.warn("All frames seems to be unvoiced.") + return f0 + + # padding start and end of f0 sequence + start_f0 = f0[f0 != 0][0] + end_f0 = f0[f0 != 0][-1] + #start_idx = np.where(f0 == start_f0)[0][0] + #end_idx = np.where(f0 == end_f0)[0][-1] + if f0[0] == 0: + f0[0] = 1.845 + if f0[-1] == 0: + f0[-1] = 1.845 # get non-zero frame index + nonzero_idxs = np.where(f0 != 0)[0] + # perform linear interpolation + interp_fn = interp1d(nonzero_idxs, f0[nonzero_idxs], kind= 'linear') + f0 = interp_fn(np.arange(0, f0.shape[0])) + return f0 + + @staticmethod + def _average_by_duration(x: torch.Tensor, d: torch.Tensor) -> torch.Tensor: + #print(d.sum(), len(x)) + if d.sum() != len(x): + d[-1] += 1 + d_cumsum = F.pad(d.cumsum(dim=0), (1, 0)) + x_avg = [ + x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0) + for start, end in zip(d_cumsum[:-1], d_cumsum[1:]) + ] + return torch.stack(x_avg) + + def _normalize(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor : + #if d.sum() != len(x): + # d[-1] += 1 + #d_cumsum = F.pad(d.cumsum(dim=0), (1, 0)) + norm_pitch = (x - x.mean())/x.std() + p_average = x.mean() + p_std = x.std() + + """ + for i in range(0, len(d_cumsum)-1): + pitch_i = x[d_cumsum[i]: d_cumsum[i+1]] + #print(pitch_i, "Pitch input") + p_average.append(pitch_i.mean()) + p_std.append(pitch_i.std()) + #print(pitch_i.std(), "pitch std") + #print(pitch_i.mean(), "pitch mean") + norm_pitch.extend((pitch_i - pitch_i.mean())/pitch_i.std()) + #print(norm_pitch[i], "Normalised pitch") + #print(norm_pitch, p_average, p_std) + """ + return norm_pitch, p_average, p_std + + def _cwt(self, x: torch.Tensor) -> torch.Tensor: + mother = pycwt.MexicanHat() + coefs, scales, _, _, _, _ = pycwt.cwt(x.numpy(), 0.25, 0.25, 0.5, J=9, wavelet=mother) # + #coefs shape = [10, T] + + return coefs.real diff --git a/dataset/audio/pitch_mod.py b/dataset/audio/pitch_mod.py new file mode 100644 index 0000000..5463e36 --- /dev/null +++ b/dataset/audio/pitch_mod.py @@ -0,0 +1,183 @@ +"""F0 extractor using DIO + Stonemask algorithm.""" + +import logging + +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +import numpy as np +import pyworld +import torch +import torch.nn.functional as F +import pycwt as wavelet + +from scipy.interpolate import interp1d +from typeguard import check_argument_types + + +class Dio(): + """F0 estimation with dio + stonemask algortihm. + This is f0 extractor based on dio + stonmask algorithm introduced in `WORLD: + a vocoder-based high-quality speech synthesis system for real-time applications`_. + .. _`WORLD: a vocoder-based high-quality speech synthesis system for real-time + applications`: https://doi.org/10.1587/transinf.2015EDP7457 + Note: + This module is based on NumPy implementation. Therefore, the computational graph + is not connected. + Todo: + Replace this module with PyTorch-based implementation. + """ + + + def __init__( + self, + fs: int = 22050, + n_fft: int = 1024, + hop_length: int = 256, + f0min: Optional[int] = 71, + f0max: Optional[int] = 400, + use_token_averaged_f0: bool = False, + use_continuous_f0: bool = True, + use_log_f0: bool = True, + J: int = 10 + ): + assert check_argument_types() + super().__init__() + self.fs = fs + self.n_fft = n_fft + self.hop_length = hop_length + self.frame_period = 1000 * hop_length / fs + self.f0min = f0min + self.f0max = f0max + self.use_token_averaged_f0 = use_token_averaged_f0 + self.use_continuous_f0 = use_continuous_f0 + self.use_log_f0 = use_log_f0 + self.J = J + + def output_size(self) -> int: + return 1 + + def get_parameters(self) -> Dict[str, Any]: + return dict( + fs=self.fs, + n_fft=self.n_fft, + hop_length=self.hop_length, + f0min=self.f0min, + f0max=self.f0max, + use_token_averaged_f0=self.use_token_averaged_f0, + use_continuous_f0=self.use_continuous_f0, + use_log_f0=self.use_log_f0, + ) + + def forward( + self, + input: torch.Tensor, + feats_lengths: torch.Tensor = None, + durations: torch.Tensor = None, + utterance: list = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If not provide, we assume that the inputs have the same length + # F0 extraction + + # input shape = [T,] + pitch = self._calculate_f0(input) + # (Optional): Adjust length to match with the mel-spectrogram + pitch, pitch_log = self._convert_to_continuous_f0(pitch) + + if feats_lengths is not None: + pitch = [ + self._adjust_num_frames(p, fl).view(-1) + for p, fl in zip(pitch, feats_lengths) + ] + pitch_log = [ + self._adjust_num_frames(p, fl).view(-1) + for p, fl in zip(pitch_log, feats_lengths) + ] + + pitch_log_norm, mean, std = self._normalize(pitch_log) + coefs, scales = self._cwt(pitch_log_norm) + # (Optional): Average by duration to calculate token-wise f0 + if self.use_token_averaged_f0: + pitch = self._average_by_duration(pitch, durations) + pitch_lengths = len(durations) + else: + pitch_lengths = 22 #input.new_tensor([len(p) for p in pitch], dtype=torch.long) + # Return with the shape (B, T, 1) + return pitch, mean, std, coefs + + + def _calculate_f0(self, input: torch.Tensor) -> torch.Tensor: + x = input.cpu().numpy().astype(np.double) + #print(self.frame_period) + _f0, t = pyworld.dio(x, self.fs, f0_floor = self.f0min, f0_ceil=self.f0max, frame_period=self.frame_period) # raw pitch extractor + f0 = pyworld.stonemask(x, _f0, t, self.fs) # pitch refinement + #sp = pw.cheaptrick(x, f0, t, self.fs, fft_size=self.n_fft) + #ap = pw.d4c(x, f0, t, fs, fft_size=self.n_fft) # extract aperiodicity + + return input.new_tensor(f0.reshape(-1), dtype=torch.float) + + + @staticmethod + def _adjust_num_frames(x: torch.Tensor, num_frames: torch.Tensor) -> torch.Tensor: + if num_frames > len(x): + x = F.pad(x, (0, num_frames - len(x))) + elif num_frames < len(x): + x = x[:num_frames] + return x + + + @staticmethod + def _convert_to_continuous_f0(f0: np.array) -> np.array: + + uv = np.float64(f0 != 0) + # get start and end of f0 + if (f0 == 0).all(): + print("all of the f0 values are 0.") + return uv, f0 + start_f0 = f0[f0 != 0][0] + end_f0 = f0[f0 != 0][-1] + + # padding start and end of f0 sequence + start_idx = np.where(f0 == start_f0)[0][0] + end_idx = np.where(f0 == end_f0)[0][-1] + f0[:start_idx] = start_f0 + f0[end_idx:] = end_f0 + + # get non-zero frame index + nz_frames = np.where(f0 != 0)[0] + + # perform linear interpolation + f = interp1d(nz_frames, f0[nz_frames]) + cont_f0 = f(np.arange(0, f0.shape[0])) + cont_f0_lpf = np.log(cont_f0) + + return cont_f0, cont_f0_lpf + + @staticmethod + def _average_by_duration(x: torch.Tensor, d: torch.Tensor) -> torch.Tensor: + #print(d.sum(), len(x)) + if d.sum() != len(x): + d[-1] += 1 + d_cumsum = F.pad(d.cumsum(dim=0), (1, 0)) + x_avg = [ + x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0) + for start, end in zip(d_cumsum[:-1], d_cumsum[1:]) + ] + return torch.stack(x_avg) + + def _normalize(self, x: torch.Tensor) -> torch.Tensor : + + norm_pitch = (x - x.mean())/x.std() + return norm_pitch, x.mean(), x.std() + + def _cwt(self, x: torch.Tensor) -> np.array: + mother = wavelet.MexicanHat() + dt = 0.005 + dj = 2 + s0 = dt*2 + J = self.J - 1 + Wavelet_lf0, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(x, dt, dj, s0, J, mother) + Wavelet_lf0 = np.real(Wavelet_lf0).T + + return Wavelet_lf0, scales diff --git a/dataset/dataloader.py b/dataset/dataloader.py index b44ca64..f0daa37 100644 --- a/dataset/dataloader.py +++ b/dataset/dataloader.py @@ -59,6 +59,9 @@ def __getitem__(self, index): p = remove_outlier( np.load(f"{self.path}pitch/{id}.npy") ) # self._norm_mean_std(np.load(f'{self.path}pitch/{id}.npy'), self.f0_mean, self.f0_std, True) + p_avg = np.load(f"{self.path}p_avg/{id}.npy") + p_std = np.load(f"{self.path}p_std/{id}.npy") + p_cwt_cont = np.load(f"{self.path}p_cwt_coef/{id}.npy") mel_len = mel.shape[1] durations = durations[: len(x)] durations[-1] = durations[-1] + (mel.shape[1] - sum(durations)) @@ -71,6 +74,9 @@ def __getitem__(self, index): np.array(durations), e, p, + p_avg, + p_std, + p_cwt_cont ) # Mel [T, num_mel] def __len__(self): @@ -107,6 +113,10 @@ def collate_tts(batch): energys = pad_list([torch.from_numpy(y[5]).float() for y in batch], 0) pitches = pad_list([torch.from_numpy(y[6]).float() for y in batch], 0) + pitches_avg = torch.Tensor([torch.from_numpy(x[7]).float() for x in batch]) #pad_list([torch.from_numpy(y[7]).float() for y in batch], 0) + pitches_std = torch.Tensor([torch.from_numpy(x[8]).float() for x in batch]) #pad_list([torch.from_numpy(y[8]).float() for y in batch], 0) + pitches_contour = pad_list([torch.from_numpy(y[9]).float() for y in batch], 0) + # make labels for stop prediction labels = mels.new_zeros(mels.size(0), mels.size(1)) for i, l in enumerate(olens): @@ -115,7 +125,7 @@ def collate_tts(batch): # scale spectrograms to -4 <--> 4 # mels = (mels * 8.) - 4 - return inputs, ilens, mels, labels, olens, ids, durations, energys, pitches + return inputs, ilens, mels, labels, olens, ids, durations, energys, pitches, pitches_avg, pitches_std, pitches_contour class BinnedLengthSampler(Sampler): diff --git a/evaluation.py b/evaluation.py index 43efedf..bf1760b 100644 --- a/evaluation.py +++ b/evaluation.py @@ -17,11 +17,11 @@ def evaluate(hp, validloader, model): l1 = torch.nn.L1Loss() model.eval() for valid in validloader: - x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid + x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_, p_avg_, p_std_, p_cwt_cont_ = valid with torch.no_grad(): ilens = torch.tensor([x_[-1].shape[0]], dtype=torch.long, device=x_.device) - _, after_outs, d_outs, e_outs, p_outs = model._forward(x_.cuda(), ilens.cuda(), out_length_.cuda(), dur_.cuda(), es=e_.cuda(), ps=p_.cuda(), is_inference=False) # [T, num_mel] + _, after_outs, d_outs, e_outs, p_outs, p_avg_outs, p_std_outs = model._forward(x_.cuda(), ilens.cuda(), out_length_.cuda(), dur_.cuda(), es=e_.cuda(), ps=p_.cuda(), is_inference=False) # [T, num_mel] # e_orig = model.energy_predictor.to_one_hot(e_).squeeze() # p_orig = model.pitch_predictor.to_one_hot(p_).squeeze() @@ -30,7 +30,7 @@ def evaluate(hp, validloader, model): dur_diff.append(l1(d_outs, dur_.cuda()).item()) #.numpy() energy_diff.append(l1(e_outs, e_.cuda()).item()) #.numpy() - pitch_diff.append(l1(p_outs, p_.cuda()).item()) #.numpy() + pitch_diff.append(l1(p_outs, p_cwt_cont_.cuda()).item()) #.numpy() '''_, target = read_wav_np( hp.data.wav_dir + f"{ids_[-1]}.wav", sample_rate=hp.audio.sample_rate) diff --git a/fastspeech.py b/fastspeech.py index 6796091..614c23a 100644 --- a/fastspeech.py +++ b/fastspeech.py @@ -109,6 +109,7 @@ def __init__(self, idim: int, odim: int, hp: Dict): dropout_rate=hp.model.duration_predictor_dropout_rate, min=hp.data.p_min, max=hp.data.p_max, + out = hp.audio.cwt_bins ) self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim) @@ -191,18 +192,23 @@ def _forward( if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) + #print(d_outs.sum(dim=1), d_outs.shape) + #d_outs = torch.Tensor([[3, 3, 10, 2, 6, 13, 21, 10, 12, 6, 3, 10, 9, 8, 3, 9, 14, 10, 8, 8, 5, 3, 5, 3, 5, 8, 8, 6, 9, 9, 7, 5, 9, 4, 3, 6, 7, 3, 3, 8, 4, 4, 6, 12, 21, 5, 7, 36, 7, 6, 9, 14, 18, 2, 6, 2, 3, 8, 5, 15, 9, 6, 3, 10, 7, 9, 9, 4, 3, 3, 7, 24, 5, 5, 8, 8, 4, 9, 5, 4, 3, 2, 11, 3, 14, 9, 6, 8, 9, 4, 7, 3, 3, 9, 4, 3, 6, 5, 4, 15, 3, 3, 9, 5, 8, 7, 4, 6, 9, 9, 6, 19]]).cuda() + #print(d_outs.sum(dim=1), d_outs.shape) + #print(hs.shape, "Hs shape before LR") hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim) - one_hot_energy = self.energy_predictor.inference(hs) # (B, Lmax, adim) - one_hot_pitch = self.pitch_predictor.inference(hs) # (B, Lmax, adim) + one_hot_energy, energy_output = self.energy_predictor.inference(hs) # (B, Lmax, adim) + one_hot_pitch, pitch_reconstructed = self.pitch_predictor.inference(hs, d_outs.sum(dim=1)) + #one_hot_pitch = self.pitch_predictor.inverse(f0, f_mean, f_std) # (B, Lmax, adim) else: with torch.no_grad(): # ds = self.duration_calculator(xs, ilens, ys, olens) # (B, Tmax) one_hot_energy = self.energy_predictor.to_one_hot( - es + es.detach() ) # (B, Lmax, adim) torch.Size([32, 868, 256]) # print("one_hot_energy:", one_hot_energy.shape) one_hot_pitch = self.pitch_predictor.to_one_hot( - ps + ps.detach() ) # (B, Lmax, adim) torch.Size([32, 868, 256]) # print("one_hot_pitch:", one_hot_pitch.shape) mel_masks = make_pad_mask(olens).to(xs.device) @@ -213,7 +219,8 @@ def _forward( # print("After Hs:",hs.shape) #torch.Size([32, 868, 256]) e_outs = self.energy_predictor(hs, mel_masks) # print("e_outs:", e_outs.shape) #torch.Size([32, 868]) - p_outs = self.pitch_predictor(hs, mel_masks) + mel_masks = make_pad_mask(olens).unsqueeze(-1).to(xs.device) + p_outs, p_avg_outs, p_std_outs = self.pitch_predictor(hs, olens, mel_masks) # print("p_outs:", p_outs.shape) #torch.Size([32, 868]) hs = hs + self.pitch_embed(one_hot_pitch) # (B, Lmax, adim) hs = hs + self.energy_embed(one_hot_energy) # (B, Lmax, adim) @@ -238,9 +245,9 @@ def _forward( ).transpose(1, 2) if is_inference: - return before_outs, after_outs, d_outs, one_hot_energy, one_hot_pitch + return before_outs, after_outs, d_outs, one_hot_energy, one_hot_pitch, pitch_reconstructed, energy_output else: - return before_outs, after_outs, d_outs, e_outs, p_outs + return before_outs, after_outs, d_outs, e_outs, p_outs, p_avg_outs, p_std_outs def forward( self, @@ -251,6 +258,9 @@ def forward( ds: torch.Tensor, es: torch.Tensor, ps: torch.Tensor, + ps_spec: torch.Tensor, + ps_avg: torch.Tensor = None, + ps_std: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Calculate forward propagation. Args: @@ -267,7 +277,7 @@ def forward( ys = ys[:, : max(olens)] # torch.Size([32, 868, 80]) -> [B, Lmax, odim] # forward propagation - before_outs, after_outs, d_outs, e_outs, p_outs = self._forward( + before_outs, after_outs, d_outs, e_outs, p_outs, p_avg_outs, p_std_outs = self._forward( xs, ilens, olens, ds, es, ps, is_inference=False ) @@ -278,6 +288,7 @@ def forward( # ys = ys[:, :max_olen] # apply mask to remove padded part + if self.use_masking: in_masks = make_non_pad_mask(ilens).to(xs.device) d_outs = d_outs.masked_select(in_masks) @@ -286,9 +297,11 @@ def forward( mel_masks = make_non_pad_mask(olens).to(ys.device) before_outs = before_outs.masked_select(out_masks) es = es.masked_select(mel_masks) # Write size - ps = ps.masked_select(mel_masks) # Write size + ps_spec = ps_spec.masked_select(out_masks) # Write size e_outs = e_outs.masked_select(mel_masks) # Write size - p_outs = p_outs.masked_select(mel_masks) # Write size + p_outs = p_outs.masked_select(out_masks) # Write size + #p_avg_outs = p_avg_outs.masked_select(mel_masks) # Write size + #p_std_outs = p_std_outs.masked_select(mel_masks) # Write size after_outs = ( after_outs.masked_select(out_masks) if after_outs is not None else None ) @@ -301,8 +314,10 @@ def forward( after_loss = self.criterion(after_outs, ys) l1_loss = before_loss + after_loss duration_loss = self.duration_criterion(d_outs, ds) - energy_loss = self.energy_criterion(e_outs, es) - pitch_loss = self.pitch_criterion(p_outs, ps) + energy_loss = self.energy_criterion(e_outs, torch.log(es)) + pitch_loss = self.pitch_criterion(p_outs, ps_spec) + pitch__mean_loss = self.pitch_criterion(p_avg_outs, ps_avg) + pitch_std_loss = self.pitch_criterion(p_std_outs, ps_std) # make weighted mask and apply it if self.use_weighted_masking: @@ -321,7 +336,7 @@ def forward( duration_loss.mul(duration_weights).masked_select(duration_masks).sum() ) - loss = l1_loss + duration_loss + energy_loss + pitch_loss + loss = l1_loss + duration_loss + energy_loss + pitch_loss + pitch__mean_loss + pitch_std_loss report_keys = [ {"l1_loss": l1_loss.item()}, {"before_loss": before_loss.item()}, @@ -329,6 +344,8 @@ def forward( {"duration_loss": duration_loss.item()}, {"energy_loss": energy_loss.item()}, {"pitch_loss": pitch_loss.item()}, + {"pitch__mean_loss": pitch__mean_loss.item()}, + {"pitch_std_loss": pitch_std_loss.item()}, {"loss": loss.item()}, ] @@ -352,9 +369,9 @@ def inference(self, x: torch.Tensor) -> torch.Tensor: xs = x.unsqueeze(0) # inference - _, outs, _, _, _ = self._forward(xs, ilens, is_inference=True) # (L, odim) + _, outs, _, _, _, pitch, energy = self._forward(xs, ilens, is_inference=True) # (L, odim) - return outs[0] + return outs[0], pitch, energy def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. diff --git a/inference.py b/inference.py index 15f8135..bdcd7d1 100644 --- a/inference.py +++ b/inference.py @@ -89,7 +89,7 @@ def preprocess(text): clean_content = english_cleaners(text) clean_content = punctuation_removers(clean_content) phonemes = g2p(clean_content) - + phonemes = ["" if x == " " else x for x in phonemes] phonemes = ["pau" if x == "," else x for x in phonemes] phonemes = ["pau" if x == "." else x for x in phonemes] @@ -125,7 +125,8 @@ def synth(text, model, hp): with torch.no_grad(): print("predicting") - outs = model.inference(text) # model(text) for jit script + print(text.shape) + outs, p, e = model.inference(text) # model(text) for jit script mel = outs return mel @@ -180,25 +181,23 @@ def main(args): if hp.train.melgan_vocoder: m = m.unsqueeze(0) print("Mel shape: ", m.shape) - vocoder = torch.hub.load("seungwonpark/melgan", "melgan") + vocoder = torch.jit.load("/results/chkpts/david/hifi-gan/v2/hifigan_david_v2_1370k.pt") # LJ/Hifi-GAN/original/hifigan_pre_trained_v1.pt torch.hub.load("seungwonpark/melgan", "melgan") JARED/Hifi-GAN/v1/hifigan_jared_1105k.pt vocoder.eval() + zero = torch.full((1, 80, 10), -11.5129).to(m.device) + m = torch.cat((m, zero), dim=2) + if torch.cuda.is_available(): vocoder = vocoder.cuda() mel = m.cuda() - - with torch.no_grad(): - wav = vocoder.inference( - mel - ) # mel ---> batch, num_mels, frames [1, 80, 234] - wav = wav.cpu().float().numpy() - else: - stft = STFT(filter_length=1024, hop_length=256, win_length=1024) - print(m.size()) - m = m.unsqueeze(0) - wav = griffin_lim(m, stft, 30) - wav = wav.cpu().numpy() - save_path = "{}/test_tts.wav".format(args.out) - write(save_path, hp.audio.sample_rate, wav.astype("int16")) + for i in range(0,len(para_mel)): + with torch.no_grad(): + wav = vocoder( + para_mel[i].unsqueeze(0) + ) # mel ---> batch, num_mels, frames [1, 80, 234] + #print(wav) + wav = wav.cpu().float().numpy() + save_path = f"{args.out}/fastspeech2_v2_david_hifigan_v2_{i}.wav" + write(save_path, hp.audio.sample_rate, wav.astype("float32")) # NOTE: you need this func to generate our sphinx doc diff --git a/nvidia_preprocessing.py b/nvidia_preprocessing.py index 337ba32..7a4e2e4 100644 --- a/nvidia_preprocessing.py +++ b/nvidia_preprocessing.py @@ -6,11 +6,14 @@ import numpy as np from utils.stft import TacotronSTFT from utils.util import read_wav_np -from dataset.audio_processing import pitch from utils.hparams import HParam +from dataset.audio.pitch_mod import Dio +from utils.util import str_to_int_list +import warnings +warnings.filterwarnings("error") -def main(args, hp): +def preprocess(data_path, hp, file): stft = TacotronSTFT( filter_length=hp.audio.n_fft, hop_length=hp.audio.hop_length, @@ -20,28 +23,68 @@ def main(args, hp): mel_fmin=hp.audio.fmin, mel_fmax=hp.audio.fmax, ) + pitch = Dio() + - wav_files = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True) + wav_files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) mel_path = os.path.join(hp.data.data_dir, "mels") energy_path = os.path.join(hp.data.data_dir, "energy") pitch_path = os.path.join(hp.data.data_dir, "pitch") + pitch_avg_path = os.path.join(hp.data.data_dir, "p_avg") + pitch_std_path = os.path.join(hp.data.data_dir, "p_std") + pitch_cwt_coefs = os.path.join(hp.data.data_dir, "p_cwt_coef") os.makedirs(mel_path, exist_ok=True) os.makedirs(energy_path, exist_ok=True) os.makedirs(pitch_path, exist_ok=True) + os.makedirs(pitch_avg_path, exist_ok=True) + os.makedirs(pitch_std_path, exist_ok=True) + os.makedirs(pitch_cwt_coefs, exist_ok=True) + print("Sample Rate : ", hp.audio.sample_rate) - for wavpath in tqdm.tqdm(wav_files, desc="preprocess wav to mel"): - sr, wav = read_wav_np(wavpath, hp.audio.sample_rate) - p = pitch(wav, hp) # [T, ] T = Number of frames - wav = torch.from_numpy(wav).unsqueeze(0) - mel, mag = stft.mel_spectrogram(wav) # mel [1, 80, T] mag [1, num_mag, T] - mel = mel.squeeze(0) # [num_mel, T] - mag = mag.squeeze(0) # [num_mag, T] - e = torch.norm(mag, dim=0) # [T, ] - p = p[: mel.shape[1]] - id = os.path.basename(wavpath).split(".")[0] - np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False) - np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False) - np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False) + + with open("{}".format(file), encoding="utf-8") as f: + _metadata = [line.strip().split("|") for line in f] + + + for metadata in tqdm.tqdm(_metadata, desc="preprocess wav to mel"): + try: + wavpath = os.path.join(data_path, metadata[4]) + sr, wav = read_wav_np(wavpath, hp.audio.sample_rate) + input_wav = torch.from_numpy(wav) + + dur = str_to_int_list(metadata[2]) + dur = torch.from_numpy(np.array(dur)) + + p, avg, std, p_coef = pitch.forward(input_wav, durations = dur) # shape in order - (T,) (no of utternace, ), (no of utternace, ), (10, T) + #print(p.shape, avg.shape, std.shape, p_coef.shape) + + wav = torch.from_numpy(wav).unsqueeze(0) + mel, mag = stft.mel_spectrogram(wav) # mel [1, 80, T] mag [1, num_mag, T] + mel = mel.squeeze(0) # [num_mel, T] + mag = mag.squeeze(0) # [num_mag, T] + e = torch.norm(mag, dim=0) # [T, ] + + id = os.path.basename(wavpath).split(".")[0] + + assert(e.numpy().shape == p.shape) + + np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False) + np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False) + np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False) + np.save("{}/{}.npy".format(pitch_avg_path, id), avg, allow_pickle=False) + np.save("{}/{}.npy".format(pitch_std_path, id), std, allow_pickle=False) + np.save("{}/{}.npy".format(pitch_cwt_coefs, id), p_coef.reshape(-1, hp.audio.cwt_bins), allow_pickle=False) + + except Exception as e: + print("{}\n".format(metadata[3])) + + + +def main(args, hp): + print("Preprocess Training dataset :") + preprocess(args.data_path, hp, hp.data.train_filelist) + print("Preprocess Validation dataset :") + preprocess(args.data_path, hp, hp.data.valid_filelist) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 6eb3e05..a5c238f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ configargparse tensorboardX typeguard==2.9.1 g2p_en +pycwt diff --git a/train_fastspeech.py b/train_fastspeech.py index 99f42a6..db9f7bc 100644 --- a/train_fastspeech.py +++ b/train_fastspeech.py @@ -12,7 +12,7 @@ import tqdm import time from evaluation import evaluate -from utils.plot import generate_audio, plot_spectrogram_to_numpy +from utils.plot import generate_audio, plot_spectrogram_to_numpy, plot_waveform_to_numpy from core.optimizer import get_std_opt from utils.util import read_wav_np from dataset.texts import valid_symbols @@ -93,7 +93,7 @@ def train(args, hp, hp_str, logger, vocoder): pbar = tqdm.tqdm(dataloader, desc="Loading train data") for data in pbar: global_step += 1 - x, input_length, y, _, out_length, _, dur, e, p = data + x, input_length, y, _, out_length, _, dur, e, p, p_avg, p_std, p_cwt_cont = data # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel] # # stop_token : [batch, T_in], out_length : [batch] @@ -105,6 +105,9 @@ def train(args, hp, hp_str, logger, vocoder): dur.cuda(), e.cuda(), p.cuda(), + p_cwt_cont.cuda(), + p_avg.cuda(), + p_std.cuda() ) loss = loss.mean() / hp.train.accum_grad running_loss += loss.item() @@ -148,7 +151,7 @@ def train(args, hp, hp_str, logger, vocoder): if step % hp.train.validation_step == 0: for valid in validloader: - x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid + x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_, p_avg_, p_std_, p_cwt_cont_ = valid model.eval() with torch.no_grad(): loss_, report_dict_ = model( @@ -159,9 +162,12 @@ def train(args, hp, hp_str, logger, vocoder): dur_.cuda(), e_.cuda(), p_.cuda(), + p_cwt_cont_.cuda(), + p_avg_.cuda(), + p_std_.cuda() ) - mels_ = model.inference(x_[-1].cuda()) # [T, num_mel] + mels_, pitch_reconstructed, energy_reconstructed = model.inference(x_[-1].cuda()) # [T, num_mel] model.train() for r in report_dict_: @@ -174,6 +180,7 @@ def train(args, hp, hp_str, logger, vocoder): writer.add_scalar("validation/{}".format(k), v, step) mels_ = mels_.T # Out: [num_mels, T] + writer.add_image( "melspectrogram_target_{}".format(ids_[-1]), plot_spectrogram_to_numpy( @@ -186,18 +193,32 @@ def train(args, hp, hp_str, logger, vocoder): "melspectrogram_prediction_{}".format(ids_[-1]), plot_spectrogram_to_numpy(mels_.data.cpu().numpy()), step, - dataformats="HWC", + dataformats="HWC" + ) + + writer.add_figure( + "Pitch_target_vs_prediction/{}".format(ids_[-1]), + plot_waveform_to_numpy(pitch_reconstructed.cpu().numpy().reshape(-1,), p_.cpu().numpy().reshape(-1,)), + step, ) - # print(mels.unsqueeze(0).shape) + writer.add_figure( + "Energy_target_vs_prediction/{}".format(ids_[-1]), + plot_waveform_to_numpy(energy_reconstructed.cpu().numpy().reshape(-1,), e_.cpu().numpy().reshape(-1,)), + step, + ) + + mels = mels_.unsqueeze(0) + zero = torch.full((1, 80, 10), -11.5129).to(mels.device) + mels = torch.cat((mels, zero), dim=2) - audio = generate_audio( - mels_.unsqueeze(0), vocoder - ) # selecting the last data point to match mel generated above - audio = audio.cpu().float().numpy() + + audio = vocoder(mels) #generate_audio(mels_.unsqueeze(0), vocoder) # selecting the last data point to match mel generated above + audio = audio.detach().cpu().float().numpy() audio = audio / ( audio.max() - audio.min() ) # get values between -1 and 1 + audio = audio.reshape(-1,1) writer.add_audio( tag="generated_audio_{}".format(ids_[-1]), @@ -218,7 +239,7 @@ def train(args, hp, hp_str, logger, vocoder): sample_rate=hp.audio.sample_rate, ) - ## + if step % hp.train.save_interval == 0: avg_p, avg_e, avg_d = evaluate(hp, validloader, model) writer.add_scalar("evaluation/Pitch_Loss", avg_p, step) @@ -443,9 +464,7 @@ def main(cmd_args): random.seed(hp.train.seed) np.random.seed(hp.train.seed) - vocoder = torch.hub.load( - "seungwonpark/melgan", "melgan" - ) # load the vocoder for validation + vocoder = torch.jit.load('vocgan_jared_first_1871233_2220.pt').cuda() # torch.hub.load( "seungwonpark/melgan", "melgan" ) # load the vocoder for validation if hp.train.GTA: create_gta(args, hp, hp_str, logger) diff --git a/utils/plot.py b/utils/plot.py index 8b38655..2aa7419 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -63,20 +63,22 @@ def save_figure_to_numpy(fig, spectrogram=False): return data -def plot_waveform_to_numpy(waveform): - fig, ax = plt.subplots(figsize=(12, 3)) - ax.plot() - ax.plot(range(len(waveform)), waveform, linewidth=0.1, alpha=0.7, color="blue") +def plot_waveform_to_numpy(waveform1, waveform2): + if len(waveform1) < len(waveform2): + l = len(waveform1) + else: + l = len(waveform2) + + fig = plt.figure(figsize = (12,7)) + x = np.linspace(1, l, l) + plt.plot(x, waveform1[0:l], label = "reconstructed" , color="blue") + plt.plot(x, waveform2[0:l], label = "ground_truth", color="orange") plt.xlabel("Samples") plt.ylabel("Amplitude") - plt.ylim(-1, 1) plt.tight_layout() - fig.canvas.draw() - data = save_figure_to_numpy(fig) - plt.close() - return data + return fig def plot_spectrogram_to_numpy(spectrogram):