-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
92 lines (83 loc) · 3.68 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
import fairseq
from conformer import ConformerBlock
from torch.nn.modules.transformer import _get_clones
from torch import Tensor
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
class MyConformer(nn.Module):
def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1):
super(MyConformer, self).__init__()
self.dim_head=int(emb_size/heads)
self.dim=emb_size
self.heads=heads
self.kernel_size=kernel_size
self.n_encoders=n_encoders
self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False)
self.encoder_blocks=_get_clones( ConformerBlock( dim = emb_size, dim_head=self.dim_head, heads= heads,
ff_mult = ffmult, conv_expansion_factor = exp_fac, conv_kernel_size = kernel_size),
n_encoders)
self.class_token = nn.Parameter(torch.rand(1, emb_size))
self.fc5 = nn.Linear(emb_size, 2)
def forward(self, x, device): # x shape [bs, tiempo, frecuencia]
x = x + self.positional_emb[:, :x.size(1), :]
x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])#[bs,1+tiempo,emb_size]
list_attn_weight = []
for layer in self.encoder_blocks:
x, attn_weight = layer(x) #[bs,1+tiempo,emb_size]
list_attn_weight.append(attn_weight)
embedding=x[:,0,:] #[bs, emb_size]
out=self.fc5(embedding) #[bs,2]
return out, list_attn_weight
class SSLModel(nn.Module): #W2V
def __init__(self,device):
super(SSLModel, self).__init__()
cp_path = 'xlsr2_300m.pt' # Change the pre-trained XLSR model path.
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
self.model = model[0]
self.device=device
self.out_dim = 1024
return
def extract_feat(self, input_data):
# put the model to GPU if it not there
if next(self.model.parameters()).device != input_data.device \
or next(self.model.parameters()).dtype != input_data.dtype:
self.model.to(input_data.device, dtype=input_data.dtype)
self.model.train()
# input should be in shape (batch, length)
if input_data.ndim == 3:
input_tmp = input_data[:, :, 0]
else:
input_tmp = input_data
# [batch, length, dim]
emb = self.model(input_tmp, mask=False, features_only=True)['x']
return emb
class Model(nn.Module):
def __init__(self, args, device):
super().__init__()
self.device=device
####
# create network wav2vec 2.0
####
self.ssl_model = SSLModel(self.device)
self.LL = nn.Linear(1024, args.emb_size)
print('W2V + Conformer')
self.first_bn = nn.BatchNorm2d(num_features=1)
self.selu = nn.SELU(inplace=True)
self.conformer=MyConformer(emb_size=args.emb_size, n_encoders=args.num_encoders,
heads=args.heads, kernel_size=args.kernel_size)
def forward(self, x):
#-------pre-trained Wav2vec model fine tunning ------------------------##
x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
x=self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim) (bs, 208, 256)
x = x.unsqueeze(dim=1) # add channel #(bs, 1, frame_number, 256)
x = self.first_bn(x)
x = self.selu(x)
x = x.squeeze(dim=1)
out, attn_score =self.conformer(x,self.device)
return out, attn_score