This repository has been archived by the owner on Nov 8, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_kg.py
138 lines (128 loc) · 4.36 KB
/
main_kg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import matplotlib.pyplot as plt
import networkx as nx
import json
import pickle
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
def extract_relations_from_model_output(text):
relations = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
for token in text_replaced.split():
if token == "<triplet>":
current = 't'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
# knowledge base class
class KB():
def __init__(self):
self.relations = []
def are_relations_equal(self, r1, r2):
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
def exists_relation(self, r1):
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
def add_relation(self, r):
if not self.exists_relation(r):
self.relations.append(r)
def print(self):
re_list=[]
for r in self.relations:
# print(f" {r}")
re_list.append(r)
return re_list
def from_small_text_to_kb(text, verbose=False):
kb = KB()
model_inputs = tokenizer(text, padding=True, truncation=True,
return_tensors='pt')
if verbose:
print(model_inputs)
gen_kwargs = {
"max_length": 520,
"length_penalty": 11,
"num_return_sequences": 3
}
generated_tokens = model.generate(
**model_inputs,
**gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
for sentence_pred in decoded_preds:
relations = extract_relations_from_model_output(sentence_pred)
for r in relations:
kb.add_relation(r)
return kb
tmp_event=[]
with open('news_list.pkl', 'rb') as file:
data = pickle.load(file)
# test_data=data[0:50]
# test_data.append(ten_ner)
data=data[100:700]
relations_find_all=[]
for i in data:
kb = from_small_text_to_kb(i, verbose=True)
kb.print()
relations = kb.print()
for j in relations:
relations_find_all.append(j)
G = nx.DiGraph()
graph_data = {
"nodes": {},
"edges": []
}
node_ids = {}
node_id_counter = 0
for rel in relations_find_all:
head = rel['head']
tail = rel['tail']
rel_type = rel['type']
if head not in node_ids:
node_ids[head] = node_id_counter
node_id_counter += 1
if tail not in node_ids:
node_ids[tail] = node_id_counter
node_id_counter += 1
if node_ids[head] not in graph_data["nodes"]:
graph_data["nodes"][node_ids[head]] = {"label": head, "category": "related"}
if node_ids[tail] not in graph_data["nodes"]:
graph_data["nodes"][node_ids[tail]] = {"label": tail, "category": "related"}
graph_data["edges"].append({"from": node_ids[head], "to": node_ids[tail], "label": rel_type, "category": "related"})
with open('graph_data_from_kg.json', 'w') as json_file:
json.dump(graph_data, json_file, indent=4)
print("Graph data has been stored in graph_data.json")
for relation in relations_find_all:
G.add_edge(relation['head'], relation['tail'], relation_type=relation['type'])