-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
torch_speech_model.py
131 lines (110 loc) · 5.62 KB
/
torch_speech_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================
"""
@author: nl8590687 / Evelynn-n
声学模型基础功能模板定义
"""
import os
import time
import torch
from torch.utils.data import Dataset, DataLoader as TorchDataLoader
from data_loader import DataLoader
from speech_features.speech_features import SpeechFeatureMeta
class SpeechDataset(Dataset):
def __init__(self, data_loader, speech_features, input_shape, max_label_length):
self.data_loader = data_loader
self.input_shape = input_shape
self.speech_features = speech_features
self.max_label_length = max_label_length
self.data_count = self.data_loader.get_data_count()
def __len__(self):
return self.data_count
def __getitem__(self, index):
wav_data, sample_rate, data_labels = self.data_loader.get_data(index)
# 提取特征
data_input = self.speech_features.run(wav_data, sample_rate)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
# 计算输入长度,确保不超出最大序列长度
pool_size = self.input_shape[0] // (self.input_shape[0] // 8)
inlen = min(data_input.shape[0] // pool_size + data_input.shape[0] % pool_size, self.input_shape[0] // 8)
# 初始化输入特征数组,填充到 `input_shape` 大小
x = torch.zeros(self.input_shape)
x[:len(data_input)] = torch.tensor(data_input, dtype=torch.float32)
# 初始化标签数组,填充到 `max_label_length` 大小
y = torch.zeros(self.max_label_length, dtype=torch.int16)
y[:len(data_labels)] = torch.tensor(data_labels, dtype=torch.int16) + 1
# 转换为 PyTorch 张量
input_length = torch.tensor((inlen,), dtype=torch.float32)
label_length = torch.tensor((len(data_labels),), dtype=torch.float32)
return x, y, input_length, label_length
class ModelSpeech:
def __init__(self, speech_model: torch.nn.Module, speech_features: SpeechFeatureMeta, max_label_length: int = 64):
"""模型初始化"""
self.speech_model = speech_model
self.trained_model = speech_model.get_model()
self.speech_features = speech_features
self.max_label_length = max_label_length
def train(self, data_loader: DataLoader, epochs: int, batch_size: int, optimizer: torch.optim.Optimizer,
device: str = 'cpu'):
"""训练模型"""
speechdata = SpeechDataset(data_loader, self.speech_features, input_shape=self.speech_model.input_shape,
max_label_length=self.max_label_length)
self.trained_model.to(device)
print('[ASRT] torch model successfully initialized to device: {}'.format(device))
data_loader = TorchDataLoader(speechdata, batch_size=batch_size, shuffle=True)
model = self.speech_model
for epoch in range(epochs):
print('[ASRT] Epoch {}/{}'.format(epoch + 1, epochs))
epoch_loss = 0.0
iter_index = 0
t0 = time.time()
for batch in data_loader:
x, y, input_length, label_length = batch
x = x.to(device)
y = y.to(device)
input_length = input_length.to(device).long()
label_length = label_length.to(device).long()
optimizer.zero_grad()
y_pred = model(x)
loss = model.compute_loss(y_pred, y, input_length, label_length)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
iter_index += 1
t1 = time.time()
predict_total_time = (t1-t0)*len(data_loader)/iter_index
predict_remain_time = predict_total_time - (t1-t0)
cur_batch_loss = loss.item()
cur_avg_loss = epoch_loss / iter_index
print("[ASRT]", f"{predict_remain_time:.2f}/{predict_total_time:.2f} s,",
f"step {iter_index}/{len(data_loader)},", f"current loss: {cur_batch_loss:.4f}",
f"avg loss: {cur_avg_loss:.4f}", end="\r")
save_filename = os.path.join('save_models_torch', f"{self.speech_model.get_model_name()}_epoch{epoch+1}.pth")
self.save_weight(save_filename)
avg_loss = epoch_loss / len(data_loader)
total_time = time.time()-t0
avg_time_per_step = total_time / len(data_loader)
print("[ASRT]", f"epoch {epoch + 1}/{epochs},", f"time cost: {total_time:.2f} s,",
f"{avg_time_per_step:.2f} s/step", f"avg loss: {avg_loss:.4f}")
def save_weight(self, filename: str):
save_filename = os.path.join('save_models_torch', filename + ".pth")
torch.save(self.speech_model.state_dict(), save_filename)
def load_weight(self, filepath: str):
self.speech_model.load_state_dict(torch.load(filepath))