-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathprocess_fewrel.py
105 lines (95 loc) · 3.65 KB
/
process_fewrel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import json
import logging
import threading
from queue import Queue
from datetime import datetime
from config import MAX_WORKERS
from get_entity import get_entities
logger = logging.getLogger('Few-rel')
Q = Queue()
def load_queue():
for rel, sentence_list in DATA.items():
for sentence_meta in sentence_list:
if 'entities' not in sentence_meta:
Q.put(sentence_meta)
class Worker(threading.Thread):
"""Add tags for few-rel dataset."""
def __init__(self, idx: int):
super(Worker, self).__init__()
self._index = idx
self._stop_event = threading.Event()
def run(self):
global Q
while Q.qsize() > 0:
# Killed
if self._stop_event.is_set():
break
# Extract entities from sentences in queue
sentence_meta = Q.get()
try:
if 'entities' not in sentence_meta:
tokens = sentence_meta['tokens']
sentence = ' '.join(tokens)
entities = get_entities(sentence)
for entity in entities:
# map entity back to word position
start, length = 0, 0
while length < entity['start_pos']:
length += len(tokens[start]) + 1
start += 1
end = start
while length < entity['end_pos']:
length += len(tokens[end]) + 1
end += 1
# add entity information
entity['index_begin'] = start
entity['index_end'] = end
sentence_meta['entities'] = entities
logger.info(
'{}, worker: {}, jobs remain: {}.'.format(datetime.now(),
self._index,
Q.qsize()))
except Exception as e:
logger.warning(e)
# Send job back to queue
Q.put(sentence_meta)
logger.info('Worker {} exited.'.format(self._index))
def stop(self):
self._stop_event.set()
if __name__ == '__main__':
for dataset in {'train', 'val'}:
# Load data
if os.path.exists('data/fewrel/{}_entity.json'.format(dataset)):
with open('data/fewrel/{}_entity.json'.format(dataset),
'r') as f:
DATA = json.load(f)
else:
with open('data/fewrel/{}.json'.format(dataset), 'r') as f:
DATA = json.load(f)
workers = []
try:
# Add sentences to queue
load_queue()
if Q.qsize() == 0:
logger.info('No job left.')
continue
# Create workers
count = int(min(MAX_WORKERS, Q.qsize() // 10))
workers = [Worker(index) for index in range(count)]
_ = [w.start() for w in workers]
# Wait till jobs finished
for w in workers:
w.join()
except KeyboardInterrupt:
logger.info('Stopped by user.')
# Stop workers
_ = [w.stop() for w in workers]
for w in workers:
w.join()
logger.info('Jobs left: {}.'.format(Q.qsize()))
# Save data
with open('data/fewrel/{}_entity.json'.format(dataset), 'w') as f:
json.dump(DATA, f)
full_name = 'data/fewrel/{}.json'.format(dataset)
logger.info('File `{}` processed.'.format(full_name))