Skip to content

Add Graph class #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions pina/data/dataset.py
Original file line number Diff line number Diff line change
@@ -93,8 +93,7 @@ def __getitem__(self, idx):


class PinaGraphDataset(PinaDataset):
pass
'''

def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
@@ -113,7 +112,7 @@ def fetch_from_idx_list(self, idx):
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
for i in cond_idx])
if isinstance(v, list)
else v[cond_idx]
else v[cond_idx].reshape(-1, *v[cond_idx].shape[2:])
for k, v in data.items()
}
return to_return_dict
@@ -132,5 +131,4 @@ def get_all_data(self):
return self.fetch_from_idx_list(index)

def __getitem__(self, idx):
return self._getitem_func(idx)
'''
return self._getitem_func(idx)
360 changes: 269 additions & 91 deletions pina/graph.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pina/model/__init__.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
"AveragingNeuralOperator",
"LowRankNeuralOperator",
"Spline",
"GraphNeuralOperator"
]

from .feed_forward import FeedForward, ResidualFeedForward
@@ -20,3 +21,4 @@
from .avno import AveragingNeuralOperator
from .lno import LowRankNeuralOperator
from .spline import Spline
from .gno import GraphNeuralOperator
177 changes: 177 additions & 0 deletions pina/model/gno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import torch
from torch.nn import Tanh
from .layers import GNOBlock
from .base_no import KernelNeuralOperator


class GraphNeuralKernel(torch.nn.Module):
"""
TODO add docstring
"""

def __init__(
self,
width,
edge_features,
n_layers=2,
internal_n_layers=0,
internal_layers=None,
inner_size=None,
internal_func=None,
external_func=None,
shared_weights=False
):
"""
The Graph Neural Kernel constructor.
:param width: The width of the kernel.
:type width: int
:param edge_features: The number of edge features.
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
:param external_func: The activation function applied to the output of the Graph Integral Layer.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
"""
super().__init__()
if external_func is None:
external_func = Tanh
if internal_func is None:
internal_func = Tanh

if shared_weights:
self.layers = GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
layers=internal_layers,
inner_size=inner_size,
internal_func=internal_func,
external_func=external_func)
self.n_layers = n_layers
self.forward = self.forward_shared
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like a lot this forward separation, is there a way to combine the two?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to have an efficient way to store parameters (avoid to use torch.nn.ModuleList with the same model repeated n_layer times), another possible solution is using an if in the forward. Otherwise I can define another 2 classes: one for the shared_weights and one for the non shared_weights. Let me know what how to proceed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep it like this for the moment, maybe two classes is the best but for a single model I would not care a lot

else:
self.layers = torch.nn.ModuleList(
[GNOBlock(
width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
layers=internal_layers,
inner_size=inner_size,
internal_func=internal_func,
external_func=external_func
)
for _ in range(n_layers)]
)

def forward(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Neural Kernel used when the weights are not shared.
:param x: The input batch.
:type x: torch.Tensor
:param edge_index: The edge index.
:type edge_index: torch.Tensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor
"""
for layer in self.layers:
x = layer(x, edge_index, edge_attr)
return x

def forward_shared(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Neural Kernel used when the weights are shared.
:param x: The input batch.
:type x: torch.Tensor
:param edge_index: The edge index.
:type edge_index: torch.Tensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor
"""
for _ in range(self.n_layers):
x = self.layers(x, edge_index, edge_attr)
return x


class GraphNeuralOperator(KernelNeuralOperator):
"""
TODO add docstring
"""

def __init__(
self,
lifting_operator,
projection_operator,
edge_features,
n_layers=10,
internal_n_layers=0,
inner_size=None,
internal_layers=None,
internal_func=None,
external_func=None,
shared_weights=True
):
"""
The Graph Neural Operator constructor.
:param lifting_operator: The lifting operator mapping the node features to its hidden dimension.
:type lifting_operator: torch.nn.Module
:param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function.
:type projection_operator: torch.nn.Module
:param edge_features: Number of edge features.
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
:type internal_func: torch.nn.Module
:param external_func: The activation function applied to the output of the Graph Integral Kernel.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
:type shared_weights: bool
"""

if internal_func is None:
internal_func = Tanh
if external_func is None:
external_func = Tanh

super().__init__(
lifting_operator=lifting_operator,
integral_kernels=GraphNeuralKernel(
width=lifting_operator.out_features,
edge_features=edge_features,
internal_n_layers=internal_n_layers,
inner_size=inner_size,
internal_layers=internal_layers,
external_func=external_func,
internal_func=internal_func,
n_layers=n_layers,
shared_weights=shared_weights
),
projection_operator=projection_operator
)

def forward(self, x):
"""
The forward pass of the Graph Neural Operator.
:param x: The input batch.
:type x: torch_geometric.data.Batch
"""
x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr
x = self.lifting_operator(x)
x = self.integral_kernels(x, edge_index, edge_attr)
x = self.projection_operator(x)
return x
2 changes: 2 additions & 0 deletions pina/model/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
"AVNOBlock",
"LowRankBlock",
"RBFBlock",
"GNOBlock"
]

from .convolution_2d import ContinuousConvBlock
@@ -31,3 +32,4 @@
from .avno_layer import AVNOBlock
from .lowrank_layer import LowRankBlock
from .rbf_layer import RBFBlock
from .gno_block import GNOBlock
87 changes: 87 additions & 0 deletions pina/model/layers/gno_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from torch_geometric.nn import MessagePassing


class GNOBlock(MessagePassing):
"""
TODO: Add documentation
"""

def __init__(
self,
width,
edges_features,
n_layers=2,
layers=None,
inner_size=None,
internal_func=None,
external_func=None
):
"""
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric.
:param width: The width of the hidden representation of the nodes features
:type width: int
:param edges_features: The number of edge features.
:type edges_features: int
:param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features.
:type n_layers: int
"""
from pina.model import FeedForward
super(GNOBlock, self).__init__(aggr='mean')
self.width = width
if layers is None and inner_size is None:
inner_size = width
self.dense = FeedForward(input_dimensions=edges_features,
output_dimensions=width ** 2,
n_layers=n_layers,
layers=layers,
inner_size=inner_size,
func=internal_func)
self.W = torch.nn.Linear(width, width)
self.func = external_func()

def message(self, x_j, edge_attr):
"""
This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class.
:param x_j: The node features of the neighboring.
:type x_j: torch.Tensor
:param edge_attr: The edge features.
:type edge_attr: torch.Tensor
:return: The message passed between the nodes of the graph.
:rtype: torch.Tensor
"""
x = self.dense(edge_attr).view(-1, self.width, self.width)
return torch.einsum('bij,bj->bi', x, x_j)

def update(self, aggr_out, x):
"""
This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class.
:param aggr_out: The aggregated messages.
:type aggr_out: torch.Tensor
:param x: The node features.
:type x: torch.Tensor
:return: The updated node features.
:rtype: torch.Tensor
"""
aggr_out = aggr_out + self.W(x)
return aggr_out

def forward(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Integral Layer.
:param x: Node features.
:type x: torch.Tensor
:param edge_index: Edge index.
:type edge_index: torch.Tensor
:param edge_attr: Edge features.
:type edge_attr: torch.Tensor
:return: Output of a single iteration over the Graph Integral Layer.
:rtype: torch.Tensor
"""
return self.func(
self.propagate(edge_index, x=x, edge_attr=edge_attr)
)
125 changes: 125 additions & 0 deletions tests/test_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import torch
import pytest
from pina import Condition, LabelTensor, Graph
from pina.condition import InputOutputPointsCondition, DomainEquationCondition
from pina.graph import RadiusGraph
from pina.problem import AbstractProblem, SpatialProblem
from pina.domain import CartesianDomain
from pina.equation.equation import Equation
from pina.equation.equation_factory import FixedValue
from pina.operators import laplacian

def test_supervised_tensor_collector():
class SupervisedProblem(AbstractProblem):
output_variables = None
conditions = {
'data1' : Condition(input_points=torch.rand((10,2)),
output_points=torch.rand((10,2))),
'data2' : Condition(input_points=torch.rand((20,2)),
output_points=torch.rand((20,2))),
'data3' : Condition(input_points=torch.rand((30,2)),
output_points=torch.rand((30,2))),
}
problem = SupervisedProblem()
collector = problem.collector
for v in collector.conditions_name.values():
assert v in problem.conditions.keys()
assert all(collector._is_conditions_ready.values())

def test_pinn_collector():
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
torch.sin(input_.extract(['y']) * torch.pi))
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term

my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})

conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': [0, 1]
}),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_)
}

def poisson_sol(self, pts):
return -(torch.sin(pts.extract(['x']) * torch.pi) *
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)

truth_solution = poisson_sol

problem = Poisson()
collector = problem.collector
for k,v in problem.conditions.items():
if isinstance(v, InputOutputPointsCondition):
assert collector._is_conditions_ready[k] == True
assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points']
else:
assert collector._is_conditions_ready[k] == False
assert collector.data_collections[k] == {}

boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
problem.discretise_domain(10, 'grid', locations=boundaries)
problem.discretise_domain(10, 'grid', locations='D')
assert all(collector._is_conditions_ready.values())
for k,v in problem.conditions.items():
if isinstance(v, DomainEquationCondition):
assert list(collector.data_collections[k].keys()) == ['input_points', 'equation']


def test_supervised_graph_collector():
pos = torch.rand((100,3))
x = [torch.rand((100,3)) for _ in range(10)]
graph_list_1 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4)
out_1 = torch.rand((10,100,3))
pos = torch.rand((50,3))
x = [torch.rand((50,3)) for _ in range(10)]
graph_list_2 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4)
out_2 = torch.rand((10,50,3))
class SupervisedProblem(AbstractProblem):
output_variables = None
conditions = {
'data1' : Condition(input_points=graph_list_1,
output_points=out_1),
'data2' : Condition(input_points=graph_list_2,
output_points=out_2),
}

problem = SupervisedProblem()
collector = problem.collector
assert all(collector._is_conditions_ready.values())
for v in collector.conditions_name.values():
assert v in problem.conditions.keys()
163 changes: 163 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import pytest
import torch
from pina.graph import RadiusGraph, KNNGraph


@pytest.mark.parametrize(
"x, pos",
[
([torch.rand(10, 2) for _ in range(3)],
[torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 2) for _ in range(3)],
[torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)),
]
)
def test_build_multiple_graph_multiple_val(x, pos):
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)

graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)


def test_build_single_graph_multiple_val():
x = torch.rand(10, 2)
pos = torch.rand(10, 3)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 1
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert len(graph.data) == 1
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)

x = torch.rand(10, 2)
pos = torch.rand(10, 3)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
assert len(graph.data) == 1
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert len(graph.data) == 1
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)


@pytest.mark.parametrize(
"pos",
[
([torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 3)),
(torch.rand(3, 10, 3))
]
)
def test_build_single_graph_single_val(pos):
x = torch.rand(10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
x = torch.rand(10, 2)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=False, k=3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)


def test_additional_parameters_1():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
additional_parameters = {'y': torch.ones(3)}
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(d_.y == 1 for d_ in data)


@pytest.mark.parametrize(
"additional_parameters",
[
({'y': torch.rand(3, 10, 1)}),
({'y': [torch.rand(10, 1) for _ in range(3)]}),
]
)
def test_additional_parameters_2(additional_parameters):
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))

def test_custom_build_edge_attr_func():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)

def build_edge_attr(x, pos, edge_index):
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1)

graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
custom_build_edge_attr=build_edge_attr)
assert len(graph.data) == 3
data = graph.data
assert all(hasattr(d, 'edge_attr') for d in data)
assert all(d.edge_attr.shape[1] == 4 for d in data)
assert all(torch.isclose(d.edge_attr,
build_edge_attr(d.x, d.pos, d.edge_index)).all()
for d in data)
129 changes: 129 additions & 0 deletions tests/test_model/test_gno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
import torch
from pina.graph import KNNGraph
from pina.model import GraphNeuralOperator
from torch_geometric.data import Batch

x = [torch.rand(100, 6) for _ in range(10)]
pos = [torch.rand(100, 3) for _ in range(10)]
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=6)
input_ = Batch.from_data_list(graph.data)


@pytest.mark.parametrize(
"shared_weights",
[
True,
False
]
)
def test_constructor(shared_weights):
lifting_operator = torch.nn.Linear(6, 16)
projection_operator = torch.nn.Linear(16, 3)
GraphNeuralOperator(lifting_operator=lifting_operator,
projection_operator=projection_operator,
edge_features=3,
internal_layers=[16, 16],
shared_weights=shared_weights)

GraphNeuralOperator(lifting_operator=lifting_operator,
projection_operator=projection_operator,
edge_features=3,
inner_size=16,
internal_n_layers=10,
shared_weights=shared_weights)

int_func = torch.nn.Softplus
ext_func = torch.nn.ReLU

GraphNeuralOperator(lifting_operator=lifting_operator,
projection_operator=projection_operator,
edge_features=3,
internal_n_layers=10,
shared_weights=shared_weights,
internal_func=int_func,
external_func=ext_func)


@pytest.mark.parametrize(
"shared_weights",
[
True,
False
]
)
def test_forward_1(shared_weights):
lifting_operator = torch.nn.Linear(6, 16)
projection_operator = torch.nn.Linear(16, 3)
model = GraphNeuralOperator(lifting_operator=lifting_operator,
projection_operator=projection_operator,
edge_features=3,
internal_layers=[16, 16],
shared_weights=shared_weights)
output_ = model(input_)
assert output_.shape == torch.Size([1000, 3])


@pytest.mark.parametrize(
"shared_weights",
[
True,
False
]
)
def test_forward_2(shared_weights):
lifting_operator = torch.nn.Linear(6, 16)
projection_operator = torch.nn.Linear(16, 3)
model = GraphNeuralOperator(lifting_operator=lifting_operator,
projection_operator=projection_operator,
edge_features=3,
inner_size=32,
internal_n_layers=2,
shared_weights=shared_weights)
output_ = model(input_)
assert output_.shape == torch.Size([1000, 3])


@pytest.mark.parametrize(
"shared_weights",
[
True,
False
]
)
def test_backward(shared_weights):
lifting_operator = torch.nn.Linear(6, 16)
projection_operator = torch.nn.Linear(16, 3)
model = GraphNeuralOperator(lifting_operator=lifting_operator,
projection_operator=projection_operator,
edge_features=3,
internal_layers=[16, 16],
shared_weights=shared_weights)
input_.x.requires_grad = True
output_ = model(input_)
l = torch.mean(output_)
l.backward()
assert input_.x.grad.shape == torch.Size([1000, 6])


@pytest.mark.parametrize(
"shared_weights",
[
True,
False
]
)
def test_backward_2(shared_weights):
lifting_operator = torch.nn.Linear(6, 16)
projection_operator = torch.nn.Linear(16, 3)
model = GraphNeuralOperator(lifting_operator=lifting_operator,
projection_operator=projection_operator,
edge_features=3,
inner_size=32,
internal_n_layers=2,
shared_weights=shared_weights)
input_.x.requires_grad = True
output_ = model(input_)
l = torch.mean(output_)
l.backward()
assert input_.x.grad.shape == torch.Size([1000, 6])