-
Notifications
You must be signed in to change notification settings - Fork 4
/
encoder.py
67 lines (57 loc) · 2.67 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
'''
Seq2VecEncoders for encoding mentions and entities.
'''
import torch.nn as nn
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper, BagOfEmbeddingsEncoder
from allennlp.modules.seq2vec_encoders import BertPooler
from overrides import overrides
from allennlp.nn.util import get_text_field_mask
class Pooler_for_cano_and_def(Seq2VecEncoder):
def __init__(self, args, word_embedder):
super(Pooler_for_cano_and_def, self).__init__()
self.args = args
self.huggingface_nameloader()
self.bertpooler_sec2vec = BertPooler(pretrained_model=self.bert_weight_filepath)
self.word_embedder = word_embedder
self.word_embedding_dropout = nn.Dropout(self.args.word_embedding_dropout)
def huggingface_nameloader(self):
if self.args.bert_name == 'bert-base-uncased':
self.bert_weight_filepath = 'bert-base-uncased'
elif self.args.bert_name == 'biobert':
self.bert_weight_filepath = './biobert/'
else:
self.bert_weight_filepath = 'dummy'
print('Currently not supported', self.args.bert_name)
exit()
def forward(self, cano_and_def_concatnated_text):
mask_sent = get_text_field_mask(cano_and_def_concatnated_text)
entity_emb = self.word_embedder(cano_and_def_concatnated_text)
entity_emb = self.word_embedding_dropout(entity_emb)
entity_emb = self.bertpooler_sec2vec(entity_emb, mask_sent)
return entity_emb
class Pooler_for_mention(Seq2VecEncoder):
def __init__(self, args, word_embedder):
super(Pooler_for_mention, self).__init__()
self.args = args
self.huggingface_nameloader()
self.bertpooler_sec2vec = BertPooler(pretrained_model=self.bert_weight_filepath)
self.word_embedder = word_embedder
self.word_embedding_dropout = nn.Dropout(self.args.word_embedding_dropout)
def huggingface_nameloader(self):
if self.args.bert_name == 'bert-base-uncased':
self.bert_weight_filepath = 'bert-base-uncased'
elif self.args.bert_name == 'biobert':
self.bert_weight_filepath = './biobert/'
else:
self.bert_weight_filepath = 'dummy'
print('Currently not supported', self.args.bert_name)
exit()
def forward(self, contextualized_mention):
mask_sent = get_text_field_mask(contextualized_mention)
mention_emb = self.word_embedder(contextualized_mention)
mention_emb = self.word_embedding_dropout(mention_emb)
mention_emb = self.bertpooler_sec2vec(mention_emb, mask_sent)
return mention_emb
@overrides
def get_output_dim(self):
return 768