-
Notifications
You must be signed in to change notification settings - Fork 140
/
main.py
101 lines (80 loc) · 4.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# %%
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()
from transformers import XLNetTokenizer, XLNetLMHeadModel
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetLMHeadModel.from_pretrained('xlnet-base-cased').eval()
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
xlmroberta_tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
xlmroberta_model = XLMRobertaForMaskedLM.from_pretrained('xlm-roberta-base').eval()
from transformers import BartTokenizer, BartForConditionalGeneration
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large').eval()
from transformers import ElectraTokenizer, ElectraForMaskedLM
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generator')
electra_model = ElectraForMaskedLM.from_pretrained('google/electra-small-generator').eval()
from transformers import RobertaTokenizer, RobertaForMaskedLM
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base').eval()
top_k = 10
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation + '[PAD]'
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
if token not in ignore_tokens:
tokens.append(token.replace('##', ''))
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
# if <mask> is the last token, append a "." so that models dont predict punctuation.
if tokenizer.mask_token == text_sentence.split()[-1]:
text_sentence += ' .'
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
return input_ids, mask_idx
def get_all_predictions(text_sentence, top_clean=5):
# ========================= BERT =================================
print(text_sentence)
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
# ========================= XLNET LARGE =================================
input_ids, mask_idx = encode(xlnet_tokenizer, text_sentence, False)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, mask_idx] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
target_mapping[0, 0, mask_idx] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
with torch.no_grad():
predict = xlnet_model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)[0]
xlnet = decode(xlnet_tokenizer, predict[0, 0, :].topk(top_k).indices.tolist(), top_clean)
# ========================= XLM ROBERTA BASE =================================
input_ids, mask_idx = encode(xlmroberta_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = xlmroberta_model(input_ids)[0]
xlm = decode(xlmroberta_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
# ========================= BART =================================
input_ids, mask_idx = encode(bart_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = bart_model(input_ids)[0]
bart = decode(bart_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
# ========================= ELECTRA =================================
input_ids, mask_idx = encode(electra_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = electra_model(input_ids)[0]
electra = decode(electra_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
# ========================= ROBERTA =================================
input_ids, mask_idx = encode(roberta_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = roberta_model(input_ids)[0]
roberta = decode(roberta_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
return {'bert': bert,
'xlnet': xlnet,
'xlm': xlm,
'bart': bart,
'electra': electra,
'roberta': roberta}