-
Notifications
You must be signed in to change notification settings - Fork 0
/
3_2_AE_Concat_Function.py
65 lines (53 loc) · 2.33 KB
/
3_2_AE_Concat_Function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#=================================================
# ML_Project__Auditory Attention Detection (on a part of KULeuven Dataset)
# 3_2_AE_Concat_Function
# Foad Moslem (foad.moslem@gmail.com) - Researcher | Aerodynamics
# Using Python 3.11.4 & Spyder IDE
#=================================================
#%%
try:
from IPython import get_ipython
get_ipython().magic('clear')
get_ipython().magic('reset -f')
except:
pass
#%%
#% Libraries
import torch.nn as nn
#% Function - (BiLSTM) Bidrectional LSTM and fully connected layers
input_size = 48*64 # The size of the concatenated feature map
hidden_size = 48 # The size of the hidden state of the BLSTM layer
num_layers = 1 # The number of layers of the BLSTM layer
direction_scale = 0.5
num_spkr = 2 # The number of output classes (speaker 1 or speaker 2)
dropout = 0.25 # The dropout probability
# Bidrectional LSTM and fully connected layers
class blstm(nn.Module):
def __init__(self):
super().__init__()
# Bidirectional LSTM layer
self.blstm = nn.LSTM(input_size = input_size,
hidden_size = int(hidden_size * direction_scale),
num_layers = num_layers,
batch_first = True,
bidirectional = True)
# Four fully connected layers
self.fc1 = nn.Linear(in_features = hidden_size*2, out_features = 2304)
self.fc2 = nn.Linear(in_features = 128, out_features = 128)
self.fc3 = nn.Linear(in_features = 128, out_features = 32)
self.fc4 = nn.Linear(in_features = 32, out_features = num_spkr)
# Activation functions
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
# Pass the input through the BLSTM layer
x, _ = self.blstm(x)
# Take the last output of the BLSTM layer
# x = x[:, -1, :]
# Pass the output through the four FC layers with ReLU activation for the first three and softmax for the last one
x = self.dropout(self.relu(self.fc1(x)))
x = self.dropout(self.relu(self.fc2(x)))
x = self.dropout(self.relu(self.fc3(x)))
x = self.softmax(self.fc4(x), dim=1)
return x