Skip to content

Commit

Permalink
refactor: do not use the model method to convert in WAV
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPandir committed Mar 14, 2024
1 parent 766f0ff commit 39cc8a6
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions tts/tts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from pathlib import Path
from io import BytesIO
import wave

import torch
from torch.package import PackageImporter
Expand Down Expand Up @@ -42,7 +43,8 @@ def generate(self, text: str, speaker: str, sample_rate: int) -> bytes:
raise InvalidSampleRateException(sample_rate)

text = self._delete_dashes(text)
return self._generate_audio(model, text, speaker, sample_rate)
tensor = self._generate_audio(model, text, speaker, sample_rate)
return self._convert_to_wav(tensor, sample_rate)

def _load_model(self, model_path: Path):
package = PackageImporter(model_path)
Expand Down Expand Up @@ -70,30 +72,29 @@ def _delete_dashes(self, text: str) -> str:

def _generate_audio(
self, model: "TTSModelMultiAcc_v3", text: str, speaker: str, sample_rate: int
) -> bytes:
) -> torch.Tensor:
try:
audio: torch.Tensor = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
return model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
except ValueError:
raise NotCorrectTextException(text)
except Exception as error:
if str(error) == "Model couldn't generate your text, probably it's too long":
raise TextTooLongException(text)
raise
else:
return self._convert_to_wav(model, audio, sample_rate)

def _convert_to_wav(self, model: "TTSModelMultiAcc_v3", audio: torch.Tensor, sample_rate: int) -> bytes:
with BytesIO() as buffer:
model.write_wave(
buffer,
audio=self._normalize_audio(audio),
sample_rate=sample_rate,
)

def _convert_to_wav(self, tensor: torch.Tensor, sample_rate: int) -> bytes:
audio = self._normalize_audio(tensor)
with BytesIO() as buffer, wave.open(buffer, 'wb') as wav:
wav.setnchannels(1) # mono
wav.setsampwidth(2) # quality is 16 bit. Do not change
wav.setframerate(sample_rate)
wav.writeframes(audio)

buffer.seek(0)
return buffer.read()

def _normalize_audio(self, audio: torch.Tensor):
audio: np.ndarray = audio.numpy() * MAX_INT16
def _normalize_audio(self, tensor: torch.Tensor):
audio: np.ndarray = tensor.numpy() * MAX_INT16
return audio.astype(np.int16)

tts = TTS()

0 comments on commit 39cc8a6

Please sign in to comment.