-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathxmover_explainer.py
105 lines (78 loc) · 3.73 KB
/
xmover_explainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import sys
sys.path.append('./xmover')
import pandas as pd
import numpy as np
import shap
import os
from mosestokenizer import MosesDetokenizer, MosesTokenizer
import argparse
from scorer import XMOVERScorer
import numpy as np
import torch
import truecase
class XMoverWrapper():
def __init__(self, src_lang, tgt_lang, model_name, do_lower_case, language_model, mapping, device, ngram, bs):
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.mapping = mapping
self.device = device
self.ngram = ngram
self.bs = bs
temp = np.loadtxt('./xmover/mapping/europarl-v7.' + src_lang + '-' + tgt_lang + '.2k.12.BAM.map')
self.projection = torch.tensor(temp, dtype=torch.float).to(device)
temp = np.loadtxt('./xmover/mapping/europarl-v7.' + src_lang + '-' + tgt_lang + '.2k.12.GBDD.map')
self.bias = torch.tensor(temp, dtype=torch.float).to(device)
self.scorer = XMOVERScorer(model_name, language_model, do_lower_case, device)
self.src_sent = None
def __call__(self, translations):
assert self.src_sent is not None
translations = [s[0] for s in translations]
translations = [truecase.get_true_case(s) for s in translations]
source = [self.src_sent] * len(translations)
xmoverscores = self.scorer.compute_xmoverscore(self.mapping, self.projection, self.bias, source, translations, self.ngram, self.bs)
return np.array(xmoverscores)
def tokenize_sent(self, sentence, lang):
with MosesTokenizer(lang) as tokenize:
tokens = tokenize(sentence)
return tokens
def detokenize(self, tokens, lang):
with MosesDetokenizer(lang) as tokenize:
sent = tokenize(tokens)
return sent
def build_feature(self, trans_sent):
tokens = self.tokenize_sent(trans_sent, self.tgt_lang)
tdict = {}
for tt in tokens:
tdict[tt] = tt
return pd.DataFrame(tdict, index=[0])
def mask_model(self, mask, x):
tokens = []
for mm, tt in zip(mask, x):
if mm: tokens.append(tt)
else: tokens.append('[MASK]')
trans_sent = self.detokenize(tokens, self.tgt_lang)
sentence = pd.DataFrame([trans_sent])
return sentence
class ExplainableXMover():
def __init__(self, src_lang, tgt_lang, model_name='bert-base-multilingual-cased', do_lower_case=False, language_model='gpt2', mapping='CLP', device='cuda:0', ngram=2, bs=32):
self.wrapper = XMoverWrapper(src_lang, tgt_lang, model_name, do_lower_case, language_model, mapping, device, ngram, bs)
def __call__(self, src_sent, trans_sent):
#return self.wrapper.scorer.compute_xmoverscore(self.wrapper.mapping, self.wrapper.projection, self.wrapper.bias, [self.wrapper.detokenize(src_sent.split(),self.wrapper.src_lang)], [self.wrapper.detokenize(trans_sent.split(),self.wrapper.tgt_lang)], self.wrapper.ngram, self.wrapper.bs)[0]
self.wrapper.src_sent = src_sent
return self.wrapper([[trans_sent]])[0]
def explain(self, src_sent, trans_sent, plot=False):
self.wrapper.src_sent = src_sent
explainer = shap.Explainer(self.wrapper, self.wrapper.mask_model)
value = explainer(self.wrapper.build_feature(trans_sent))
if plot: shap.waterfall_plot(value[0])
all_tokens = self.wrapper.tokenize_sent(trans_sent, self.wrapper.tgt_lang)
return [(token,sv) for token, sv in zip(all_tokens,value[0].values)]
if __name__ == '__main__':
model = ExplainableXMover('de','en')
src = 'Er mag Hunde'
trans = 'He dislikes dogs'
score = model(src, trans)
exps = model.explain(src, trans)
print('\n =========')
print(score)
print(exps)