-
Notifications
You must be signed in to change notification settings - Fork 0
/
Vocabulary.py
146 lines (92 loc) · 3.13 KB
/
Vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Vocabulary Class
class Vocabulary:
def __init__(self):
self.__voc_dict = {
'<eos>': 0,
'<oov>': 0,
'<pad>': 0,
'<sos>': 0
}
self.Vocab = set(['<eos>', '<oov>', '<pad>', '<sos>'])
self.__trimmed_vocab = None
self.word2idx = None
self.idx2word = None
def __addWord(self, word):
if word in self.Vocab:
self.__voc_dict[word] += 1
else:
self.Vocab.add(word)
self.__voc_dict[word] = 1
return
def addSentence(self, sentence):
for word in sentence:
self.__addWord(word.lower())
return
def minCountTrim(self, min_count):
temp = set(['<eos>', '<oov>', '<pad>', '<sos>'])
for word in self.__voc_dict.keys():
if self.__voc_dict[word] >= min_count:
temp.add(word)
else:
continue
self.__trimmed_vocab = list(temp)
return list(temp)
def wordCountTrim(self, num_words):
temp = set(['<eos>', '<oov>', '<pad>', '<sos>'])
if num_words <= 4:
return temp
sorted_vocab = [k for k,v in sorted(self.__voc_dict.items(),
key=lambda x: x[1],
reverse=True)]
for word in sorted_vocab:
if len(temp) == num_words:
break
temp.add(word)
self.__trimmed_vocab = list(temp)
return list(temp)
def sentence2Sequence(self, sentence):
if self.word2idx is None:
self.getVocabMapper()
sequence = []
for word in sentence:
sequence.append(self.word2idx.get(word.lower(), self.word2idx['<oov>']))
return sequence
def sequence2Sentence(self, sequence):
if self.idx2word is None:
self.getVocabMapper()
sentence = []
for i in sequence:
sentence.append(self.idx2word[i])
return sentence
def getTrimmedVocab(self):
if self.__trimmed_vocab is not None:
return self.__trimmed_vocab
else:
return None
def getVocabMapper(self):
vocab = None
self.word2idx = {
'<pad>': 0,
'<oov>': 1,
'<sos>': 2,
'<eos>': 3
}
self.idx2word = {
0: '<pad>',
1: '<oov>',
2: '<sos>',
3: '<eos>'
}
if self.__trimmed_vocab is None:
vocab = self.Vocab
else:
vocab = self.__trimmed_vocab
i = 0
for word in sorted(vocab):
if word not in ('<pad>', '<oov>', '<eos>', '<sos>'):
self.word2idx[word] = i + 4
self.idx2word[i + 4] = word
i += 1
return self.word2idx, self.idx2word
def getCountDictionary(self):
return self.__voc_dict