-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
118 lines (101 loc) · 3.62 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.functional import edge_softmax
from modules import MessageNorm, MLP
from ogb.graphproppred.mol_encoder import BondEncoder
class GENConv(nn.Module):
r"""
Description
-----------
Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
Parameters
----------
in_dim: int
Input size.
out_dim: int
Output size.
aggregator: str
Type of aggregation. Default is 'softmax'.
beta: float
A continuous variable called an inverse temperature. Default is 1.0.
learn_beta: bool
Whether beta is a learnable variable or not. Default is False.
p: float
Initial power for power mean aggregation. Default is 1.0.
learn_p: bool
Whether p is a learnable variable or not. Default is False.
msg_norm: bool
Whether message normalization is used. Default is False.
learn_msg_scale: bool
Whether s is a learnable scaling factor or not in message normalization. Default is False.
mlp_layers: int
The number of MLP layers. Default is 1.
eps: float
A small positive constant in message construction function. Default is 1e-7.
"""
def __init__(
self,
in_dim,
out_dim,
aggregator="softmax",
beta=1.0,
learn_beta=False,
p=1.0,
learn_p=False,
msg_norm=False,
learn_msg_scale=False,
mlp_layers=1,
eps=1e-7,
):
super(GENConv, self).__init__()
self.aggr = aggregator
self.eps = eps
channels = [in_dim]
for _ in range(mlp_layers - 1):
channels.append(in_dim * 2)
channels.append(out_dim)
self.mlp = MLP(channels)
self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None
self.beta = (
nn.Parameter(torch.Tensor([beta]), requires_grad=True)
if learn_beta and self.aggr == "softmax"
else beta
)
self.p = (
nn.Parameter(torch.Tensor([p]), requires_grad=True)
if learn_p
else p
)
self.edge_encoder = BondEncoder(in_dim)
def forward(self, g, node_feats, edge_feats):
with g.local_scope():
# Node and edge feature size need to match.
g.ndata["h"] = node_feats
g.edata["h"] = self.edge_encoder(edge_feats)
g.apply_edges(fn.u_add_e("h", "h", "m"))
if self.aggr == "softmax":
g.edata["m"] = F.relu(g.edata["m"]) + self.eps
g.edata["a"] = edge_softmax(g, g.edata["m"] * self.beta)
g.update_all(
lambda edge: {"x": edge.data["m"] * edge.data["a"]},
fn.sum("x", "m"),
)
elif self.aggr == "power":
minv, maxv = 1e-7, 1e1
torch.clamp_(g.edata["m"], minv, maxv)
g.update_all(
lambda edge: {"x": torch.pow(edge.data["m"], self.p)},
fn.mean("x", "m"),
)
torch.clamp_(g.ndata["m"], minv, maxv)
g.ndata["m"] = torch.pow(g.ndata["m"], self.p)
else:
raise NotImplementedError(
f"Aggregator {self.aggr} is not supported."
)
if self.msg_norm is not None:
g.ndata["m"] = self.msg_norm(node_feats, g.ndata["m"])
feats = node_feats + g.ndata["m"]
return self.mlp(feats)