-
Notifications
You must be signed in to change notification settings - Fork 5
/
wiki_loader_sentences.py
139 lines (104 loc) · 5.08 KB
/
wiki_loader_sentences.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from torch.utils.data import Dataset
from text_manipulation import word_model
from text_manipulation import extract_sentence_words
from pathlib2 import Path
import re
import wiki_utils
import os
import utils
logger = utils.setup_logger(__name__, 'train.log')
section_delimiter = "========"
def get_files(path):
all_objects = Path(path).glob('**/*')
files = [str(p) for p in all_objects if p.is_file()]
return files
def get_cache_path(wiki_folder):
cache_file_path = wiki_folder / 'paths_cache'
return cache_file_path
def cache_wiki_filenames(wiki_folder):
files = Path(wiki_folder).glob('*/*/*/*')
cache_file_path = get_cache_path(wiki_folder)
with cache_file_path.open('w+') as f:
for file in files:
f.write(unicode(file) + u'\n')
def clean_section(section):
cleaned_section = section.strip('\n')
return cleaned_section
def get_scections_from_text(txt, high_granularity=True):
sections_to_keep_pattern = wiki_utils.get_seperator_foramt() if high_granularity else wiki_utils.get_seperator_foramt(
(1, 2))
if not high_granularity:
# if low granularity required we should flatten segments within segemnt level 2
pattern_to_ommit = wiki_utils.get_seperator_foramt((3, 999))
txt = re.sub(pattern_to_ommit, "", txt)
#delete empty lines after re.sub()
sentences = [s for s in txt.strip().split("\n") if len(s) > 0 and s != "\n"]
txt = '\n'.join(sentences).strip('\n')
all_sections = re.split(sections_to_keep_pattern, txt)
non_empty_sections = [s for s in all_sections if len(s) > 0]
return non_empty_sections
def get_sections(path, high_granularity=True):
file = open(str(path), "r")
raw_content = file.read()
file.close()
# clean_txt = raw_content.decode('utf-8').strip()
# sections = [clean_section(s) for s in get_scections_from_text(clean_txt, high_granularity)]
sections = [clean_section(s) for s in get_scections_from_text(raw_content, high_granularity)]
return sections
def read_wiki_file(path, word2vec=None, remove_preface_segment=True, ignore_list=False, remove_special_tokens=False,
return_as_sentences=True, high_granularity=True,only_letters = False):
data = []
targets = []
all_sections = get_sections(path, high_granularity)
required_sections = all_sections[1:] if remove_preface_segment and len(all_sections) > 0 else all_sections
required_non_empty_sections = [section for section in required_sections if len(section) > 0 and section != "\n"]
for section in required_non_empty_sections:
sentences = section.split('\n')
if sentences:
for sentence in sentences:
is_list_sentence = wiki_utils.get_list_token() + "." == sentence.encode('utf-8')
if ignore_list and is_list_sentence:
continue
if not return_as_sentences:
sentence_words = extract_sentence_words(sentence, remove_special_tokens=remove_special_tokens)
if 1 <= len(sentence_words):
data.append([word_model(word, word2vec) for word in sentence_words])
else:
#raise ValueError('Sentence in wikipedia file is empty')
logger.info('Sentence in wikipedia file is empty')
else: # for the annotation. keep sentence as is.
if (only_letters):
sentence = re.sub('[^a-zA-Z0-9 ]+', '', sentence)
data.append(sentence)
else:
data.append(sentence)
if data:
targets.append(len(data) - 1)
return data, targets, path
class WikipediaDataSet(Dataset):
def __init__(self, root, word2vec = None, train=True, manifesto=False,
folder=False, high_granularity=False, only_letters = False):
if (manifesto):
self.textfiles = list(Path(root).glob('*'))
else:
if (folder):
self.textfiles = get_files(root)
else:
root_path = Path(root)
cache_path = get_cache_path(root_path)
if not cache_path.exists():
cache_wiki_filenames(root_path)
self.textfiles = cache_path.read_text().splitlines()
if len(self.textfiles) == 0:
raise RuntimeError('Found 0 images in subfolders of: {}'.format(root))
self.train = train
self.root = root
self.word2vec = word2vec
self.high_granularity = high_granularity
self.only_letters = only_letters
def __getitem__(self, index):
path = self.textfiles[index]
return read_wiki_file(Path(path), self.word2vec, ignore_list=True, remove_special_tokens=True,
high_granularity=self.high_granularity, only_letters= self.only_letters)
def __len__(self):
return len(self.textfiles)