-
Notifications
You must be signed in to change notification settings - Fork 4
/
load.py
99 lines (85 loc) · 3.12 KB
/
load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import re
import os
import unicodedata
from config import MAX_LENGTH, save_dir
SOS_token = 0
EOS_token = 1
PAD_token = 2
class Voc:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS", 2:"PAD"}
self.n_words = 3 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
s = re.sub(r"\s+", r" ", s).strip()
return s
def readVocs(corpus, corpus_name):
print("Reading lines...")
# combine every two lines into pairs and normalize
with open(corpus) as f:
content = f.readlines()
# import gzip
# content = gzip.open(corpus, 'rt')
lines = [x.strip() for x in content]
it = iter(lines)
# pairs = [[normalizeString(x), normalizeString(next(it))] for x in it]
pairs = [[x, next(it)] for x in it]
voc = Voc(corpus_name)
return voc, pairs
def filterPair(p):
# input sequences need to preserve the last word for EOS_token
return len(p[0].split(' ')) < MAX_LENGTH and \
len(p[1].split(' ')) < MAX_LENGTH
def filterPairs(pairs):
return [pair for pair in pairs if filterPair(pair)]
def prepareData(corpus, corpus_name):
voc, pairs = readVocs(corpus, corpus_name)
print("Read {!s} sentence pairs".format(len(pairs)))
pairs = filterPairs(pairs)
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
print("Counting words...")
for pair in pairs:
voc.addSentence(pair[0])
voc.addSentence(pair[1])
print("Counted words:", voc.n_words)
directory = os.path.join(save_dir, 'training_data', corpus_name)
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(voc, os.path.join(directory, '{!s}.tar'.format('voc')))
torch.save(pairs, os.path.join(directory, '{!s}.tar'.format('pairs')))
return voc, pairs
def loadPrepareData(corpus):
corpus_name = corpus.split('/')[-1].split('.')[0]
try:
print("Start loading training data ...")
voc = torch.load(os.path.join(save_dir, 'training_data', corpus_name, 'voc.tar'))
pairs = torch.load(os.path.join(save_dir, 'training_data', corpus_name, 'pairs.tar'))
except FileNotFoundError:
print("Saved data not found, start preparing trianing data ...")
voc, pairs = prepareData(corpus, corpus_name)
return voc, pairs