forked from Riroaki/FewShotEntityExtraction
-
Notifications
You must be signed in to change notification settings - Fork 1
/
entity_fewrel.py
114 lines (104 loc) · 3.96 KB
/
entity_fewrel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import logging
from queue import Queue
import threading
from datetime import datetime
import json
import tagme
from config import TAGME_TOKEN, MAX_WORKERS
tagme.GCUBE_TOKEN = TAGME_TOKEN
logger = logging.getLogger('Entity extraction(fewrel)')
logger.setLevel(logging.DEBUG)
queue = Queue()
def load_queue():
global queue, data
for rel, sentence_list in data.items():
for sentence_meta in sentence_list:
if 'entities' not in sentence_meta:
queue.put(sentence_meta)
class Worker(threading.Thread):
"""Add tags for few-rel dataset."""
def __init__(self, idx: int):
super(Worker, self).__init__()
self._index = idx
self._stop_event = threading.Event()
def run(self):
global queue
while queue.qsize() > 0:
# Killed
if self._stop_event.is_set():
break
# Extract entities from sentences in queue
sentence_meta = queue.get()
try:
if 'entities' not in sentence_meta:
tokens = sentence_meta['tokens']
sentence = ' '.join(tokens)
sentence_annotations = tagme.annotate(sentence)
entities = []
for ann in sentence_annotations.annotations:
# map entity back to word position
start, length = 0, 0
while length < ann.begin:
length += len(tokens[start]) + 1
start += 1
end = start
while length < ann.end:
length += len(tokens[end]) + 1
end += 1
# add entity information
entities.append({'index_begin': start,
'index_end': end,
'entity_id': ann.entity_id,
'score': ann.score})
sentence_meta['entities'] = entities
logger.info(
'{}, worker: {}, jobs remain: {}.'.format(datetime.now(),
self._index,
queue.qsize()))
except Exception as e:
logger.warning(e)
# Send job back to queue
queue.put(sentence_meta)
logger.info('Worker {} exited.'.format(self._index))
def stop(self):
self._stop_event.set()
if __name__ == '__main__':
for dataset in {'train', 'val'}:
# Load data
if os.path.exists('data/fewrel/{}_entity.json'.format(dataset)):
with open('data/fewrel/{}_entity.json'.format(dataset),
'r') as f:
data = json.load(f)
else:
with open('data/fewrel/{}.json'.format(dataset), 'r') as f:
data = json.load(f)
workers = []
try:
# Add sentences to queue
load_queue()
if queue.qsize() == 0:
logger.info('No job left.')
continue
# Create workers
count = int(min(MAX_WORKERS, queue.qsize() // 20))
for index in range(count):
w = Worker(index)
w.start()
workers.append(w)
# Wait till jobs finished
for w in workers:
w.join()
except KeyboardInterrupt:
logger.info('Stopped by user.')
# Stop workers
for w in workers:
w.stop()
for w in workers:
w.join()
logger.info('Jobs left: {}.'.format(queue.qsize()))
# Save data
with open('data/fewrel/{}_entity.json'.format(dataset), 'w') as f:
json.dump(data, f)
logger.info('Saved data.')
logger.info('Everything done.')