Skip to content

Commit

Permalink
Fix GENConv for the latest pyg version
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Jan 4, 2024
1 parent fe3689f commit 31c52cc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
16 changes: 11 additions & 5 deletions neuralogic/nn/module/gnn/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,39 @@ def __call__(self):
v_eps = v_eps.fixed()

e_proj = []
if self.edge_dim is not None:
if self.edge_dim is not None and self.out_channels != self.edge_dim:
e = R.get(f"{self.output_name}__gen_edge_proj")
e_proj = [
(e(V.I, V.J)[self.in_channels, self.edge_dim] <= R.get(self.edge_name)(V.I, V.J))
(e(V.I, V.J)[self.out_channels, self.edge_dim] <= R.get(self.edge_name)(V.I, V.J))
| Metadata(transformation=Transformation.IDENTITY),
e / 2 | Metadata(transformation=Transformation.IDENTITY),
]

channels = [self.in_channels]
channels = [self.out_channels]
for _ in range(self.num_layers - 1):
channels.append(self.out_channels * self.expansion)
channels.append(self.out_channels)

mlp = MLP(channels, self.output_name, f"{self.output_name}__gen_out", Transformation.IDENTITY)

j_feat = x(V.J)
i_feat = x(V.I)
if self.in_channels != self.out_channels:
j_feat = x(V.J)[self.out_channels, self.in_channels]
i_feat = x(V.I)[self.out_channels, self.in_channels]

return [
v_eps,
*e_proj,
(feat_sum(V.I, V.J) <= (x(V.J), e(V.J, V.I)))
(feat_sum(V.I, V.J) <= (j_feat, e(V.J, V.I)))
| Metadata(transformation=Transformation.RELU, combination=Combination.SUM),
feat_sum / 2 | Metadata(transformation=Transformation.IDENTITY),
(feat_agg(V.I) <= (feat_sum(V.I, V.J), eps))
| Metadata(
transformation=Transformation.IDENTITY, aggregation=self.aggregation, combination=Combination.SUM
),
feat_agg / 1 | Metadata(transformation=Transformation.IDENTITY),
(out(V.I) <= (x(V.I), feat_agg(V.I)))
(out(V.I) <= (i_feat, feat_agg(V.I)))
| Metadata(transformation=Transformation.IDENTITY, combination=Combination.SUM),
out / 1 | Metadata(transformation=Transformation.IDENTITY),
*mlp(),
Expand Down
10 changes: 7 additions & 3 deletions tests/test_gnn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def test_gen_module(input_size, hidden_size):

data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr=e)

gen = GENConv(input_size, hidden_size, aggr="mean", num_layers=1, eps=0)
gen = GENConv(input_size, hidden_size, aggr="mean", num_layers=1, eps=0, edge_dim=input_size)
for m in gen.mlp._modules.values():
m.bias = None

template = Template()
template += neuralogic.nn.module.GENConv(
input_size, hidden_size, "h", "f", "e", num_layers=1, aggregation=Aggregation.AVG, eps=0
input_size, hidden_size, "h", "f", "e", num_layers=1, aggregation=Aggregation.AVG, eps=0, edge_dim=input_size
)

model = template.build(
Expand All @@ -55,7 +55,11 @@ def test_gen_module(input_size, hidden_size):

parameters = model.parameters()
torch_parameters = [parameter.tolist() for parameter in gen.parameters()]
parameters["weights"][1] = [torch_parameters[0][i] for i in range(0, hidden_size)]

parameters["weights"][2] = [torch_parameters[0][i] for i in range(0, hidden_size)]
parameters["weights"][1] = [torch_parameters[1][i] for i in range(0, hidden_size)]
parameters["weights"][3] = [torch_parameters[2][i] for i in range(0, hidden_size)]
parameters["weights"][4] = torch_parameters[3][0]

model.load_state_dict(parameters)

Expand Down

0 comments on commit 31c52cc

Please sign in to comment.