From 39cc8a671085ae059ac36f3cf2d68e12a05425e4 Mon Sep 17 00:00:00 2001 From: MrPandir Date: Thu, 14 Mar 2024 18:23:37 +0100 Subject: [PATCH] refactor: do not use the model method to convert in WAV --- tts/tts.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tts/tts.py b/tts/tts.py index 6137a5e..eaa755d 100644 --- a/tts/tts.py +++ b/tts/tts.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from pathlib import Path from io import BytesIO +import wave import torch from torch.package import PackageImporter @@ -42,7 +43,8 @@ def generate(self, text: str, speaker: str, sample_rate: int) -> bytes: raise InvalidSampleRateException(sample_rate) text = self._delete_dashes(text) - return self._generate_audio(model, text, speaker, sample_rate) + tensor = self._generate_audio(model, text, speaker, sample_rate) + return self._convert_to_wav(tensor, sample_rate) def _load_model(self, model_path: Path): package = PackageImporter(model_path) @@ -70,30 +72,29 @@ def _delete_dashes(self, text: str) -> str: def _generate_audio( self, model: "TTSModelMultiAcc_v3", text: str, speaker: str, sample_rate: int - ) -> bytes: + ) -> torch.Tensor: try: - audio: torch.Tensor = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate) + return model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate) except ValueError: raise NotCorrectTextException(text) except Exception as error: if str(error) == "Model couldn't generate your text, probably it's too long": raise TextTooLongException(text) raise - else: - return self._convert_to_wav(model, audio, sample_rate) - - def _convert_to_wav(self, model: "TTSModelMultiAcc_v3", audio: torch.Tensor, sample_rate: int) -> bytes: - with BytesIO() as buffer: - model.write_wave( - buffer, - audio=self._normalize_audio(audio), - sample_rate=sample_rate, - ) + + def _convert_to_wav(self, tensor: torch.Tensor, sample_rate: int) -> bytes: + audio = self._normalize_audio(tensor) + with BytesIO() as buffer, wave.open(buffer, 'wb') as wav: + wav.setnchannels(1) # mono + wav.setsampwidth(2) # quality is 16 bit. Do not change + wav.setframerate(sample_rate) + wav.writeframes(audio) + buffer.seek(0) return buffer.read() - def _normalize_audio(self, audio: torch.Tensor): - audio: np.ndarray = audio.numpy() * MAX_INT16 + def _normalize_audio(self, tensor: torch.Tensor): + audio: np.ndarray = tensor.numpy() * MAX_INT16 return audio.astype(np.int16) tts = TTS()