From d3a6d5e2cbededd7cf6e1658dee8ce3bedd9b328 Mon Sep 17 00:00:00 2001 From: Eric Wright Date: Tue, 9 Apr 2024 21:32:48 -0400 Subject: [PATCH] Added support for AllTalk TTS Server --- pandrator.py | 143 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 134 insertions(+), 9 deletions(-) diff --git a/pandrator.py b/pandrator.py index 7cdee35..cd1588e 100644 --- a/pandrator.py +++ b/pandrator.py @@ -169,7 +169,7 @@ def __init__(self, master): self.dubbing_switch.grid_remove() # Hide the dubbing switch by default ctk.CTkLabel(session_settings_frame, text="TTS Service:").grid(row=2, column=0, padx=10, pady=5, sticky=tk.W) - self.tts_service_dropdown = ctk.CTkOptionMenu(session_settings_frame, variable=self.tts_service, values=["XTTS", "Silero", "VoiceCraft"], command=self.update_tts_service) + self.tts_service_dropdown = ctk.CTkOptionMenu(session_settings_frame, variable=self.tts_service, values=["XTTS", "Silero", "VoiceCraft", "AllTalk TTS"], command=self.update_tts_service) self.tts_service_dropdown.grid(row=2, column=1, padx=10, pady=5, sticky=tk.EW) self.connect_to_server_button = ctk.CTkButton(session_settings_frame, text="Connect to Server", command=self.connect_to_server) @@ -740,7 +740,31 @@ def toggle_external_server(self): self.populate_speaker_dropdown() def connect_to_server(self): - if self.tts_service.get() == "XTTS": + if self.tts_service.get() == "AllTalk TTS": + if self.use_external_server.get(): + external_server_url = self.external_server_url.get() + try: + response = requests.get(f"{external_server_url}/docs") + if response.status_code == 200: + self.external_server_connected = True + self.populate_speaker_dropdown() + messagebox.showinfo("Connected", "Successfully connected to the external AllTalk TTS server.") + else: + messagebox.showerror("Error", f"Failed to connect to the external AllTalk TTS server. Status code: {response.status_code}") + except requests.exceptions.RequestException as e: + messagebox.showerror("Error", f"Failed to connect to the external AllTalk TTS server: {str(e)}") + else: + try: + response = requests.get("http://localhost:7851/docs") + if response.status_code == 200: + self.populate_speaker_dropdown() + messagebox.showinfo("Connected", "Successfully connected to the local AllTalk TTS server.") + else: + messagebox.showerror("Error", f"Failed to connect to the local AllTalk TTS server. Status code: {response.status_code}") + + except requests.exceptions.RequestException as e: + messagebox.showerror("Error", f"Failed to connect to the local AllTalk TTS server: {str(e)}") + elif self.tts_service.get() == "XTTS": if self.use_external_server.get(): external_server_url = self.external_server_url.get() try: @@ -772,7 +796,38 @@ def connect_to_server(self): messagebox.showerror("Error", f"Failed to connect to the local XTTS server: {str(e)}") def populate_speaker_dropdown(self): - if self.tts_service.get() == "XTTS": + if self.tts_service.get() == "AllTalk TTS": + if self.use_external_server.get() and self.external_server_connected: + external_server_url = self.external_server_url.get() + try: + response = requests.get(f"{external_server_url}/api/voices") + if response.status_code == 200: + # Extract the list of voices + json_data = response.json() + speakers = json_data["voices"] + + self.speaker_dropdown.configure(values=speakers) + if speakers: + self.selected_speaker.set(speakers[0]) + else: + messagebox.showerror("Error", f"Failed to fetch speakers from the external server. Status code: {response.status_code}") + except requests.exceptions.RequestException as e: + messagebox.showerror("Error", f"Failed to connect to the external server: {str(e)}") + else: + try: + response = requests.get("http://localhost:7851/api/voices") + if response.status_code == 200: + # Extract the list of voices + json_data = response.json() + speakers = json_data["voices"] + self.speaker_dropdown.configure(values=speakers) + if speakers: + self.selected_speaker.set(speakers[0]) + else: + messagebox.showerror("Error", f"Failed to fetch speakers from the local server. Status code: {response.status_code}") + except requests.exceptions.RequestException as e: + messagebox.showerror("Error", f"Failed to connect to the local server: {str(e)}") + elif self.tts_service.get() == "XTTS": if self.use_external_server.get() and self.external_server_connected: external_server_url = self.external_server_url.get() try: @@ -1046,7 +1101,7 @@ def load_models(self): CTkMessagebox(title="Error", message="Failed to connect to the LLM API.", icon="cancel") def update_tts_service(self, event=None): - if self.tts_service.get() == "XTTS": + if self.tts_service.get() == "XTTS" or self.tts_service.get() == "AllTalk TTS": self.connect_to_server_button.grid() self.use_external_server_switch.grid() if self.use_external_server.get(): @@ -1194,7 +1249,12 @@ def start_optimisation_thread(self): def check_server_connection(self): try: - if self.tts_service.get() == "XTTS": + if self.tts_service.get() == "AllTalk TTS": + if self.use_external_server.get() and self.external_server_connected: + url = f"{self.external_server_url.get()}/docs" + else: + url = "http://localhost:7851/docs" + elif self.tts_service.get() == "XTTS": if self.use_external_server.get() and self.external_server_connected: url = f"{self.external_server_url.get()}/docs" else: @@ -1212,7 +1272,7 @@ def check_server_connection(self): messagebox.showerror("Error", f"{self.tts_service.get()} server returned status code {response.status_code}. Cannot start generation.") return False except requests.exceptions.RequestException as e: - if self.tts_service.get() == "XTTS" and self.use_external_server.get(): + if (self.tts_service.get() == "XTTS" or self.tts_service.get() == "AllTalk TTS") and self.use_external_server.get(): messagebox.showerror("Error", f"Failed to connect to the external XTTS server:\n{str(e)}") else: messagebox.showerror("Error", f"Failed to connect to {self.tts_service.get()} server:\n{str(e)}") @@ -1892,7 +1952,7 @@ def split_into_sentences(self, text): "Kalmyk (v3)": "xal" } - if self.tts_service.get() == "XTTS": + if self.tts_service.get() == "XTTS" or self.tts_service.get() == "AllTalk TTS" : language = self.language_var.get() # Replace self.language.get() with self.language_var.get() else: # Silero silero_language_name = self.language_var.get() # Replace self.language.get() with self.language_var.get() @@ -2039,7 +2099,7 @@ def save_json(self, data, filename): json.dump(numbered_data, f, indent=2) def update_language_dropdown(self, event=None): - if self.tts_service.get() == "XTTS": + if self.tts_service.get() == "XTTS" or self.tts_service.get() == "AllTalk TTS": languages = ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "hu", "ko", "hi"] self.language_dropdown.configure(values=languages, state="normal") # Enable the language dropdown self.language_var.set("en") @@ -2077,7 +2137,72 @@ def update_language_dropdown(self, event=None): def tts_to_audio(self, text): best_audio = None best_mos = -1 - if self.tts_service.get() == "XTTS": + if self.tts_service.get() == "AllTalk TTS": + language = self.language_dropdown.get() # Get the value directly from the language dropdown/combobox + speaker = self.selected_speaker.get() + speaker = f"{speaker}" + + for attempt in range(self.max_attempts.get()): + try: + timestamp = str(int(time.time())) + data = { + "text_input": text, + "text_filtering": "standard", + "character_voice_gen": speaker, + "narrator_enabled": False, + "narrator_voice_gen": speaker, + "text_not_inside": "character", + "language": language, + "output_file_name": "output" + timestamp, + "output_file_timestamp": False, + "autoplay": False, + "autoplay_volume": 0.1, + "streaming": False + } + print(f"Request data: {data}") # Print the request data + if self.external_server_connected: + server_url = self.external_server_url.get() + else: + server_url = "http://localhost:7851/api/tts-generate/" + + response = requests.post(server_url, data=data) + + if response.status_code == 200: + # Extract the path to the wave file + json_data = response.json() + local_file_path = json_data["output_file_path"] + if os.path.exists(local_file_path): + + print(f"Using local file at {local_file_path}") # Print the response status code + audio = AudioSegment.from_file(local_file_path, format="wav") + else: + output_cache_url = json_data["output_cache_url"] + print(f"Getting file from server at {output_cache_url}") # Print the response status code + + response = requests.get(output_cache_url) + + if response.status_code == 200: + audio_data = io.BytesIO(response.content) + audio = AudioSegment.from_file(audio_data, format="wav") + else: + print(f"Error {response.status_code}: Failed to retreive audio file") + + if self.enable_tts_evaluation.get(): + mos_score = self.evaluate_tts(text, audio) + if mos_score > best_mos: + best_audio = audio + best_mos = mos_score + + if mos_score >= float(self.target_mos_value.get()): + return best_audio + else: + return audio + else: + print(f"Error {response.status_code}: Failed to convert text to audio.") + except Exception as e: + print(f"Error in tts_to_audio: {str(e)}") + + elif self.tts_service.get() == "XTTS": language = self.language_dropdown.get() # Get the value directly from the language dropdown/combobox speaker = self.selected_speaker.get() speaker_wav = f"{speaker}.wav"