-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
51 lines (40 loc) · 1.91 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
from layer import GraphConv, HigherOrderGraphConv
class GCN_2layer(nn.Module):
def __init__(self, in_features, hidden_features, out_features, skip = False):
super(GCN_2layer, self).__init__()
self.skip = skip
self.gcl1 = GraphConv(in_features, hidden_features)
if self.skip:
self.gcl_skip = GraphConv(hidden_features, out_features, activation = 'softmax', skip = self.skip,
skip_in_features = in_features)
else:
self.gcl2 = GraphConv(hidden_features, out_features, activation = 'softmax') # out_features为num_classes
def forward(self, A, X):
out = self.gcl1(A, X)
if self.skip:
out = self.gcl_skip(A, out, X)
else:
out = self.gcl2(A, out)
return out
class HigherOrderGCN_2layer(nn.Module):
def __init__(self, in_features, hidden_features, out_features, nums_class, order, skip = False):
super(HigherOrderGCN_2layer, self).__init__()
self.skip = skip
self.gcl1 = HigherOrderGraphConv(in_features, hidden_features, order)
if self.skip:
self.gcl_skip = HigherOrderGraphConv(hidden_features, out_features, order, activation = 'relu', skip = self.skip,
skip_in_features = in_features)
else:
self.gcl2 = HigherOrderGraphConv(hidden_features, out_features, order, activation = 'relu') # out_features为num_classes
self.fully_connected = torch.nn.Linear(out_features, nums_class)
def forward(self, A, X):
out = self.gcl1(A, X)
if self.skip:
out = self.gcl_skip(A, out, X)
else:
out = self.gcl2(A, out)
out = self.fully_connected(out) # fully connected
out = torch.nn.functional.softmax(out, dim=1)
return out