-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathentity_utils.py
71 lines (61 loc) · 2.21 KB
/
entity_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import dgl
import torch as th
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
def load_data(data_name, get_norm=False, inv_target=False):
if data_name == "aifb":
dataset = AIFBDataset()
elif data_name == "mutag":
dataset = MUTAGDataset()
elif data_name == "bgs":
dataset = BGSDataset()
else:
dataset = AMDataset()
# Load hetero-graph
hg = dataset[0]
num_rels = len(hg.canonical_etypes)
category = dataset.predict_category
num_classes = dataset.num_classes
labels = hg.nodes[category].data.pop("labels")
train_mask = hg.nodes[category].data.pop("train_mask")
test_mask = hg.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
if get_norm:
# Calculate normalization weight for each edge,
# 1. / d, d is the degree of the destination node
for cetype in hg.canonical_etypes:
hg.edges[cetype].data["norm"] = dgl.norm_by_dst(
hg, cetype
).unsqueeze(1)
edata = ["norm"]
else:
edata = None
# get target category id
category_id = hg.ntypes.index(category)
g = dgl.to_homogeneous(hg, edata=edata)
# Rename the fields as they can be changed by for example DataLoader
g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
g.ndata["type_id"] = g.ndata.pop(dgl.NID)
node_ids = th.arange(g.num_nodes())
# find out the target node ids in g
loc = g.ndata["ntype"] == category_id
target_idx = node_ids[loc]
if inv_target:
# Map global node IDs to type-specific node IDs. This is required for
# looking up type-specific labels in a minibatch
inv_target = th.empty((g.num_nodes(),), dtype=th.int64)
inv_target[target_idx] = th.arange(
0, target_idx.shape[0], dtype=inv_target.dtype
)
return (
g,
num_rels,
num_classes,
labels,
train_idx,
test_idx,
target_idx,
inv_target,
)
else:
return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx