-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbeam.py
38 lines (34 loc) · 1.4 KB
/
beam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
from itertools import groupby
import math
import numpy as np
# Taken from Vinayak A.
# (https://stackoverflow.com/questions/51422776/how-do-i-add-ctc-beam-search-decoder-in-crnn-model-pytorch)
# TODO make beam_k cmd line arg
def beam_search(output, TRG, beam_k=15):
decodes, k_probabilities = tf.nn.ctc_beam_search_decoder(
inputs=output.cpu().detach().numpy(), sequence_length=np.full(
(output.shape[1]), output.shape[0]), top_paths=beam_k)
top_k_decodes = []
for k in range(min(beam_k, len(decodes))):
decode = decodes[k].values.numpy()
# prob = k_probabilities[k].numpy()
char_list = []
for i in range(len(decodes[k].values)):
char_list.append(TRG.vocab.itos[decode[i] - 1])
batch_seqs = [
list(group) for k,
group in groupby(
char_list,
lambda x: x == "<pad>" or x == '<eos>' or x == '<unk>') if not k]
# print(k_probabilities[0, k].numpy())
# print(k_probabilities[1, k].numpy())
# first index is batch num and seond is k
if math.exp(k_probabilities[0, 0].numpy()) - \
math.exp(k_probabilities[0, 1].numpy()) > 0.1:
print(
f'{math.exp(k_probabilities[0, 0].numpy()):.3f}: {" >> ".join(batch_seqs[0])}')
print(
f'{math.exp(k_probabilities[0, 1].numpy()):.3f}: {" >> ".join(batch_seqs[1])}')
# print(batch_seqs[0])
# print(batch_seqs[beam_k - 1])