diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 3886a8f8c9..37e3a1779d 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -3,9 +3,10 @@ import logging import os import random -from typing import Dict, List, Union +from typing import Any, Optional, Union import numpy as np +import numpy.typing as npt import torch import torchaudio import tqdm @@ -32,29 +33,34 @@ def _parse_sample(item): elif len(item) == 3: text, wav_file, speaker_name = item else: - raise ValueError(" [!] Dataset cannot parse the sample.") + msg = "Dataset cannot parse the sample." + raise ValueError(msg) return text, wav_file, speaker_name, language_name, attn_file -def noise_augment_audio(wav): +def noise_augment_audio(wav: npt.NDArray) -> npt.NDArray: return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) -def string2filename(string): +def string2filename(string: str) -> str: # generate a safe and reversible filename based on a string - filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") - return filename + return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") -def get_audio_size(audiopath) -> int: +def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: """Return the number of samples in the audio file.""" + if not isinstance(audiopath, str): + audiopath = str(audiopath) extension = audiopath.rpartition(".")[-1].lower() if extension not in {"mp3", "wav", "flac"}: - raise RuntimeError( - f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" - ) + msg = f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" + raise RuntimeError(msg) - return torchaudio.info(audiopath).num_frames + try: + return torchaudio.info(audiopath).num_frames + except RuntimeError as e: + msg = f"Failed to decode {audiopath}" + raise RuntimeError(msg) from e class TTSDataset(Dataset): @@ -63,31 +69,32 @@ def __init__( outputs_per_step: int = 1, compute_linear_spec: bool = False, ap: AudioProcessor = None, - samples: List[Dict] = None, + samples: Optional[list[dict]] = None, tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, compute_energy: bool = False, - f0_cache_path: str = None, - energy_cache_path: str = None, + f0_cache_path: Optional[str] = None, + energy_cache_path: Optional[str] = None, return_wav: bool = False, batch_group_size: int = 0, min_text_len: int = 0, max_text_len: int = float("inf"), min_audio_len: int = 0, max_audio_len: int = float("inf"), - phoneme_cache_path: str = None, + phoneme_cache_path: Optional[str] = None, precompute_num_workers: int = 0, - speaker_id_mapping: Dict = None, - d_vector_mapping: Dict = None, - language_id_mapping: Dict = None, + speaker_id_mapping: Optional[dict] = None, + d_vector_mapping: Optional[dict] = None, + language_id_mapping: Optional[dict] = None, use_noise_augment: bool = False, start_by_longest: bool = False, - ): + ) -> None: """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. If you need something different, you can subclass and override. Args: + ---- outputs_per_step (int): Number of time frames predicted per step. compute_linear_spec (bool): compute linear spectrogram if True. @@ -139,6 +146,7 @@ def __init__( use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + """ super().__init__() self.batch_group_size = batch_group_size @@ -168,25 +176,38 @@ def __init__( if self.tokenizer.use_phonemes: self.phoneme_dataset = PhonemeDataset( - self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + self.samples, + self.tokenizer, + phoneme_cache_path, + precompute_num_workers=precompute_num_workers, ) if compute_f0: self.f0_dataset = F0Dataset( - self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + self.samples, + self.ap, + cache_path=f0_cache_path, + precompute_num_workers=precompute_num_workers, ) if compute_energy: self.energy_dataset = EnergyDataset( - self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers + self.samples, + self.ap, + cache_path=energy_cache_path, + precompute_num_workers=precompute_num_workers, ) self.print_logs() @property - def lengths(self): + def lengths(self) -> list[int]: lens = [] for item in self.samples: _, wav_file, *_ = _parse_sample(item) - audio_len = get_audio_size(wav_file) + try: + audio_len = get_audio_size(wav_file) + except RuntimeError: + logger.warning(f"Failed to compute length for {item['audio_file']}") + audio_len = 0 lens.append(audio_len) return lens @@ -195,7 +216,7 @@ def samples(self): return self._samples @samples.setter - def samples(self, new_samples): + def samples(self, new_samples) -> None: self._samples = new_samples if hasattr(self, "f0_dataset"): self.f0_dataset.samples = new_samples @@ -204,7 +225,7 @@ def samples(self, new_samples): if hasattr(self, "phoneme_dataset"): self.phoneme_dataset.samples = new_samples - def __len__(self): + def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx): @@ -251,7 +272,7 @@ def get_token_ids(self, idx, text): token_ids = self.tokenizer.text_to_ids(text) return np.array(token_ids, dtype=np.int32) - def load_data(self, idx): + def load_data(self, idx) -> dict[str, Any]: item = self.samples[idx] raw_text = item["text"] @@ -285,7 +306,7 @@ def load_data(self, idx): if self.compute_energy: energy = self.get_energy(idx)["energy"] - sample = { + return { "raw_text": raw_text, "token_ids": token_ids, "wav": wav, @@ -298,13 +319,16 @@ def load_data(self, idx): "wav_file_name": os.path.basename(item["audio_file"]), "audio_unique_name": item["audio_unique_name"], } - return sample @staticmethod def _compute_lengths(samples): new_samples = [] for item in samples: - audio_length = get_audio_size(item["audio_file"]) + try: + audio_length = get_audio_size(item["audio_file"]) + except RuntimeError: + logger.warning(f"Failed to compute length, skipping {item['audio_file']}") + continue text_lenght = len(item["text"]) item["audio_length"] = audio_length item["text_length"] = text_lenght @@ -312,7 +336,7 @@ def _compute_lengths(samples): return new_samples @staticmethod - def filter_by_length(lengths: List[int], min_len: int, max_len: int): + def filter_by_length(lengths: list[int], min_len: int, max_len: int): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] @@ -325,10 +349,9 @@ def filter_by_length(lengths: List[int], min_len: int, max_len: int): return ignore_idx, keep_idx @staticmethod - def sort_by_length(samples: List[List]): + def sort_by_length(samples: list[list]): audio_lengths = [s["audio_length"] for s in samples] - idxs = np.argsort(audio_lengths) # ascending order - return idxs + return np.argsort(audio_lengths) # ascending order @staticmethod def create_buckets(samples, batch_group_size: int): @@ -348,7 +371,7 @@ def _select_samples_by_idx(idxs, samples): samples_new.append(samples[idx]) return samples_new - def preprocess_samples(self): + def preprocess_samples(self) -> None: r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. """ @@ -374,7 +397,8 @@ def preprocess_samples(self): samples = self._select_samples_by_idx(sorted_idxs, samples) if len(samples) == 0: - raise RuntimeError(" [!] No samples left") + msg = "No samples left." + raise RuntimeError(msg) # shuffle batch groups # create batches with similar length items @@ -388,36 +412,37 @@ def preprocess_samples(self): self.samples = samples logger.info("Preprocessing samples") - logger.info("Max text length: {}".format(np.max(text_lengths))) - logger.info("Min text length: {}".format(np.min(text_lengths))) - logger.info("Avg text length: {}".format(np.mean(text_lengths))) - logger.info("Max audio length: {}".format(np.max(audio_lengths))) - logger.info("Min audio length: {}".format(np.min(audio_lengths))) - logger.info("Avg audio length: {}".format(np.mean(audio_lengths))) + logger.info(f"Max text length: {np.max(text_lengths)}") + logger.info(f"Min text length: {np.min(text_lengths)}") + logger.info(f"Avg text length: {np.mean(text_lengths)}") + logger.info(f"Max audio length: {np.max(audio_lengths)}") + logger.info(f"Min audio length: {np.min(audio_lengths)}") + logger.info(f"Avg audio length: {np.mean(audio_lengths)}") logger.info("Num. instances discarded samples: %d", len(ignore_idx)) - logger.info("Batch group size: {}.".format(self.batch_group_size)) + logger.info(f"Batch group size: {self.batch_group_size}.") @staticmethod def _sort_batch(batch, text_lengths): """Sort the batch by the input text length for RNN efficiency. Args: + ---- batch (Dict): Batch returned by `__getitem__`. text_lengths (List[int]): Lengths of the input character sequences. + """ text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) batch = [batch[idx] for idx in ids_sorted_decreasing] return batch, text_lengths, ids_sorted_decreasing def collate_fn(self, batch): - r""" - Perform preprocessing and create a final data batch: + """Perform preprocessing and create a final data batch. + 1. Sort batch instances by text-length 2. Convert Audio signal to features. 3. PAD sequences wrt r. 4. Load to Torch. """ - # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) @@ -562,23 +587,18 @@ def collate_fn(self, batch): "audio_unique_names": batch["audio_unique_name"], } - raise TypeError( - ( - "batch must contain tensors, numbers, dicts or lists;\ - found {}".format( - type(batch[0]) - ) - ) - ) + msg = f"batch must contain tensors, numbers, dicts or lists; found {type(batch[0])}" + raise TypeError(msg) class PhonemeDataset(Dataset): - """Phoneme Dataset for converting input text to phonemes and then token IDs + """Phoneme Dataset for converting input text to phonemes and then token IDs. At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data loading latency. If `cache_path` is already present, it skips the pre-computation. Args: + ---- samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. @@ -590,15 +610,16 @@ class PhonemeDataset(Dataset): precompute_num_workers (int): Number of workers used for pre-computing the phonemes. Defaults to 0. + """ def __init__( self, - samples: Union[List[Dict], List[List]], + samples: Union[list[dict], list[list]], tokenizer: "TTSTokenizer", cache_path: str, - precompute_num_workers=0, - ): + precompute_num_workers: int = 0, + ) -> None: self.samples = samples self.tokenizer = tokenizer self.cache_path = cache_path @@ -606,16 +627,16 @@ def __init__( os.makedirs(cache_path) self.precompute(precompute_num_workers) - def __getitem__(self, index): + def __getitem__(self, index) -> dict[str, Any]: item = self.samples[index] ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"]) ph_hat = self.tokenizer.ids_to_text(ids) return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} - def __len__(self): + def __len__(self) -> int: return len(self.samples) - def compute_or_load(self, file_name, text, language): + def compute_or_load(self, file_name: str, text: str, language: str) -> list[int]: """Compute phonemes for the given text. If the phonemes are already cached, load them from cache. @@ -629,11 +650,11 @@ def compute_or_load(self, file_name, text, language): np.save(cache_path, ids) return ids - def get_pad_id(self): - """Get pad token ID for sequence padding""" + def get_pad_id(self) -> int: + """Get pad token ID for sequence padding.""" return self.tokenizer.pad_id - def precompute(self, num_workers=1): + def precompute(self, num_workers: int = 1) -> None: """Precompute phonemes for all samples. We use pytorch dataloader because we are lazy. @@ -642,7 +663,11 @@ def precompute(self, num_workers=1): with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 dataloder = torch.utils.data.DataLoader( - batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, ) for _ in dataloder: pbar.update(batch_size) @@ -667,12 +692,13 @@ def print_logs(self, level: int = 0) -> None: class F0Dataset: - """F0 Dataset for computing F0 from wav files in CPU + """F0 Dataset for computing F0 from wav files in CPU. Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It also computes the mean and std of F0 values if `normalize_f0` is True. Args: + ---- samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. @@ -688,17 +714,18 @@ class F0Dataset: normalize_f0 (bool): Whether to normalize F0 values by mean and std. Defaults to True. + """ def __init__( self, - samples: Union[List[List], List[Dict]], + samples: Union[list[list], list[dict]], ap: "AudioProcessor", audio_config=None, # pylint: disable=unused-argument - cache_path: str = None, - precompute_num_workers=0, - normalize_f0=True, - ): + cache_path: Optional[str] = None, + precompute_num_workers: int = 0, + normalize_f0: bool = True, + ) -> None: self.samples = samples self.ap = ap self.cache_path = cache_path @@ -720,10 +747,10 @@ def __getitem__(self, idx): f0 = self.normalize(f0) return {"audio_unique_name": item["audio_unique_name"], "f0": f0} - def __len__(self): + def __len__(self) -> int: return len(self.samples) - def precompute(self, num_workers=0): + def precompute(self, num_workers: int = 0) -> None: logger.info("Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 @@ -731,7 +758,11 @@ def precompute(self, num_workers=0): normalize_f0 = self.normalize_f0 self.normalize_f0 = False dataloder = torch.utils.data.DataLoader( - batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, ) computed_data = [] for batch in dataloder: @@ -750,9 +781,8 @@ def get_pad_id(self): return self.pad_id @staticmethod - def create_pitch_file_path(file_name, cache_path): - pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") - return pitch_file + def create_pitch_file_path(file_name: str, cache_path: str) -> str: + return os.path.join(cache_path, file_name + "_pitch.npy") @staticmethod def _compute_and_save_pitch(ap, wav_file, pitch_file=None): @@ -768,7 +798,7 @@ def compute_pitch_stats(pitch_vecs): mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std - def load_stats(self, cache_path): + def load_stats(self, cache_path) -> None: stats_path = os.path.join(cache_path, "pitch_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() self.mean = stats["mean"].astype(np.float32) @@ -789,9 +819,7 @@ def denormalize(self, pitch): return pitch def compute_or_load(self, wav_file, audio_unique_name): - """ - compute pitch and return a numpy array of pitch values - """ + """Compute pitch and return a numpy array of pitch values.""" pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) if not os.path.exists(pitch_file): pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) @@ -816,12 +844,13 @@ def print_logs(self, level: int = 0) -> None: class EnergyDataset: - """Energy Dataset for computing Energy from wav files in CPU + """Energy Dataset for computing Energy from wav files in CPU. Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It also computes the mean and std of Energy values if `normalize_Energy` is True. Args: + ---- samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. @@ -837,16 +866,17 @@ class EnergyDataset: normalize_Energy (bool): Whether to normalize Energy values by mean and std. Defaults to True. + """ def __init__( self, - samples: Union[List[List], List[Dict]], + samples: Union[list[list], list[dict]], ap: "AudioProcessor", - cache_path: str = None, + cache_path: Optional[str] = None, precompute_num_workers=0, normalize_energy=True, - ): + ) -> None: self.samples = samples self.ap = ap self.cache_path = cache_path @@ -868,10 +898,10 @@ def __getitem__(self, idx): energy = self.normalize(energy) return {"audio_unique_name": item["audio_unique_name"], "energy": energy} - def __len__(self): + def __len__(self) -> int: return len(self.samples) - def precompute(self, num_workers=0): + def precompute(self, num_workers=0) -> None: logger.info("Pre-computing energys...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 @@ -879,7 +909,11 @@ def precompute(self, num_workers=0): normalize_energy = self.normalize_energy self.normalize_energy = False dataloder = torch.utils.data.DataLoader( - batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, ) computed_data = [] for batch in dataloder: @@ -900,8 +934,7 @@ def get_pad_id(self): @staticmethod def create_energy_file_path(wav_file, cache_path): file_name = os.path.splitext(os.path.basename(wav_file))[0] - energy_file = os.path.join(cache_path, file_name + "_energy.npy") - return energy_file + return os.path.join(cache_path, file_name + "_energy.npy") @staticmethod def _compute_and_save_energy(ap, wav_file, energy_file=None): @@ -917,7 +950,7 @@ def compute_energy_stats(energy_vecs): mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std - def load_stats(self, cache_path): + def load_stats(self, cache_path) -> None: stats_path = os.path.join(cache_path, "energy_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() self.mean = stats["mean"].astype(np.float32) @@ -938,9 +971,7 @@ def denormalize(self, energy): return energy def compute_or_load(self, wav_file, audio_unique_name): - """ - compute energy and return a numpy array of energy values - """ + """Compute energy and return a numpy array of energy values.""" energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) if not os.path.exists(energy_file): energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)