-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathsolver_merge.py
186 lines (163 loc) · 7.88 KB
/
solver_merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import time
import os
from Backup_pesq import com_mse_loss, com_mag_mse_loss, pesq_loss
import hdf5storage
import gc
from config_merge import *
tr_batch, tr_epoch,cv_epoch = [], [], []
class Solver(object):
def __init__(self, data, model, optimizer, args):
# load args parameters
self.tr_loader = data['tr_loader']
self.cv_loader = data['cv_loader']
self.loss_dir = args.loss_dir
self.model = model
self.optimizer = optimizer
self.epochs = args.epochs
self.half_lr = args.half_lr
self.early_stop = args.early_stop
self.best_path = args.best_path
self.cp_path = args.cp_path
self.tr_loss = torch.Tensor(self.epochs)
self.cv_loss = torch.Tensor(self.epochs)
self.print_freq = args.print_freq
self.is_conti = args.is_conti
self.conti_path = args.conti_path
self._reset()
def _reset(self):
if self.is_conti:
checkpoint = torch.load(self.conti_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.start_epoch = checkpoint['epoch']
self.prev_cv_loss = checkpoint['cv_loss']
self.best_cv_loss = checkpoint['best_cv_loss']
#self.start_epoch = 0
#self.prev_cv_loss = float("inf")
#self.best_cv_loss = float("inf")
self.cv_no_impv = 0
self.having = False
else:
#Reset
self.start_epoch = 0
self.prev_cv_loss = float("inf")
self.best_cv_loss = float("inf")
self.cv_no_impv = 0
self.having = False
def train(self):
for epoch in range(self.start_epoch, self.epochs):
print("Begin to train.....")
self.model.train()
start = time.time()
tr_avg_loss = self.run_one_epoch(epoch)
print('-' * 90)
print("End of Epoch %d, Time: %4f s, Train_Loss:%5f" % (int(epoch+1), time.time()-start, tr_avg_loss))
print('-' * 90)
# Cross cv
print("Begin Cross Validation....")
self.model.eval() # BN and Dropout is off
cv_avg_loss = self.run_one_epoch(epoch, cross_valid=True)
print('-' * 90)
print("Time: %4fs, CV_Loss:%5f" % (time.time() - start, cv_avg_loss))
print('-' * 90)
# save checkpoint
check_point = {}
check_point['epoch'] = epoch+1
check_point['model_state_dict'] = self.model.state_dict()
check_point['optimizer_state_dict'] = self.optimizer.state_dict()
check_point['tr_loss'] = tr_avg_loss
check_point['cv_loss'] = cv_avg_loss
check_point['best_cv_loss'] = self.best_cv_loss
torch.save(check_point, os.path.join(self.cp_path, 'checkpoint_early_exit_%dth.pth.tar' % (epoch+1)))
self.tr_loss[epoch] = tr_avg_loss
self.cv_loss[epoch] = cv_avg_loss
tr_epoch.append(tr_avg_loss)
cv_epoch.append(cv_avg_loss)
# save loss
loss = {}
loss['tr_loss'] = tr_epoch
loss['cv_loss'] = cv_epoch
hdf5storage.savemat(self.loss_dir, loss)
# Adjust learning rate and early stop
if self.half_lr:
if cv_avg_loss >= self.prev_cv_loss:
self.cv_no_impv += 1
if self.cv_no_impv == 3:
self.having = True
if self.cv_no_impv >= 5 and self.early_stop == True:
print("No improvement and apply early stop")
break
else:
self.cv_no_impv = 0
if self.having == True:
optim_state = self.optimizer.state_dict()
for i in range(len(optim_state['param_groups'])):
optim_state['param_groups'][i]['lr'] = optim_state['param_groups'][i]['lr'] / 2.0
self.optimizer.load_state_dict(optim_state)
print('Learning rate adjusted to %5f' % (optim_state['param_groups'][0]['lr']))
self.having = False
self.prev_cv_loss = cv_avg_loss
if cv_avg_loss < self.best_cv_loss:
self.best_cv_loss = cv_avg_loss
torch.save(self.model.state_dict(), self.best_path)
print("Find better cv model, saving to %s" % os.path.split(self.best_path)[1])
def run_one_epoch(self, epoch, cross_valid=False):
def _batch(_, batch_info):
batch_feat = batch_info.feats.cuda()
batch_label = batch_info.labels.cuda()
noisy_phase = torch.atan2(batch_feat[:,-1,:,:], batch_feat[:,0,:,:])
clean_phase = torch.atan2(batch_label[:,-1,:,:], batch_label[:,0,:,:])
batch_frame_mask_list = batch_info.frame_mask_list
# three approachs for feature compression:
if feat_type is 'normal':
batch_feat, batch_label = torch.norm(batch_feat, dim=1), torch.norm(batch_label, dim=1)
elif feat_type is 'sqrt':
batch_feat, batch_label = (torch.norm(batch_feat, dim=1)) ** 0.5, (
torch.norm(batch_label, dim=1)) ** 0.5
elif feat_type is 'cubic':
batch_feat, batch_label = (torch.norm(batch_feat, dim=1)) ** 0.3, (
torch.norm(batch_label, dim=1)) ** 0.3
elif feat_type is 'log_1x':
batch_feat, batch_label = torch.log(torch.norm(batch_feat, dim=1) + 1), \
torch.log(torch.norm(batch_label, dim=1) + 1)
batch_feat = torch.stack((batch_feat*torch.cos(noisy_phase), batch_feat*torch.sin(noisy_phase)), dim=1)
batch_label = torch.stack((batch_label*torch.cos(clean_phase), batch_label*torch.sin(clean_phase)), dim=1)
esti_list = self.model(batch_feat)
if not cross_valid:
batch_loss = com_mag_mse_loss(esti_list, batch_label, batch_frame_mask_list)
batch_loss_res = batch_loss.item()
tr_batch.append(batch_loss_res)
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
else:
if is_pesq:
batch_loss_res = pesq_loss(esti_list, batch_label, batch_frame_mask_list)
else:
batch_loss = com_mag_mse_loss(esti_list, batch_label, batch_frame_mask_list)
batch_loss_res = batch_loss.item()
tr_batch.append(batch_loss_res)
return batch_loss_res
start1 = time.time()
total_loss = 0
data_loader = self.tr_loader if not cross_valid else self.cv_loader
for batch_id, batch_info in enumerate(data_loader.get_data_loader()):
batch_loss_res = _batch(batch_id, batch_info)
total_loss += batch_loss_res
gc.collect()
if batch_id % self.print_freq == 0:
print("Epoch:%d, Iter:%d, the average_loss:%5f, current_loss:%5f, %d ms/batch."
% (int(epoch + 1), int(batch_id), total_loss / (batch_id + 1), batch_loss_res,
1000 * (time.time() - start1) / (batch_id + 1)))
return total_loss / (batch_id + 1)
from contextlib import contextmanager
@contextmanager
def set_default_tensor_type(tensor_type):
if torch.tensor(0).is_cuda:
old_tensor_type = torch.cuda.FloatTensor
else:
old_tensor_type = torch.FloatTensor
torch.set_default_tensor_type(tensor_type)
yield
torch.set_default_tensor_type(old_tensor_type)