-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
35 lines (31 loc) · 1.43 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from torch import nn
import torch
from config import device
class LSTM(nn.Module):
"""
LSTM model with embedding layer
"""
def __init__(self, n_vocab, padding_idx, embedding_weights=None, embedding_size=50, hidden_size=128,
num_layers=2) -> None:
super(LSTM, self).__init__()
self.n_vocab = n_vocab
if embedding_weights is not None:
self.embedding = nn.Embedding.from_pretrained(embedding_weights, freeze=False, padding_idx=padding_idx)
else:
self.embedding = nn.Embedding(num_embeddings=n_vocab, embedding_dim=embedding_size, padding_idx=padding_idx)
self.LSTM = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=0.2)
self.dropout = nn.Dropout(0.2)
self.num_layers = num_layers
self.hidden_size = hidden_size
self.fc = nn.Linear(hidden_size, n_vocab)
def forward(self, x, seq_len, state):
x = self.embedding(x)
x = self.dropout(x)
packed = torch.nn.utils.rnn.pack_padded_sequence(x, seq_len, enforce_sorted=False)
output, (h, c) = self.LSTM(packed, state)
x, lengths = torch.nn.utils.rnn.pad_packed_sequence(output)
x = self.fc(x)
return x, (h, c)
def init_state(self, batch_size):
return (torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device),
torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device))