forked from SAP-samples/acl2019-commonsense
-
Notifications
You must be signed in to change notification settings - Fork 0
/
commonsense.py
312 lines (238 loc) · 13.7 KB
/
commonsense.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import sys
#sys.path.append("/home/ubuntu/bertviz/")
from bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer
from bertviz.attention_details import AttentionDetailsData, show, get_attention_details
import argparse
from tqdm import trange, tqdm
import data_processors as processors
import numpy as np
from fuzzywuzzy import fuzz
def contains_word(s, w):
return s.find(w) > -1
def computeMaximumAttentionScore(activity):
"""
Compute the Maximum Attention Score (MAS) given the self-attention of a transformer network architecture.
Parameters
----------
activity : ndarray
tensor of attentions across layers and heads
Returns
-------
ndarray
MAS values in range [0,1]
"""
MAS_res = np.zeros((activity.shape[2]))
# loop over layers and heads
for x in range(0,activity.shape[0]):
for y in range(0,activity.shape[0]):
# get word with max attention and its attention
max_idx = np.argmax(activity[x,y,:])
max_val = np.max(activity[x,y,:])
# mask out all non-max values
MAS_res[max_idx] += max_val
# now normalize the attentions
MAS_res /= np.sum(MAS_res)
return MAS_res
def analyzeAttentionSingleTupleDecoy(model, tokenizer, data, guid_dict=None, select_guid=None, num_layers = 12, num_heads = 12, do_debug = False):
"""
Extracts the attention of target words, e.g. groundtruth and decoy word.
Note: If target word is tokenized, we only consider the max attention here.
Parameters
----------
model : bertviz.pytorch_pretrained_bert.modeling.BertModel
BERT model from BERT visualization that provides access to attention
tokenizer: bertviz.pytorch_pretrained_bert.tokenization.BertTokenizer
BERT tolenizer
data: InputItems[]
List of InputItems containing the WNLI/PDP data
guid_dict: dictionary
Dictionary that maps unique ids to data indices. Default None
select_guid: int
GUID of example data, for which the attentions are to be extracted
num_layers: int
Number of layers. Default 12
num_heads: int
Number of attention heads. Default 12
do_debug: boolean
Toggle for printing debug information. Default False
Returns
-------
activity : ndarray
Count matrix, keeping track which layer and head is associated with max attention
attention : ndarray
Attention matrix (#layers, #heads, 2), containing for each head and layer the attention for true word and decoy, respectively
"""
problem_list = set([])
activity = np.zeros((num_layers,num_heads))
if select_guid is None:
elements = range(0,len(data))
else:
assert(guid_dict is not None)
assert(select_guid is not None)
elements = [guid_dict[select_guid]]
for idx in elements:
sentence_a = data[idx].text_a
sentence_b = data[idx].text_b
groundtruth = data[idx].groundtruth
guid = data[idx].guid
decoy = data[idx].decoy
if groundtruth is not None:
details_data = AttentionDetailsData(model, tokenizer)
tokens_a, tokens_b, queries, keys, atts = details_data.get_data(sentence_a, sentence_b)
attentions = get_attention_details(tokens_a, tokens_b, queries, keys, atts)
groundtruth_tokens = tokenizer.tokenize(data[idx].groundtruth)
activity = np.zeros((num_layers,num_heads,len(decoy)+1))
attention_matrix = np.zeros((num_layers,num_heads,len(decoy)+1))
reference_idx = data[idx].reference_idx
if tokenizer.tokenize(groundtruth)[0] not in sentence_a and tokenizer.tokenize(groundtruth)[0] not in sentence_b:
print('Wrong annotation: '+sentence_a+' | '+groundtruth+' | '+sentence_b)
continue
for layer_id in range(0,num_layers):
for head_id in range(0,num_heads):
attention_pairwise = np.asarray(attentions['ab']['att'][layer_id][head_id])
correct_activity = 0
indices = []
# determine attention for the correct word
# check if correct word is in sentence_a OR sentence_b
if contains_word(sentence_a, groundtruth_tokens[0]):
# check if target is single or multi-token
if len(tokenizer.tokenize(groundtruth)) == 1:
# some answers might not be perfect match or misspellings, e.g. plural piece(s), so fuzzy matching necessary
ratios = [fuzz.ratio(groundtruth, token) for token in tokens_a]
best_match_idx = ratios.index(max(ratios))
correct_activity = attention_pairwise[best_match_idx,reference_idx]
indices.append(best_match_idx)
# target streches over multiple tokens
else:
groundtruth_split = tokenizer.tokenize(groundtruth)
local_attention = []
for f in groundtruth_split:
if len(f)>1:
try:
local_attention.append(attention_pairwise[tokens_a.index(f),reference_idx])
indices.append(tokens_a.index(f))
except:
problem_list.add(guid)
pass
# keep max attention
if len(local_attention) > 0:
correct_activity = (np.max(local_attention))
else:
# check if target is single or multi-token
if len(tokenizer.tokenize(groundtruth)) == 1:
correct_activity = attention_pairwise[reference_idx, tokens_b.index(groundtruth)]
indices.append(tokens_b.index(groundtruth))
# target stretches over multiple tokens
else:
groundtruth_split = tokenizer.tokenize(groundtruth)
local_attention = []
for f in groundtruth_split:
if len(f)>1:
local_attention.append(attention_pairwise[reference_idx, tokens_b.index(f)])
indices.append(tokens_b.index(f))
# keep max attention
correct_activity = (np.max(local_attention))
# determine attention for the decoy word
decoy_attention = []
if contains_word(sentence_a, groundtruth_tokens[0]):
for k in decoy:
# check if target is single or multi-token
if len(tokenizer.tokenize(k)) == 1:
# some answers might not be perfect match or misspellings, e.g. plural piece(s), so fuzzy matching necessary
ratios = [fuzz.ratio(k, token) for token in tokens_a]
best_match_idx = ratios.index(max(ratios))
decoy_attention.append(attention_pairwise[best_match_idx,reference_idx])
indices.append(best_match_idx)
else:
decoy_split = tokenizer.tokenize(k)
local_attention = []
for f in decoy_split:
if len(f)>1:
try:
local_attention.append(attention_pairwise[tokens_a.index(f),reference_idx])
indices.append(tokens_a.index(f))
except:
problem_list.add(guid)
pass
if len(local_attention) > 0:
decoy_attention.append(np.max(local_attention))
else:
decoy_attention.append(0)
else:
for k in decoy:
# check if target is single or multi-token
if len(tokenizer.tokenize(k)) == 1:
decoy_attention.append(attention_pairwise[reference_idx, tokens_b.index(k)])
else:
decoy_split = tokenizer.tokenize(k)
local_attention = []
for f in decoy_split:
if len(f)>1:
# some answers might not be perfect match or misspellings, e.g. plural piece(s), so fuzzy matching necessary
ratios = [fuzz.ratio(f, token) for token in tokens_b]
best_match_idx = ratios.index(max(ratios))
local_attention.append(attention_pairwise[reference_idx, best_match_idx])
indices.append(best_match_idx)
if len(local_attention) > 0:
decoy_attention.append(np.max(local_attention))
else:
decoy_attention.append(0)
attn = [correct_activity] + decoy_attention
activity[head_id,layer_id, np.argmax(attn)]+=1
attention_matrix[head_id,layer_id, :] = np.asarray(attn[:])
if do_debug and len(problem_list) > 0:
print('Problems with following guids: '+str(problem_list))
return activity, attention_matrix
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Shou5ld contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--task_name", choices=['PDP', 'MNLI', 'pdp', 'mnli'],
required=True,
help="The name of the task to train, pdp or WNLI.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--debug", action='store_true',
help="Set this flag if you are want to print debug infos.")
args = parser.parse_args()
if args.task_name.lower() == 'pdp':
processor = processors.XMLPDPProcessor()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
all_ambigious, guid_dict = processor.get_train_items(args.data_dir, select_type='ambigious')
all_guids = set([])
for i in range(0,len(all_ambigious)):
all_guids.add(all_ambigious[i].guid)
elif args.task_name.lower() == 'mnli':
processor = processors.XMLMnliProcessor()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
all_ambigious, guid_dict = processor.get_train_items(args.data_dir, select_type='ambigious')
all_guids = set([])
for i in range(0,len(all_ambigious)):
all_guids.add(all_ambigious[i].guid)
model = BertModel.from_pretrained(args.bert_model)
MAS_list = []
counter = np.zeros((2))
for i in trange(0,len(all_guids)):
_, attention = analyzeAttentionSingleTupleDecoy(model, tokenizer, all_ambigious, guid_dict, select_guid=i, do_debug=args.debug)
ref_word = tokenizer.tokenize(all_ambigious[i].text_b)[all_ambigious[i].reference_idx]
MAS = computeMaximumAttentionScore(attention)
MAS_list.append(MAS[0])
# now count how many time MAX is assigned either the true word or the decoy
if np.argmax(MAS)==0:
counter[0] += 1
else:
counter[1] += 1
if args.debug:
print(str(all_ambigious[i].guid) + ' | ' + str(MAS)+ ' | '+all_ambigious[i].text_a + ' '+all_ambigious[i].text_b + ' || >'+ref_word+'<, '+ str(all_ambigious[i].groundtruth)+', '+str(all_ambigious[i].decoy))
print(args.task_name.upper()+" Accuracy: "+str(counter[0]/np.sum(counter)))
if __name__== "__main__":
main()