-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel_helper.py
executable file
·45 lines (36 loc) · 1.64 KB
/
model_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import numpy as np
import matplotlib.pyplot as plt
def plot_attention_weights(encoder_inputs, attention_weights, ip_id2word, op_id2word, base_dir, filename=None):
"""
Plots attention weights
:param encoder_inputs: Sequence of word ids (list/numpy.ndarray)
:param attention_weights: Sequence of (<word_id_at_decode_step_t>:<attention_weights_at_decode_step_t>)
:param ip_id2word: dict
:param op_id2word: dict
:return:
"""
if len(attention_weights) == 0:
print('Your attention weights was empty. No attention map saved to the disk. ' +
'\nPlease check if the decoder produced a proper translation')
return
mats = []
dec_inputs = []
for dec_ind, attn in attention_weights:
mats.append(attn.reshape(-1))
dec_inputs.append(dec_ind)
attention_mat = np.transpose(np.array(mats))
fig, ax = plt.subplots(figsize=(32, 32))
ax.imshow(attention_mat)
ax.set_xticks(np.arange(attention_mat.shape[1]))
ax.set_yticks(np.arange(attention_mat.shape[0]))
ax.set_xticklabels([op_id2word[inp] if inp != 0 else "<Res>" for inp in dec_inputs])
ax.set_yticklabels([ip_id2word[inp] if inp != 0 else "<Res>" for inp in encoder_inputs.ravel()])
ax.tick_params(labelsize=32)
ax.tick_params(axis='x', labelrotation=90)
if not os.path.exists(os.path.join(base_dir, 'results')):
os.mkdir(os.path.join(base_dir, 'results'))
if filename is None:
plt.savefig(os.path.join(base_dir, 'results', 'attention.png'))
else:
plt.savefig(os.path.join(base_dir, 'results', '{}'.format(filename)))