This repository has been archived by the owner on Sep 12, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
300 lines (232 loc) · 9.78 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def run(media_folder, extension_folder, app):
"""
Run extension speech_train
Parameters
----------
media_folder: path to media folder to save files
extension_folder: path to extension folder where run.py file
app
Returns
-------
"""
import os
import sys
import time
import random
import torch
import yaml
import glob
import json
from time import gmtime, strftime
from flask_cors import cross_origin
from flask import request
from io import StringIO
save_train_voice_folder = os.path.join(media_folder, 'user_trained_voice')
if not os.path.exists(save_train_voice_folder):
os.makedirs(save_train_voice_folder)
def current_time(format: str = None):
if format:
return strftime(format, gmtime())
return strftime("%Y%m%d%H%M%S", gmtime())
def clear_memory():
app.config['models'] = []
torch.cuda.empty_cache()
time.sleep(10)
torch.cuda.empty_cache()
return {"status": 200}
@app.route('/change_processor', methods=["POST"])
@cross_origin()
def change_processor():
current_processor = os.environ.get('WUNJO_TORCH_DEVICE', "cpu")
if app.config['SYSNTHESIZE_STATUS'].get("status_code") == 200:
if current_processor == "cpu":
os.environ['WUNJO_TORCH_DEVICE'] = 'cuda'
return {"current_processor": 'cuda'}
else:
os.environ['WUNJO_TORCH_DEVICE'] = 'cpu'
return {"current_processor": 'cpu'}
return {"current_processor": current_processor}
# Create a StringIO object to capture the console output
console_stdout = StringIO()
console_stderr = StringIO()
sys.stdout = console_stdout # prints
sys.stderr = console_stderr # https
app.config["CONSOLE_LOG"] = []
@app.route('/console_log', methods=['GET'])
def console_log():
# Retrieve the captured console output from the StringIO object
captured_stdout = console_stdout.getvalue()
captured_stderr = console_stderr.getvalue()
# Split the captured output into individual log lines
new_logs = captured_stdout.splitlines() # + captured_stderr.splitlines()
new_logs = [log for log in new_logs if "127.0.0.1" not in log and log != '']
# Add new logs to the logs list
app.config["CONSOLE_LOG"].extend(new_logs)
# Truncate the logs list if it exceeds the maximum limit
max_logs = 100
if len(app.config["CONSOLE_LOG"]) > max_logs:
app.config["CONSOLE_LOG"] = app.config["CONSOLE_LOG"][-max_logs:]
# Join the logs into a single string with line breaks
logs_text = '\n'.join(app.config["CONSOLE_LOG"])
app.config["CONSOLE_LOG"] = []
return logs_text
@staticmethod
def tacotron2_train(hparams_path):
sys.path.insert(0, extension_folder)
from tacotron2.train import main as tacotron2
sys.path.pop(0)
tacotron2(hparams_path=hparams_path)
return
@staticmethod
def waveglow_train(json_config):
sys.path.insert(0, extension_folder)
from waveglow.train import main as waveglow
sys.path.pop(0)
waveglow(json_config=json_config)
return
@app.route('/tacotron2/', methods=["POST"])
@cross_origin()
def tacotron2_run():
if app.config['SYSNTHESIZE_STATUS']["status_code"] != 200:
return {"status_code": 200}
clear_memory()
request_list = request.get_json()
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 400, "message": "Происходит обучение модели Tacotron2"}
checkpoint_file = request_list.get("checkpoint")
audio_path = request_list.get("audio_path")
mark_file = request_list.get("mark_path")
charset = request_list.get("language")
train_split = int(request_list.get("train_split"))
tacotron2_config = request_list.get("config")
tacotron2_config = yaml.safe_load(tacotron2_config)
if checkpoint_file:
if not os.path.isfile(checkpoint_file):
# User set path to checkpoint file but this doesn't exist
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": "Не найден checkpoint файл"}
return {"status_code": 200}
else:
checkpoint_file = None
if not audio_path or not mark_file:
# User not set paths
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200,
"message": "Директория аудио или путь к разметке не установлены"}
return {"status_code": 200}
if not os.path.isdir(audio_path) or not os.path.isfile(mark_file):
# Audio path is not path or mark is not file
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": "Директория аудио не найдена или не найден файл разметки"}
return {"status_code": 200}
dir_time = current_time()
train_folder = os.path.join(save_train_voice_folder, dir_time)
if not os.path.exists(train_folder):
os.makedirs(train_folder)
with open(mark_file, "r", encoding="utf-8") as file:
mark_lines = file.readlines()
mark_line = mark_lines[0]
mark_vector = mark_line.split("|")
if len(mark_vector) != 2:
app.config['SYSNTHESIZE_STATUS'] = {
"status_code": 200,
"message": "Не верный формат разметки. Каждая строка данных должна быть в формате: аудио.wav|text from audio"
}
return {"status_code": 200}
if mark_vector[0].split(".")[-1] not in ["wav", "mp3"]:
app.config['SYSNTHESIZE_STATUS'] = {
"status_code": 200,
"message": "Аудио файл не является .wav или .mp3. Каждая строка данных должна быть в формате: аудио.wav|text from audio"
}
return {"status_code": 200}
random.shuffle(mark_lines) # random shuffle list
len_mark_lines = len(mark_lines)
train_size = int(len_mark_lines * (train_split / 100))
mark_test = mark_lines[train_size:]
validation_files = os.path.join(train_folder, "mark_test.txt")
with open(validation_files, "w") as file:
file.writelines(mark_test) # because elements has \n
mark_train = mark_lines[:train_size]
training_files = os.path.join(train_folder, "mark_train.txt")
with open(training_files, "w") as file:
file.writelines(mark_train) # because elements has \n
# todo set absolute path
tacotron2_config["audios_path"] = audio_path
tacotron2_config["training_files"] = training_files
tacotron2_config["validation_files"] = validation_files
tacotron2_config["output_dir"] = train_folder
tacotron2_config["charset"] = charset
tacotron2_config["checkpoint"] = checkpoint_file
new_yaml_path = os.path.join(train_folder, "config")
if not os.path.exists(new_yaml_path):
os.makedirs(new_yaml_path)
new_yaml_file_path = os.path.join(new_yaml_path, "hparams.yaml")
with open(new_yaml_file_path, "w") as file:
yaml.dump(tacotron2_config, file)
# delete tts models from memory
app.config['models'] = {}
torch.cuda.empty_cache()
# start train
try:
tacotron2_train(hparams_path=new_yaml_file_path)
except Exception as e:
print(e)
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": "Недостаточно памяти графического процессора. Увеличьте размер памяти"}
return {"status_code": 400}
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": ""}
return {"status_code": 200}
@app.route('/waveglow/', methods=["POST"])
@cross_origin()
def waveglow_run():
if app.config['SYSNTHESIZE_STATUS']["status_code"] != 200:
return {"status_code": 200}
clear_memory()
request_list = request.get_json()
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 400, "message": "Происходит обучение модели Waveglow"}
audio_path = request_list.get("audio_path")
train_split = int(request_list.get("train_split"))
if not audio_path:
# User not set paths
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": "Директория аудио не установлена"}
return {"status_code": 200}
if not os.path.isdir(audio_path):
# Audio path is not path or mark is not file
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": "Директория аудио не найдена или не найден файл разметки"}
return {"status_code": 200}
audio_list = glob.glob(os.path.join(audio_path, "*.wav"))
len_audio_list = len(audio_list)
train_size = int(len_audio_list * (train_split / 100))
dir_time = current_time()
train_folder = os.path.join(save_train_voice_folder, dir_time)
if not os.path.exists(train_folder):
os.makedirs(train_folder)
random.shuffle(audio_list) # random shuffle list
mark_test = audio_list[train_size:]
validation_files = os.path.join(train_folder, "test_file.txt")
with open(validation_files, "w") as file:
for item in mark_test:
file.write(item + "\n")
mark_train = audio_list[:train_size]
training_files = os.path.join(train_folder, "train_file.txt")
with open(training_files, "w") as file:
for item in mark_train:
file.write(item + "\n")
waveglow_config = request_list.get("config")
waveglow_config = json.loads(waveglow_config)
waveglow_config["train_config"]["output_directory"] = train_folder
waveglow_config["data_config"]["training_files"] = training_files
new_json_path = os.path.join(train_folder, "config")
if not os.path.exists(new_json_path):
os.makedirs(new_json_path)
new_json_file_path = os.path.join(new_json_path, "config.json")
with open(new_json_file_path, "w") as file:
json.dump(waveglow_config, file)
# delete tts models from memory
app.config['models'] = {}
torch.cuda.empty_cache()
# start train
try:
waveglow_train(json_config=new_json_file_path)
except Exception as e:
print(e)
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": "Недостаточно памяти графического процессора. Увеличьте размер памяти"}
return {"status_code": 400}
app.config['SYSNTHESIZE_STATUS'] = {"status_code": 200, "message": ""}
return {"status_code": 200}
sys.path.pop(0)