From 18c895df943d182dd29a9f0f6d025b7a834fe2f1 Mon Sep 17 00:00:00 2001 From: MrPandir <137798474+MrPandir@users.noreply.github.com> Date: Sat, 16 Mar 2024 13:48:41 +0100 Subject: [PATCH] feat: add support to change rate and pitch (#14) * feat: add minimal workable code * feat: add removal of html tags from text * feat: add pitch and rate interpolation * feat: add errors of incorrect pitch and rate * feat(app): add pitch and rate limit * style: improve style --- app/app.py | 13 ++++++++++-- tts/exceptions.py | 10 +++++++++ tts/tts.py | 54 +++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/app/app.py b/app/app.py index d1b3c78..b48bd25 100644 --- a/app/app.py +++ b/app/app.py @@ -37,6 +37,8 @@ def generate( sample_rate: Annotated[ int, Parameter(examples=sample_rate_examples, default=48_000) ], + pitch: Annotated[int, Parameter(ge=0, le=100, default=50)], + rate: Annotated[int, Parameter(ge=0, le=100, default=50)], ) -> Response: if len(text) > text_length_limit: raise TextTooLongHTTPException( @@ -44,7 +46,7 @@ def generate( ) try: - audio = tts.generate(text, speaker, sample_rate) + audio = tts.generate(text, speaker, sample_rate, pitch, rate) except NotFoundModelException: raise NotFoundSpeakerHTTPException({"speaker": speaker}) except NotCorrectTextException: @@ -57,6 +59,9 @@ def generate( raise InvalidSampleRateHTTPException( {"sample_rate": sample_rate, "valid_sample_rates": tts.VALID_SAMPLE_RATES} ) + except (InvalidPitchException, InvalidRateException): + # This will never happen because litestar ensures compliance with the parameters `ge` and `le`. + pass else: return Response(audio, media_type="audio/wav") @@ -65,12 +70,16 @@ def generate( async def speakers() -> dict[str, list[str]]: return tts.speakers + @get(["/", "/docs"], include_in_schema=False) async def docs() -> Redirect: return Redirect("/schema") + app = Litestar( [generate, speakers, docs], - openapi_config=OpenAPIConfig(title="Silero TTS API", version="1.0.0", root_schema_site="swagger"), + openapi_config=OpenAPIConfig( + title="Silero TTS API", version="1.0.0", root_schema_site="swagger" + ), cors_config=CORSConfig(), ) diff --git a/tts/exceptions.py b/tts/exceptions.py index 8153e6f..7fc85d7 100644 --- a/tts/exceptions.py +++ b/tts/exceptions.py @@ -17,3 +17,13 @@ class InvalidSampleRateException(Exception): def __init__(self, sample_rate: int) -> None: self.sample_rate = sample_rate super().__init__(f"Invalid sample rate {sample_rate}. Supported sample rates are 8 000, 24 000, and 48 000.") + +class InvalidPitchException(Exception): + def __init__(self, pitch: int) -> None: + self.pitch = pitch + super().__init__(f"Invalid pitch {pitch}. Pitch should be in range from 0 to 100.") + +class InvalidRateException(Exception): + def __init__(self, rate: int) -> None: + self.rate = rate + super().__init__(f"Invalid rate {rate}. Rate should be in range from 0 to 100.") diff --git a/tts/tts.py b/tts/tts.py index eaa755d..f0bb619 100644 --- a/tts/tts.py +++ b/tts/tts.py @@ -35,15 +35,27 @@ def __init__(self): for model_path in Path("models").glob("*.pt"): self._load_model(model_path) - def generate(self, text: str, speaker: str, sample_rate: int) -> bytes: + def generate( + self, text: str, speaker: str, sample_rate: int, pitch: int, rate: int + ) -> bytes: model = self.model_by_speaker.get(speaker) if not model: raise NotFoundModelException(speaker) if sample_rate not in self.VALID_SAMPLE_RATES: raise InvalidSampleRateException(sample_rate) + if not 0 <= pitch <= 100: + raise InvalidPitchException(pitch) + if not 0 <= rate <= 100: + raise InvalidRateException(rate) + + pitch = self._interpolate_pitch(pitch) + rate = self._interpolate_rate(rate) + text = self._delete_dashes(text) - tensor = self._generate_audio(model, text, speaker, sample_rate) + text = self._delete_html_brackets(text) + + tensor = self._generate_audio(model, text, speaker, sample_rate, pitch, rate) return self._convert_to_wav(tensor, sample_rate) def _load_model(self, model_path: Path): @@ -64,17 +76,46 @@ def _load_speakers(self, model: "TTSModelMultiAcc_v3", language: str): self.speakers[language] = model.speakers for speaker in model.speakers: self.model_by_speaker[speaker] = model - + def _delete_dashes(self, text: str) -> str: # This fixes the problem: # https://github.com/twirapp/silero-tts-api-server/issues/8 return text.replace("-", "").replace("‑", "") + def _delete_html_brackets(self, text: str) -> str: + # Safeguarding against pitch and rate modifications with HTML tags in text. + # And also prevents raising the error of generation of audio `ValueError`, if there is html tags. + return text.replace("<", "").replace(">", "") + + def _interpolate_pitch(self, pitch: int) -> int: + # One interesting feature of the models is that when a pitch of -100 is input, + # it transforms to `1.0 + (-100 / 100) = 0`, making the sound equivalent to generating `1.0 + (0 / 100) = 1`. + # This makes the voice the same for 0 and 1 + if pitch == 0: + return -101 + + SCALE_FACTOR = 2 + OFFSET = -100 + return pitch * SCALE_FACTOR + OFFSET + + def _interpolate_rate(self, rate: int) -> int: + OFFSET = 50 + return rate + OFFSET + def _generate_audio( - self, model: "TTSModelMultiAcc_v3", text: str, speaker: str, sample_rate: int + self, + model: "TTSModelMultiAcc_v3", + text: str, + speaker: str, + sample_rate: int, + pitch: int, + rate: int, ) -> torch.Tensor: + ssml_text = f"{text}" try: - return model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate) + return model.apply_tts( + ssml_text=ssml_text, speaker=speaker, sample_rate=sample_rate + ) except ValueError: raise NotCorrectTextException(text) except Exception as error: @@ -84,7 +125,7 @@ def _generate_audio( def _convert_to_wav(self, tensor: torch.Tensor, sample_rate: int) -> bytes: audio = self._normalize_audio(tensor) - with BytesIO() as buffer, wave.open(buffer, 'wb') as wav: + with BytesIO() as buffer, wave.open(buffer, "wb") as wav: wav.setnchannels(1) # mono wav.setsampwidth(2) # quality is 16 bit. Do not change wav.setframerate(sample_rate) @@ -97,4 +138,5 @@ def _normalize_audio(self, tensor: torch.Tensor): audio: np.ndarray = tensor.numpy() * MAX_INT16 return audio.astype(np.int16) + tts = TTS()