-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add train and test python3 script and modify readme
- Loading branch information
Showing
5 changed files
with
158 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: nl8590687 | ||
用于测试语音识别系统语音模型的程序 | ||
""" | ||
import platform as plat | ||
import os | ||
|
||
import tensorflow as tf | ||
from keras.backend.tensorflow_backend import set_session | ||
|
||
|
||
from SpeechModel22 import ModelSpeech | ||
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
#进行配置,使用70%的GPU | ||
config = tf.ConfigProto() | ||
config.gpu_options.per_process_gpu_memory_fraction = 0.9 | ||
#config.gpu_options.allow_growth=True #不全部占满显存, 按需分配 | ||
set_session(tf.Session(config=config)) | ||
|
||
|
||
datapath = '' | ||
modelpath = 'model_speech' | ||
|
||
|
||
if(not os.path.exists(modelpath)): # 判断保存模型的目录是否存在 | ||
os.makedirs(modelpath) # 如果不存在,就新建一个,避免之后保存模型的时候炸掉 | ||
|
||
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断 | ||
if(system_type == 'Windows'): | ||
datapath = 'E:\\语音数据集' | ||
modelpath = modelpath + '\\' | ||
elif(system_type == 'Linux'): | ||
datapath = 'dataset' | ||
modelpath = modelpath + '/' | ||
else: | ||
print('*[Message] Unknown System\n') | ||
datapath = 'dataset' | ||
modelpath = modelpath + '/' | ||
|
||
ms = ModelSpeech(datapath) | ||
|
||
ms.LoadModel(modelpath + 'speech_model22_e_0_step_327500.model') | ||
|
||
ms.TestModel(datapath, str_dataset='train', data_count = 128, out_report = True) | ||
|
||
#r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\ST-CMDS-20170001_1-OS\\20170001P00241I0053.wav') | ||
#r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\ST-CMDS-20170001_1-OS\\20170001P00020I0087.wav') | ||
#r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\wav\\train\\A11\\A11_167.WAV') | ||
#r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\wav\\test\\D4\\D4_750.wav') | ||
#print('*[提示] 语音识别结果:\n',r) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: nl8590687 | ||
用于训练语音识别系统语音模型的程序 | ||
""" | ||
import platform as plat | ||
import os | ||
|
||
import tensorflow as tf | ||
from keras.backend.tensorflow_backend import set_session | ||
|
||
|
||
from SpeechModel22 import ModelSpeech | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
#进行配置,使用70%的GPU | ||
config = tf.ConfigProto() | ||
config.gpu_options.per_process_gpu_memory_fraction = 0.9 | ||
#config.gpu_options.allow_growth=True #不全部占满显存, 按需分配 | ||
set_session(tf.Session(config=config)) | ||
|
||
|
||
datapath = '' | ||
modelpath = 'model_speech' | ||
|
||
|
||
if(not os.path.exists(modelpath)): # 判断保存模型的目录是否存在 | ||
os.makedirs(modelpath) # 如果不存在,就新建一个,避免之后保存模型的时候炸掉 | ||
|
||
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断 | ||
if(system_type == 'Windows'): | ||
datapath = 'E:\\语音数据集' | ||
modelpath = modelpath + '\\' | ||
elif(system_type == 'Linux'): | ||
datapath = 'dataset' | ||
modelpath = modelpath + '/' | ||
else: | ||
print('*[Message] Unknown System\n') | ||
datapath = 'dataset' | ||
modelpath = modelpath + '/' | ||
|
||
ms = ModelSpeech(datapath) | ||
|
||
#ms.LoadModel(modelpath + 'speech_model22_e_0_step_327500.model') | ||
ms.TrainModel(datapath, epoch = 50, batch_size = 4, save_step = 500) | ||
|
||
|