forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sampler.py
152 lines (129 loc) · 6.09 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import numpy as np
import dgl
import torch
from torch.utils.data import IterableDataset, DataLoader
def compact_and_copy(frontier, seeds):
block = dgl.to_block(frontier, seeds)
for col, data in frontier.edata.items():
if col == dgl.EID:
continue
block.edata[col] = data[block.edata[dgl.EID]]
return block
class ItemToItemBatchSampler(IterableDataset):
def __init__(self, g, user_type, item_type, batch_size):
self.g = g
self.user_type = user_type
self.item_type = item_type
self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.batch_size = batch_size
def __iter__(self):
while True:
heads = torch.randint(0, self.g.number_of_nodes(self.item_type), (self.batch_size,))
tails = dgl.sampling.random_walk(
self.g,
heads,
metapath=[self.item_to_user_etype, self.user_to_item_etype])[0][:, 2]
neg_tails = torch.randint(0, self.g.number_of_nodes(self.item_type), (self.batch_size,))
mask = (tails != -1)
yield heads[mask], tails[mask], neg_tails[mask]
class NeighborSampler(object):
def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob,
num_random_walks, num_neighbors, num_layers):
self.g = g
self.user_type = user_type
self.item_type = item_type
self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.samplers = [
dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
random_walk_restart_prob, num_random_walks, num_neighbors)
for _ in range(num_layers)]
def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None):
blocks = []
for sampler in self.samplers:
frontier = sampler(seeds)
if heads is not None:
eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2]
if len(eids) > 0:
old_frontier = frontier
frontier = dgl.remove_edges(old_frontier, eids)
#print(old_frontier)
#print(frontier)
#print(frontier.edata['weights'])
#frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]
block = compact_and_copy(frontier, seeds)
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
def sample_from_item_pairs(self, heads, tails, neg_tails):
# Create a graph with positive connections only and another graph with negative
# connections only.
pos_graph = dgl.graph(
(heads, tails),
num_nodes=self.g.number_of_nodes(self.item_type))
neg_graph = dgl.graph(
(heads, neg_tails),
num_nodes=self.g.number_of_nodes(self.item_type))
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
seeds = pos_graph.ndata[dgl.NID]
blocks = self.sample_blocks(seeds, heads, tails, neg_tails)
return pos_graph, neg_graph, blocks
def assign_simple_node_features(ndata, g, ntype, assign_id=False):
"""
Copies data to the given block from the corresponding nodes in the original graph.
"""
for col in g.nodes[ntype].data.keys():
if not assign_id and col == dgl.NID:
continue
induced_nodes = ndata[dgl.NID]
ndata[col] = g.nodes[ntype].data[col][induced_nodes]
def assign_textual_node_features(ndata, textset, ntype):
"""
Assigns numericalized tokens from a torchtext dataset to given block.
The numericalized tokens would be stored in the block as node features
with the same name as ``field_name``.
The length would be stored as another node feature with name
``field_name + '__len'``.
block : DGLHeteroGraph
First element of the compacted blocks, with "dgl.NID" as the
corresponding node ID in the original graph, hence the index to the
text dataset.
The numericalized tokens (and lengths if available) would be stored
onto the blocks as new node features.
textset : torchtext.data.Dataset
A torchtext dataset whose number of examples is the same as that
of nodes in the original graph.
"""
node_ids = ndata[dgl.NID].numpy()
for field_name, field in textset.fields.items():
examples = [getattr(textset[i], field_name) for i in node_ids]
tokens, lengths = field.process(examples)
if not field.batch_first:
tokens = tokens.t()
ndata[field_name] = tokens
ndata[field_name + '__len'] = lengths
def assign_features_to_blocks(blocks, g, textset, ntype):
# For the first block (which is closest to the input), copy the features from
# the original graph as well as the texts.
assign_simple_node_features(blocks[0].srcdata, g, ntype)
assign_textual_node_features(blocks[0].srcdata, textset, ntype)
assign_simple_node_features(blocks[-1].dstdata, g, ntype)
assign_textual_node_features(blocks[-1].dstdata, textset, ntype)
class PinSAGECollator(object):
def __init__(self, sampler, g, ntype, textset):
self.sampler = sampler
self.ntype = ntype
self.g = g
self.textset = textset
def collate_train(self, batches):
heads, tails, neg_tails = batches[0]
# Construct multilayer neighborhood via PinSAGE...
pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(heads, tails, neg_tails)
assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)
return pos_graph, neg_graph, blocks
def collate_test(self, samples):
batch = torch.LongTensor(samples)
blocks = self.sampler.sample_blocks(batch)
assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)
return blocks