diff --git a/README.md b/README.md index b2b93a3..e231234 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ An OpenAI API compatible text to speech server. Full Compatibility: * `tts-1`: `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer` (configurable) * `tts-1-hd`: `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer` (configurable, uses OpenAI samples by default) -* response_format: `mp3`, `opus`, `aac`, or `flac` +* response_format: `mp3`, `opus`, `aac`, `flac`, `wav` and `pcm` * speed 0.25-4.0 (and more) Details: @@ -20,6 +20,8 @@ Details: * Custom cloned voices can be used for tts-1-hd, See: [Custom Voices Howto](#custom-voices-howto) * 🌐 [Multilingual](#multilingual) support with XTTS voices * [Custom fine-tuned XTTS model support](#custom-fine-tuned-model-support) + * Configurable [generation parameters](#generation-parameters) + * Streamed output while generating * Occasionally, certain words or symbols may sound incorrect, you can fix them with regex via `pre_process_map.yaml` @@ -27,6 +29,14 @@ If you find a better voice match for `tts-1` or `tts-1-hd`, please let me know s ## Recent Changes +Version 0.14.0, 2024-06-26 + +* Added `response_format`: `wav` and `pcm` support +* Output streaming (while generating) for `tts-1` and `tts-1-hd` +* Enhanced [generation parameters](#generation-parameters) for xtts models (temperature, top_p, etc.) +* Idle unload timer (optional) - doesn't work perfectly yet +* Improved error handling + Version 0.13.0, 2024-06-25 * Added [Custom fine-tuned XTTS model support](#custom-fine-tuned-model-support) @@ -313,3 +323,21 @@ tts-1-hd: model_path: voices/halo ``` 3) The model will be loaded when you access the voice for the first time (`--preload` doesn't work with custom models yet) + +## Generation Parameters + +The generation of XTTSv2 voices can be fine tuned with the following options (defaults included below): + +```yaml +tts-1-hd: + alloy: + model: xtts + speaker: voices/alloy.wav + enable_text_splitting: True + length_penalty: 1.0 + repetition_penalty: 10 + speed: 1.0 + temperature: 0.75 + top_k: 50 + top_p: 0.85 +``` \ No newline at end of file diff --git a/download_voices_tts-1-hd.bat b/download_voices_tts-1-hd.bat index e2cadc6..d156e96 100644 --- a/download_voices_tts-1-hd.bat +++ b/download_voices_tts-1-hd.bat @@ -2,10 +2,7 @@ set COQUI_TOS_AGREED=1 set TTS_HOME=voices -set MODELS=%* -if "%MODELS%" == "" set MODELS=xtts - -for %%i in (%MODELS%) do ( +for %%i in (%*) do ( python -c "from TTS.utils.manage import ModelManager; ModelManager().download_model('%%i')" ) call download_samples.bat diff --git a/download_voices_tts-1-hd.sh b/download_voices_tts-1-hd.sh index 3e5e391..f101b00 100755 --- a/download_voices_tts-1-hd.sh +++ b/download_voices_tts-1-hd.sh @@ -2,8 +2,7 @@ export COQUI_TOS_AGREED=1 export TTS_HOME=voices -MODELS=${*:-xtts} -for model in $MODELS; do +for model in $*; do python -c "from TTS.utils.manage import ModelManager; ModelManager().download_model('$model')" done ./download_samples.sh \ No newline at end of file diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 5225085..ec6a7ae 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -4,7 +4,9 @@ loguru # piper-tts piper-tts==1.2.0 # xtts -TTS +TTS==0.22.0 +# https://github.com/huggingface/transformers/issues/31040 +transformers<4.41.0 # XXX, 3.8+ has some issue for now spacy==3.7.4 diff --git a/requirements.txt b/requirements.txt index 1155e14..930cb1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,9 @@ loguru # piper-tts piper-tts==1.2.0 # xtts -TTS +TTS==0.22.0 +# https://github.com/huggingface/transformers/issues/31040 +transformers<4.41.0 # XXX, 3.8+ has some issue for now spacy==3.7.4 diff --git a/sample.env b/sample.env index 702d400..5913a4a 100644 --- a/sample.env +++ b/sample.env @@ -2,5 +2,5 @@ TTS_HOME=voices HF_HOME=voices #PRELOAD_MODEL=xtts #PRELOAD_MODEL=xtts_v2.0.2 -#EXTRA_ARGS=--log-level DEBUG +#EXTRA_ARGS=--log-level DEBUG --unload-timer 300 #USE_ROCM=1 \ No newline at end of file diff --git a/speech.py b/speech.py index a1c1266..9c75a0e 100755 --- a/speech.py +++ b/speech.py @@ -1,51 +1,105 @@ #!/usr/bin/env python3 import argparse import os -import sys +import gc import re import subprocess -import tempfile +import sys +import threading +import time import yaml +import contextlib + from fastapi.responses import StreamingResponse -import uvicorn -from pydantic import BaseModel from loguru import logger - +from pydantic import BaseModel +import uvicorn from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError + +@contextlib.asynccontextmanager +async def lifespan(app): + yield + gc.collect() + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except: + pass + +app = OpenAIStub(lifespan=lifespan) xtts = None args = None -app = OpenAIStub() + +def unload_model(): + import torch, gc + global xtts + if xtts: + logger.info("Unloading model") + xtts.xtts.to('cpu') # this was required to free up GPU memory... + del xtts + xtts = None + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() class xtts_wrapper(): - def __init__(self, model_name, device, model_path=None): + check_interval: int = 1 + + def __init__(self, model_name, device, model_path=None, unload_timer=None): self.model_name = model_name + self.unload_timer = unload_timer + self.last_used = time.time() + self.timer = None + self.lock = threading.Lock() logger.info(f"Loading model {self.model_name} to {device}") - if model_path: # custom model # and config_path - config_path=os.path.join(model_path, 'config.json') - self.xtts = TTS(model_path=model_path, config_path=config_path).to(device) - else: - self.xtts = TTS(model_name=model_name).to(device) - - def tts(self, text, speaker_wav, speed, language): - tf, file_path = tempfile.mkstemp(suffix='.wav', prefix='openedai-speech-') - + if model_path is None: + model_path = ModelManager().download_model(model_name)[0] + + config_path = os.path.join(model_path, 'config.json') + config = XttsConfig() + config.load_json(config_path) + self.xtts = Xtts.init_from_config(config) + self.xtts.load_checkpoint(config, checkpoint_dir=model_path, use_deepspeed=False) # XXX there are no prebuilt deepspeed wheels?? + self.xtts = self.xtts.to(device=device) + self.xtts.eval() + + if self.unload_timer: + logger.info(f"Setting unload timer to {self.unload_timer} seconds") + self.not_idle() + self.check_idle() + + def not_idle(self): + with self.lock: + self.last_used = time.time() + + def check_idle(self): + with self.lock: + if time.time() - self.last_used >= self.unload_timer: + print("Unloading TTS model due to inactivity") + unload_model() + else: + # Reschedule the check + self.timer = threading.Timer(self.check_interval, self.check_idle) + self.timer.daemon = True + self.timer.start() + + def tts(self, text, language, speaker_wav, **hf_generate_kwargs): + self.not_idle() try: - # TODO: support speaker= as voice id instead of just wav - file_path = self.xtts.tts_to_file( - text=text, - language=language, - speaker_wav=speaker_wav, - speed=speed, - file_path=file_path, - ) + with torch.no_grad(): + gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=[speaker_wav]) # XXX TODO: allow multiple wav - finally: - os.unlink(file_path) + for wav in self.xtts.inference_stream(text, language, gpt_cond_latent, speaker_embedding, **hf_generate_kwargs): + yield wav.cpu().numpy().tobytes() # assumes wav data is f32le + self.not_idle() - return tf + finally: + self.not_idle() def default_exists(filename: str): if not os.path.exists(filename): @@ -92,10 +146,10 @@ class GenerateSpeechRequest(BaseModel): def build_ffmpeg_args(response_format, input_format, sample_rate): # Convert the output to the desired format using ffmpeg - if input_format == 'raw': - ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "s16le", "-ar", sample_rate, "-ac", "1", "-i", "-"] - else: + if input_format == 'WAV': ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "WAV", "-i", "-"] + else: + ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", input_format, "-ar", sample_rate, "-ac", "1", "-i", "-"] if response_format == "mp3": ffmpeg_args.extend(["-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k"]) @@ -105,6 +159,10 @@ def build_ffmpeg_args(response_format, input_format, sample_rate): ffmpeg_args.extend(["-f", "adts", "-c:a", "aac", "-ab", "64k"]) elif response_format == "flac": ffmpeg_args.extend(["-f", "flac", "-c:a", "flac"]) + elif response_format == "wav": + ffmpeg_args.extend(["-f", "wav", "-c:a", "pcm_s16le"]) + elif response_format == "pcm": # even though pcm is technically 'raw', we still use ffmpeg to adjust the speed + ffmpeg_args.extend(["-f", "s16le", "-c:a", "pcm_s16le"]) return ffmpeg_args @@ -121,18 +179,27 @@ async def generate_speech(request: GenerateSpeechRequest): model = request.model voice = request.voice - response_format = request.response_format + response_format = request.response_format.lower() speed = request.speed # Set the Content-Type header based on the requested format if response_format == "mp3": media_type = "audio/mpeg" elif response_format == "opus": - media_type = "audio/ogg;codecs=opus" + media_type = "audio/ogg;codec=opus" # codecs? elif response_format == "aac": media_type = "audio/aac" elif response_format == "flac": media_type = "audio/x-flac" + elif response_format == "wav": + media_type = "audio/wav" + elif response_format == "pcm": + if model == 'tts-1': # piper + media_type = "audio/pcm;rate=22050" + elif model == 'tts-1-hd': + media_type = "audio/pcm;rate=24000" + else: + BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format') ffmpeg_args = None tts_io_out = None @@ -158,51 +225,77 @@ async def generate_speech(request: GenerateSpeechRequest): tts_proc.stdin.write(bytearray(input_text.encode('utf-8'))) tts_proc.stdin.close() tts_io_out = tts_proc.stdout - ffmpeg_args = build_ffmpeg_args(response_format, input_format="raw", sample_rate="22050") + ffmpeg_args = build_ffmpeg_args(response_format, input_format="s16le", sample_rate="22050") + # Pipe the output from piper/xtts to the input of ffmpeg + ffmpeg_args.extend(["-"]) + ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_io_out, stdout=subprocess.PIPE) + + return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type) # Use xtts for tts-1-hd elif model == 'tts-1-hd': voice_map = map_voice_to_speaker(voice, 'tts-1-hd') try: - tts_model = voice_map['model'] - speaker = voice_map['speaker'] + tts_model = voice_map.pop('model') + speaker = voice_map.pop('speaker') except KeyError as e: raise ServiceUnavailableError(f"Configuration error: tts-1-hd voice '{voice}' is missing setting. KeyError: {e}") - language = voice_map.get('language', 'en') - tts_model_path = voice_map.get('model_path', None) + if xtts and xtts.model_name != tts_model: + unload_model() - if xtts is not None and xtts.model_name != tts_model: - import torch, gc - del xtts - xtts = None - gc.collect() - torch.cuda.empty_cache() + tts_model_path = voice_map.pop('model_path', None) # XXX changing this on the fly is ignored if you keep the same name + + if xtts is None: + xtts = xtts_wrapper(tts_model, device=args.xtts_device, model_path=tts_model_path, unload_timer=args.unload_timer) + + ffmpeg_args = build_ffmpeg_args(response_format, input_format="f32le", sample_rate="24000") - else: - if xtts is None: - xtts = xtts_wrapper(tts_model, device=args.xtts_device, model_path=tts_model_path) + # tts speed doesn't seem to work well + speed = voice_map.pop('speed', speed) + if speed < 0.5: + speed = speed / 0.5 + ffmpeg_args.extend(["-af", "atempo=0.5"]) + if speed > 1.0: + ffmpeg_args.extend(["-af", f"atempo={speed}"]) + speed = 1.0 - ffmpeg_args = build_ffmpeg_args(response_format, input_format="WAV", sample_rate="24000") + language = voice_map.pop('language', 'en') - # tts speed doesn't seem to work well - if speed < 0.5: - speed = speed / 0.5 - ffmpeg_args.extend(["-af", "atempo=0.5"]) - if speed > 1.0: - ffmpeg_args.extend(["-af", f"atempo={speed}"]) - speed = 1.0 + comment = voice_map.pop('comment', None) # ignored. - tts_io_out = xtts.tts(text=input_text, speaker_wav=speaker, speed=speed, language=language) + hf_generate_kwargs = dict( + speed=speed, + **voice_map, + ) + + hf_generate_kwargs['enable_text_splitting'] = hf_generate_kwargs.get('enable_text_splitting', True) # change the default to true + + # Pipe the output from piper/xtts to the input of ffmpeg + ffmpeg_args.extend(["-"]) + ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + + def generator(): + try: + for chunk in xtts.tts(text=input_text, language=language, speaker_wav=speaker, **hf_generate_kwargs): + ffmpeg_proc.stdin.write(chunk) + + except Exception as e: + logger.error(f"Exception: {repr(e)}") + raise e + + finally: + ffmpeg_proc.stdin.close() + + worker = threading.Thread(target=generator) + worker.daemon = True + worker.start() + + return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type) else: raise BadRequestError("No such model, must be tts-1 or tts-1-hd.", param='model') - # Pipe the output from piper/xtts to the input of ffmpeg - ffmpeg_args.extend(["-"]) - ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_io_out, stdout=subprocess.PIPE) - - return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type) # We return 'mps' but currently XTTS will not work with mps devices as the cuda support is incomplete def auto_torch_device(): @@ -220,6 +313,7 @@ def auto_torch_device(): parser.add_argument('--xtts_device', action='store', default=auto_torch_device(), help="Set the device for the xtts model. The special value of 'none' will use piper for all models.") parser.add_argument('--preload', action='store', default=None, help="Preload a model (Ex. 'xtts' or 'xtts_v2.0.2'). By default it's loaded on first use.") + parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds") parser.add_argument('-P', '--port', action='store', default=8000, type=int, help="Server tcp port") parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. 0.0.0.0") parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level") @@ -233,10 +327,13 @@ def auto_torch_device(): logger.add(sink=sys.stderr, level=args.log_level) if args.xtts_device != "none": - from TTS.api import TTS + import torch + from TTS.tts.configs.xtts_config import XttsConfig + from TTS.tts.models.xtts import Xtts + from TTS.utils.manage import ModelManager if args.preload: - xtts = xtts_wrapper(args.preload, device=args.xtts_device) + xtts = xtts_wrapper(args.preload, device=args.xtts_device, unload_timer=args.unload_timer) app.register_model('tts-1') app.register_model('tts-1-hd') diff --git a/voice_to_speaker.default.yaml b/voice_to_speaker.default.yaml index 5d54d39..0604830 100644 --- a/voice_to_speaker.default.yaml +++ b/voice_to_speaker.default.yaml @@ -48,3 +48,11 @@ tts-1-hd: me: model: xtts_v2.0.2 # you can specify different xtts version speaker: voices/me.wav # this could be you + enable_text_splitting: True + length_penalty: 1.0 + repetition_penalty: 10 + speed: 1.0 + temperature: 0.75 + top_k: 50 + top_p: 0.85 + comment: You can add a comment here also, which will be persistent and otherwise ignored. \ No newline at end of file