-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathenhance.py
112 lines (101 loc) · 5.47 KB
/
enhance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import argparse
import librosa
import os
import numpy as np
from istft import ISTFT
from aia_trans import aia_complex_trans_mag, aia_complex_trans_ri, dual_aia_trans_merge_crm
import soundfile as sf
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
def enhance(args):
model = dual_aia_trans_merge_crm()
checkpoint = torch.load(args.Model_path)
model.load_state_dict(checkpoint)
print(model)
model.eval()
model.cuda()
with torch.no_grad():
cnt = 0
mix_file_path = args.mix_file_path
esti_file_path = args.esti_file_path
file_list = os.listdir(mix_file_path)
istft = ISTFT(filter_length=320, hop_length=160, window='hanning')
for file_id in file_list:
feat_wav, _ = sf.read(os.path.join(mix_file_path, file_id))
c = np.sqrt(len(feat_wav) / np.sum((feat_wav ** 2.0)))
feat_wav = feat_wav * c
wav_len = len(feat_wav)
frame_num = int(np.ceil((wav_len - 320 + 320) / 160 + 1))
fake_wav_len = (frame_num - 1) * 160 + 320 - 320
left_sample = fake_wav_len - wav_len
feat_wav = torch.FloatTensor(np.concatenate((feat_wav, np.zeros([left_sample])), axis=0))
feat_x = torch.stft(feat_wav.unsqueeze(dim=0), n_fft=320, hop_length=160, win_length=320,
window=torch.hann_window(320)).permute(0, 3, 2, 1)
noisy_phase = torch.atan2(feat_x[:, -1, :, :], feat_x[:, 0, :, :])
feat_x_mag = (torch.norm(feat_x, dim=1)) ** 0.5
feat_x = torch.stack((feat_x_mag * torch.cos(noisy_phase), feat_x_mag * torch.sin(noisy_phase)), dim=1)
esti_x = model(feat_x.cuda())
esti_mag, esti_phase = torch.norm(esti_x, dim=1), torch.atan2(esti_x[:, -1, :, :],
esti_x[:, 0, :, :])
esti_mag = esti_mag ** 2
esti_com = torch.stack((esti_mag * torch.cos(esti_phase), esti_mag * torch.sin(esti_phase)), dim=1)
esti_com = esti_com.cpu()
esti_utt = istft(esti_com).squeeze().numpy()
esti_utt = esti_utt[:wav_len]
esti_utt = esti_utt / c
os.makedirs(os.path.join(esti_file_path), exist_ok=True)
sf.write(os.path.join(esti_file_path, file_id), esti_utt, args.fs)
print(' The %d utterance has been decoded!' % (cnt + 1))
cnt += 1
def enhance_ri(args):
model = aia_complex_trans_ri()
checkpoint = torch.load(args.Model_path)['model_state_dict']
model.load_state_dict(checkpoint)
print(model)
model.eval()
model.cuda()
with torch.no_grad():
cnt = 0
mix_file_path = args.mix_file_path
esti_file_path = args.esti_file_path
file_list = os.listdir(mix_file_path)
istft = ISTFT(filter_length=320, hop_length=160, window='hanning')
for file_id in file_list:
feat_wav, _ = sf.read(os.path.join(mix_file_path, file_id))
c = np.sqrt(len(feat_wav) / np.sum((feat_wav ** 2.0)))
feat_wav = feat_wav * c
wav_len = len(feat_wav)
frame_num = int(np.ceil((wav_len - 320 + 320) / 160 + 1))
fake_wav_len = (frame_num - 1) * 160 + 320 - 320
left_sample = fake_wav_len - wav_len
feat_wav = torch.FloatTensor(np.concatenate((feat_wav, np.zeros([left_sample])), axis=0))
feat_x = torch.stft(feat_wav.unsqueeze(dim=0), n_fft=320, hop_length=160, win_length=320,
window=torch.hann_window(320)).permute(0, 3, 2, 1)
noisy_phase = torch.atan2(feat_x[:, -1, :, :], feat_x[:, 0, :, :])
feat_x_mag = (torch.norm(feat_x, dim=1)) ** 0.5
feat_x = torch.stack((feat_x_mag * torch.cos(noisy_phase), feat_x_mag * torch.sin(noisy_phase)), dim=1)
esti_x = model(feat_x.cuda())
esti_mag, esti_phase = torch.norm(esti_x, dim=1), torch.atan2(esti_x[:, -1, :, :],
esti_x[:, 0, :, :])
esti_mag = esti_mag ** 2
esti_com = torch.stack((esti_mag * torch.cos(esti_phase), esti_mag * torch.sin(esti_phase)), dim=1)
esti_com = esti_com.cpu()
esti_utt = istft(esti_com).squeeze().numpy()
esti_utt = esti_utt[:wav_len]
esti_utt = esti_utt / c
os.makedirs(os.path.join(esti_file_path), exist_ok=True)
sf.write(os.path.join(esti_file_path, file_id), esti_utt, args.fs)
print(' The %d utterance has been decoded!' % (cnt + 1))
cnt += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser('Recovering audio')
parser.add_argument('--mix_file_path', type=str, default='/home/yuguochen/DNS_NONBLIND_TEST/no_reverb_noisy/')
parser.add_argument('--esti_file_path', type=str, default='./estimated_audio/dns_nonblind_test/aia_merge_dns300_best')
parser.add_argument('--snr', type=list, default=[-5, 0, 5, 10, 15, 20]) # -5 -2 0 2 5
parser.add_argument('--fs', type=int, default=16000,
help='The sampling rate of speech')
parser.add_argument('--Model_path', type=str, default='./BEST_MODEL/aia_merge_dns300.pth.tar',
help='The place to save best model')
args = parser.parse_args()
print(args)
enhance(args=args)