-
Notifications
You must be signed in to change notification settings - Fork 0
/
GPTSoVits.py
184 lines (155 loc) · 7.34 KB
/
GPTSoVits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import common
import requests
import urllib.parse
import subprocess
class GPTSoVitsAPI():
"""
GPTSoVits API class.
Attributes:
api_url (str): Url of the GPTSoVits API.
usingV3 (bool): whether using APIv3.py from the fast_inference_ branch.
ttsInferYamlPath (str): TTS inference YAML path. Necessary for v3 API.
Methods:
tts_v1(ref_audio: str, ref_text: str, text: str, ref_language: str = 'auto', text_language: str = 'auto') -> requests.Response: Run TTS inference using the classic v1 API.
tts_v3(ref_audio: str, ref_text: str, text: str, ref_language: str = 'auto', text_language: str = 'auto') -> requests.Response: Run TTS inference using the v3 API from the fast_inference_ branch.
tts(ref_audio: str, ref_text: str, text: str, ref_language: str = 'auto', text_language: str = 'auto') -> requests.Response: Run TTS inference based on the API version.
changeReferenceAudio(ref_audio: str, ref_text: str, ref_language: str = 'auto') -> None: Change reference audio.
control(command: str): Restart or exit.
"""
def __init__(self, api_url: str, isTTSv3: bool = False, ckpt_path: str = None, pth_path: str = None) -> None:
"""
Initialize the GPTSoVitsAPI.
Args:
api_url (str): Url of the GPTSoVits API.
isTTSv3 (bool): whether using APIv3.py from the fast_inference_ branch. Defaults to False.
ttsInferYamlPath (str): TTS inference YAML path. Necessary for v3 API.
"""
self.api_url = api_url
self.usingV3 = isTTSv3
if isTTSv3:
common.log('Using v3 API from the fast_inference_ branch')
self.set_models(ckpt_path, pth_path)
else:
common.log('Using classic v1 API from the main branch')
def set_models(self, ckpt_path: str, pth_path: str) -> None:
"""
Set the models for the v3 API.
Args:
ckpt_path (str): Path to the TTS model checkpoint.
pth_path (str): Path to the TTS model state dict.
"""
a = requests.get(f'{self.api_url}/set_gpt_weights', params={
'weights_path': ckpt_path
})
b = requests.get(f'{self.api_url}/set_sovits_weights', params={
'weights_path': pth_path
})
if a.status_code == 200 and b.status_code == 200:
common.log('Successfully set models for v3 API')
else:
common.panic(f'Failed to set models for v3 API: {a.content} {b.content}')
# text to speech function for v1 API
def tts_v1(self, ref_audio: str, ref_text: str, text: str, ref_language: str = 'auto', text_language: str = 'auto') -> requests.Response:
"""
Run TTS inference using the classic v1 API.
Args:
ref_audio (str): Reference audio path.
ref_text (str): Reference text.
text (str): Text to be synthesized.
ref_language (str, optional): Reference audio language. Defaults to 'auto'.
text_language (str, optional): Text language. Defaults to 'auto'.
Returns:
requests.Response: Response object from the API.
"""
return requests.post(f'{self.api_url}/', json={
"refer_wav_path": ref_audio,
"prompt_text": ref_text,
"prompt_language": ref_language,
"text": text,
"text_language": text_language
}, stream=True)
def build_tts_v3_request(self, ref_audio: str, ref_text: str, text: str, ref_language: str = 'auto', text_language: str = 'auto') -> str:
"""
Build the request for the v3 API.
Args:
ref_audio (str): Reference audio path.
ref_text (str): Reference text.
text (str): Text to be synthesized.
ref_language (str, optional): Reference audio language. Defaults to 'auto'.
text_language (str, optional): Text language. Defaults to 'auto'.
Returns:
str: Request string for the v3 API.
"""
return f'''{self.api_url}/tts?{urllib.parse.urlencode({
"text": text,
"text_lang": text_language,
"ref_audio_path": ref_audio,
"prompt_text": ref_text,
"prompt_lang": ref_language,
"media_type": "aac",
"streaming_mode": True,
"parallel_infer": False,
"tts_infer_yaml_path": self.ttsInferYamlPath
})}'''
def tts_v3(self, ref_audio: str, ref_text: str, text: str, ref_language: str = 'auto', text_language: str = 'auto', streamed: bool = False) -> requests.Response:
"""
Run TTS inference using the v3 API from the fast_inference_ branch.
Args:
ref_audio (str): Reference audio path. Should be less than 10 seconds long.
ref_text (str): Reference text.
text (str): Text to be synthesized.
ref_language (str, optional): Reference audio language. Defaults to 'auto'.
text_language (str, optional): Text language. Defaults to 'auto'.
Returns:
requests.Response: Response object from the API.
"""
return requests.get(f'{self.api_url}/tts', params={
"text": text,
"text_lang": text_language,
"ref_audio_path": ref_audio,
"prompt_text": ref_text,
"prompt_lang": ref_language,
"media_type": "aac",
"streaming_mode": True,
"parallel_infer": False
}, stream=streamed)
def tts(self, ref_audio: str, ref_text: str, text: str, ref_language: str = 'auto', text_language: str = 'auto', streamed: bool = False) -> requests.Response:
"""
Run TTS inference based on the API version.
Args:
ref_audio (str): Reference audio path.
ref_text (str): Reference text.
text (str): Text to be synthesized.
ref_language (str, optional): Reference audio language. Defaults to 'auto'.
text_language (str, optional): Text language. Defaults to 'auto'.
streamed (bool, optional): Whether to stream the response. Makes no difference for v1 API. Defaults to False.
Returns:
requests.Response: Response object from the API.
"""
if self.usingV3:
return self.tts_v3(ref_audio, ref_text, text, ref_language, text_language, streamed)
else:
return self.tts_v1(ref_audio, ref_text, text, ref_language, text_language)
# change reference audio
def changeReferenceAudio(self, ref_audio: str, ref_text: str, ref_language: str = 'auto') -> None:
r = requests.post(f'{self.api_url}/change_ref', json={
"refer_wav_path": ref_audio,
"prompt_text": ref_text,
"prompt_language": ref_language
})
if r.status_code == 400:
raise RuntimeError(f'{__name__}: Failed to change reference audio')
else:
return
# restart or exit
def control(self, command: str):
requests.post(f'{self.api_url}/control', json={
"command": command})
def run_get_text() -> subprocess.Popen:
"""
Run the get_text.py script to get text from the user.
Returns:
subprocess.Popen: Popen object of the get_text.py script.
"""
current_python_exec = subprocess.check_output(['which', 'python3']).decode().strip()
subprocess.Popen()