forked from UKPLab/sentence-transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLSTM.py
62 lines (47 loc) · 2.27 KB
/
LSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
from typing import List
import os
import json
class LSTM(nn.Module):
"""
Bidirectional LSTM running over word embeddings.
"""
def __init__(self, word_embedding_dimension: int, hidden_dim: int, num_layers: int = 1, dropout: float = 0, bidirectional: bool = True):
nn.Module.__init__(self)
self.config_keys = ['word_embedding_dimension', 'hidden_dim', 'num_layers', 'dropout', 'bidirectional']
self.word_embedding_dimension = word_embedding_dimension
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = dropout
self.bidirectional = bidirectional
self.embeddings_dimension = hidden_dim
if self.bidirectional:
self.embeddings_dimension *= 2
self.encoder = nn.LSTM(word_embedding_dimension, hidden_dim, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True)
def forward(self, features):
token_embeddings = features['token_embeddings']
sentence_lengths = torch.clamp(features['sentence_lengths'], min=1)
packed = nn.utils.rnn.pack_padded_sequence(token_embeddings, sentence_lengths, batch_first=True, enforce_sorted=False)
packed = self.encoder(packed)
unpack = nn.utils.rnn.pad_packed_sequence(packed[0], batch_first=True)[0]
features.update({'token_embeddings': unpack})
return features
def get_word_embedding_dimension(self) -> int:
return self.embeddings_dimension
def tokenize(self, text: str) -> List[int]:
raise NotImplementedError()
def save(self, output_path: str):
with open(os.path.join(output_path, 'lstm_config.json'), 'w') as fOut:
json.dump(self.get_config_dict(), fOut, indent=2)
torch.save(self.state_dict(), os.path.join(output_path, 'pytorch_model.bin'))
def get_config_dict(self):
return {key: self.__dict__[key] for key in self.config_keys}
@staticmethod
def load(input_path: str):
with open(os.path.join(input_path, 'lstm_config.json'), 'r') as fIn:
config = json.load(fIn)
weights = torch.load(os.path.join(input_path, 'pytorch_model.bin'))
model = LSTM(**config)
model.load_state_dict(weights)
return model