-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdatasets.py
More file actions
220 lines (171 loc) · 9.57 KB
/
datasets.py
File metadata and controls
220 lines (171 loc) · 9.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import os
import numpy as np
import torch
from torch.nn import functional as F
import dgl
from dgl import ops
from sklearn.metrics import roc_auc_score
from rewiring.fastrewiringKupdates import *
from rewiring.MinGapKupdates import *
from rewiring.fosr import *
from rewiring.spectral_utils import *
import nx_cugraph as nxcg
from torch_geometric.data import Data
import networkx as nx
from sklearn.metrics.cluster import normalized_mutual_info_score as NMI
import time
class Dataset:
def __init__(self, name, seed, max_iterations, rewiring_method='none', add_self_loops=False, device='cpu',
use_sgc_features=False, use_identity_features=False, use_adjacency_features=False,
do_not_use_original_features=False):
if do_not_use_original_features and not any([use_sgc_features, use_identity_features, use_adjacency_features]):
raise ValueError('If original node features are not used, at least one of the arguments '
'use_sgc_features, use_identity_features, use_adjacency_features should be used.')
print('Preparing data...')
data = np.load(os.path.join('data', f'{name.replace("-", "_")}.npz'))
node_features = torch.tensor(data['node_features'])
labels = torch.tensor(data['node_labels'])
edges = torch.tensor(data['edges'])
edge_index = torch.tensor(data['edges'], dtype=torch.long).t().contiguous()
graph = dgl.graph((edges[:, 0], edges[:, 1]), num_nodes=len(node_features), idtype=torch.int)
nxgraph = dgl.to_networkx(graph)
nxgraph = nx.Graph(nxgraph)
print(nxgraph)
# nxcg_G = nxcg.from_networkx(nxgraph)
# communities_before = list(nx.community.louvain_communities(nxcg_G, seed=seed))
# cluster_dict_before = {node: i for i, cluster in enumerate(communities_before) for node in cluster}
# cluster_list_before = [cluster_dict_before[node] for node in range(len(labels))]
# self.nmiscoremod_before = NMI(cluster_list_before, labels.numpy())
# print(f'NMI Score: {self.nmiscoremod_before:.4f}')
start_algo = time.time()
if rewiring_method == 'proxydelmin':
print("Deleting edges to minimize spectral gap")
newgraph = min_and_update_edges(nxgraph, rank_by_proxy_delete_min, "proxydeletemin", seed, max_iter=max_iterations, updating_period=5)
elif rewiring_method == 'proxydeletemax':
print("Deleting edges to maximize spectral gap")
newgraph = process_and_update_edges(nxgraph, rank_by_proxy_delete, "proxydeletemax", seed, max_iter=max_iterations, updating_period=5)
elif rewiring_method == 'proxyaddmax':
print("Adding edges to maximize spectral gap")
newgraph = process_and_update_edges(nxgraph, rank_by_proxy_add, "proxyaddmax", seed, max_iter=max_iterations, updating_period=5)
elif rewiring_method == 'proxyaddmin':
print("Adding edges to minimize spectral gap")
newgraph = min_and_update_edges(nxgraph, rank_by_proxy_add_min, "proxyaddmin", seed, max_iter=max_iterations, updating_period=5)
elif rewiring_method == 'fosr':
print("Applying FOSR")
print("Converting to PyG dataset...")
pydata = Data(x=node_features, edge_index=edge_index, y=labels)
print(pydata)
for j in tqdm(range(max_iterations)):
edge_index, _, _, prod = edge_rewire(pydata.edge_index.numpy(), num_iterations=1)
pydata.edge_index = torch.tensor(edge_index)
pydata.edge_index = torch.cat([pydata.edge_index])
print(pydata)
newgraph = to_networkx(pydata, to_undirected=True)
else:
print("No rewiring applied")
newgraph = nxgraph
if rewiring_method != 'none':
newgraph.remove_edges_from(list(nx.selfloop_edges(newgraph)))
end_algo = time.time()
print(newgraph)
graph = dgl.from_networkx(newgraph)
print(graph)
self.rewire_time = end_algo - start_algo
print("Time taken to rewire the graph:", self.rewire_time)
# nxcg_G = nxcg.from_networkx(newgraph)
# communities_after = list(nx.community.louvain_communities(nxcg_G, seed=seed))
# cluster_dict_after = {node: i for i, cluster in enumerate(communities_after) for node in cluster}
# cluster_list_after = [cluster_dict_after[node] for node in range(len(labels))]
# self.nmiscoremod_after = NMI(cluster_list_after, labels.numpy())
# print(f'NMI Score: {self.nmiscoremod_after:.4f}')
if 'directed' not in name:
graph = dgl.to_bidirected(graph)
if add_self_loops:
graph = dgl.add_self_loop(graph)
num_classes = len(labels.unique())
num_targets = 1 if num_classes == 2 else num_classes
if num_targets == 1:
labels = labels.float()
train_masks = torch.tensor(data['train_masks'])
val_masks = torch.tensor(data['val_masks'])
test_masks = torch.tensor(data['test_masks'])
train_idx_list = [torch.where(train_mask)[0] for train_mask in train_masks]
val_idx_list = [torch.where(val_mask)[0] for val_mask in val_masks]
test_idx_list = [torch.where(test_mask)[0] for test_mask in test_masks]
node_features = self.augment_node_features(graph=graph,
node_features=node_features,
use_sgc_features=use_sgc_features,
use_identity_features=use_identity_features,
use_adjacency_features=use_adjacency_features,
do_not_use_original_features=do_not_use_original_features)
self.name = name
self.device = device
self.graph = graph.to(device)
self.node_features = node_features.to(device)
self.labels = labels.to(device)
self.train_idx_list = [train_idx.to(device) for train_idx in train_idx_list]
self.val_idx_list = [val_idx.to(device) for val_idx in val_idx_list]
self.test_idx_list = [test_idx.to(device) for test_idx in test_idx_list]
self.num_data_splits = len(train_idx_list)
self.cur_data_split = 0
self.num_node_features = node_features.shape[1]
self.num_targets = num_targets
self.loss_fn = F.binary_cross_entropy_with_logits if num_targets == 1 else F.cross_entropy
self.metric = 'ROC AUC' if num_targets == 1 else 'accuracy'
@property
def train_idx(self):
return self.train_idx_list[self.cur_data_split]
@property
def val_idx(self):
return self.val_idx_list[self.cur_data_split]
@property
def test_idx(self):
return self.test_idx_list[self.cur_data_split]
def next_data_split(self):
self.cur_data_split = (self.cur_data_split + 1) % self.num_data_splits
def compute_metrics(self, logits):
if self.num_targets == 1:
train_metric = roc_auc_score(y_true=self.labels[self.train_idx].cpu().numpy(),
y_score=logits[self.train_idx].cpu().numpy()).item()
val_metric = roc_auc_score(y_true=self.labels[self.val_idx].cpu().numpy(),
y_score=logits[self.val_idx].cpu().numpy()).item()
test_metric = roc_auc_score(y_true=self.labels[self.test_idx].cpu().numpy(),
y_score=logits[self.test_idx].cpu().numpy()).item()
else:
preds = logits.argmax(axis=1)
train_metric = (preds[self.train_idx] == self.labels[self.train_idx]).float().mean().item()
val_metric = (preds[self.val_idx] == self.labels[self.val_idx]).float().mean().item()
test_metric = (preds[self.test_idx] == self.labels[self.test_idx]).float().mean().item()
metrics = {
f'train {self.metric}': train_metric,
f'val {self.metric}': val_metric,
f'test {self.metric}': test_metric
}
return metrics
@staticmethod
def augment_node_features(graph, node_features, use_sgc_features, use_identity_features, use_adjacency_features,
do_not_use_original_features):
n = graph.num_nodes()
original_node_features = node_features
if do_not_use_original_features:
node_features = torch.tensor([[] for _ in range(n)])
if use_sgc_features:
sgc_features = Dataset.compute_sgc_features(graph, original_node_features)
node_features = torch.cat([node_features, sgc_features], axis=1)
if use_identity_features:
node_features = torch.cat([node_features, torch.eye(n)], axis=1)
if use_adjacency_features:
graph_without_self_loops = dgl.remove_self_loop(graph)
adj_matrix = graph_without_self_loops.adjacency_matrix().to_dense()
node_features = torch.cat([node_features, adj_matrix], axis=1)
return node_features
@staticmethod
def compute_sgc_features(graph, node_features, num_props=5):
graph = dgl.remove_self_loop(graph)
graph = dgl.add_self_loop(graph)
degrees = graph.out_degrees().float()
degree_edge_products = ops.u_mul_v(graph, degrees, degrees)
norm_coefs = 1 / degree_edge_products ** 0.5
for _ in range(num_props):
node_features = ops.u_mul_e_sum(graph, node_features, norm_coefs)
return node_features