-
Notifications
You must be signed in to change notification settings - Fork 81
Add Graph class #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Graph class #403
Changes from all commits
1f32ace
05faaaa
6eba0cf
d591aee
dc87615
d79017e
a68b711
9b7cdbf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import torch | ||
from torch.nn import Tanh | ||
from .layers import GNOBlock | ||
from .base_no import KernelNeuralOperator | ||
|
||
|
||
class GraphNeuralKernel(torch.nn.Module): | ||
""" | ||
TODO add docstring | ||
""" | ||
|
||
def __init__( | ||
self, | ||
width, | ||
edge_features, | ||
n_layers=2, | ||
internal_n_layers=0, | ||
internal_layers=None, | ||
inner_size=None, | ||
internal_func=None, | ||
external_func=None, | ||
shared_weights=False | ||
): | ||
""" | ||
The Graph Neural Kernel constructor. | ||
:param width: The width of the kernel. | ||
:type width: int | ||
:param edge_features: The number of edge features. | ||
:type edge_features: int | ||
:param n_layers: The number of kernel layers. | ||
:type n_layers: int | ||
:param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer. | ||
:type internal_n_layers: int | ||
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. | ||
:type internal_layers: list | tuple | ||
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. | ||
:param external_func: The activation function applied to the output of the Graph Integral Layer. | ||
:type external_func: torch.nn.Module | ||
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. | ||
""" | ||
super().__init__() | ||
if external_func is None: | ||
external_func = Tanh | ||
if internal_func is None: | ||
internal_func = Tanh | ||
|
||
if shared_weights: | ||
self.layers = GNOBlock( | ||
width=width, | ||
edges_features=edge_features, | ||
n_layers=internal_n_layers, | ||
layers=internal_layers, | ||
inner_size=inner_size, | ||
internal_func=internal_func, | ||
external_func=external_func) | ||
self.n_layers = n_layers | ||
self.forward = self.forward_shared | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like a lot this forward separation, is there a way to combine the two? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to have an efficient way to store parameters (avoid to use torch.nn.ModuleList with the same model repeated n_layer times), another possible solution is using an if in the forward. Otherwise I can define another 2 classes: one for the shared_weights and one for the non shared_weights. Let me know what how to proceed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep it like this for the moment, maybe two classes is the best but for a single model I would not care a lot |
||
else: | ||
self.layers = torch.nn.ModuleList( | ||
[GNOBlock( | ||
width=width, | ||
edges_features=edge_features, | ||
n_layers=internal_n_layers, | ||
layers=internal_layers, | ||
inner_size=inner_size, | ||
internal_func=internal_func, | ||
external_func=external_func | ||
) | ||
for _ in range(n_layers)] | ||
) | ||
|
||
def forward(self, x, edge_index, edge_attr): | ||
""" | ||
The forward pass of the Graph Neural Kernel used when the weights are not shared. | ||
:param x: The input batch. | ||
:type x: torch.Tensor | ||
:param edge_index: The edge index. | ||
:type edge_index: torch.Tensor | ||
:param edge_attr: The edge attributes. | ||
:type edge_attr: torch.Tensor | ||
""" | ||
for layer in self.layers: | ||
x = layer(x, edge_index, edge_attr) | ||
return x | ||
|
||
def forward_shared(self, x, edge_index, edge_attr): | ||
""" | ||
The forward pass of the Graph Neural Kernel used when the weights are shared. | ||
:param x: The input batch. | ||
:type x: torch.Tensor | ||
:param edge_index: The edge index. | ||
:type edge_index: torch.Tensor | ||
:param edge_attr: The edge attributes. | ||
:type edge_attr: torch.Tensor | ||
""" | ||
for _ in range(self.n_layers): | ||
x = self.layers(x, edge_index, edge_attr) | ||
return x | ||
|
||
|
||
class GraphNeuralOperator(KernelNeuralOperator): | ||
""" | ||
TODO add docstring | ||
""" | ||
|
||
def __init__( | ||
self, | ||
lifting_operator, | ||
projection_operator, | ||
edge_features, | ||
n_layers=10, | ||
internal_n_layers=0, | ||
inner_size=None, | ||
internal_layers=None, | ||
internal_func=None, | ||
external_func=None, | ||
shared_weights=True | ||
): | ||
""" | ||
The Graph Neural Operator constructor. | ||
:param lifting_operator: The lifting operator mapping the node features to its hidden dimension. | ||
:type lifting_operator: torch.nn.Module | ||
:param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function. | ||
:type projection_operator: torch.nn.Module | ||
:param edge_features: Number of edge features. | ||
:type edge_features: int | ||
:param n_layers: The number of kernel layers. | ||
:type n_layers: int | ||
:param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer. | ||
:type internal_n_layers: int | ||
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. | ||
:type internal_layers: list | tuple | ||
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. | ||
:type internal_func: torch.nn.Module | ||
:param external_func: The activation function applied to the output of the Graph Integral Kernel. | ||
:type external_func: torch.nn.Module | ||
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. | ||
:type shared_weights: bool | ||
""" | ||
|
||
if internal_func is None: | ||
internal_func = Tanh | ||
if external_func is None: | ||
external_func = Tanh | ||
|
||
super().__init__( | ||
lifting_operator=lifting_operator, | ||
integral_kernels=GraphNeuralKernel( | ||
width=lifting_operator.out_features, | ||
edge_features=edge_features, | ||
internal_n_layers=internal_n_layers, | ||
inner_size=inner_size, | ||
internal_layers=internal_layers, | ||
external_func=external_func, | ||
internal_func=internal_func, | ||
n_layers=n_layers, | ||
shared_weights=shared_weights | ||
), | ||
projection_operator=projection_operator | ||
) | ||
|
||
def forward(self, x): | ||
""" | ||
The forward pass of the Graph Neural Operator. | ||
:param x: The input batch. | ||
:type x: torch_geometric.data.Batch | ||
""" | ||
x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr | ||
x = self.lifting_operator(x) | ||
x = self.integral_kernels(x, edge_index, edge_attr) | ||
x = self.projection_operator(x) | ||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
from torch_geometric.nn import MessagePassing | ||
|
||
|
||
class GNOBlock(MessagePassing): | ||
""" | ||
TODO: Add documentation | ||
""" | ||
|
||
def __init__( | ||
self, | ||
width, | ||
edges_features, | ||
n_layers=2, | ||
layers=None, | ||
inner_size=None, | ||
internal_func=None, | ||
external_func=None | ||
): | ||
""" | ||
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric. | ||
:param width: The width of the hidden representation of the nodes features | ||
:type width: int | ||
:param edges_features: The number of edge features. | ||
:type edges_features: int | ||
:param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features. | ||
:type n_layers: int | ||
""" | ||
from pina.model import FeedForward | ||
super(GNOBlock, self).__init__(aggr='mean') | ||
self.width = width | ||
if layers is None and inner_size is None: | ||
inner_size = width | ||
self.dense = FeedForward(input_dimensions=edges_features, | ||
output_dimensions=width ** 2, | ||
n_layers=n_layers, | ||
layers=layers, | ||
inner_size=inner_size, | ||
func=internal_func) | ||
self.W = torch.nn.Linear(width, width) | ||
self.func = external_func() | ||
|
||
def message(self, x_j, edge_attr): | ||
""" | ||
This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class. | ||
:param x_j: The node features of the neighboring. | ||
:type x_j: torch.Tensor | ||
:param edge_attr: The edge features. | ||
:type edge_attr: torch.Tensor | ||
:return: The message passed between the nodes of the graph. | ||
:rtype: torch.Tensor | ||
""" | ||
x = self.dense(edge_attr).view(-1, self.width, self.width) | ||
return torch.einsum('bij,bj->bi', x, x_j) | ||
|
||
def update(self, aggr_out, x): | ||
""" | ||
This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class. | ||
:param aggr_out: The aggregated messages. | ||
:type aggr_out: torch.Tensor | ||
:param x: The node features. | ||
:type x: torch.Tensor | ||
:return: The updated node features. | ||
:rtype: torch.Tensor | ||
""" | ||
aggr_out = aggr_out + self.W(x) | ||
return aggr_out | ||
|
||
def forward(self, x, edge_index, edge_attr): | ||
""" | ||
The forward pass of the Graph Integral Layer. | ||
:param x: Node features. | ||
:type x: torch.Tensor | ||
:param edge_index: Edge index. | ||
:type edge_index: torch.Tensor | ||
:param edge_attr: Edge features. | ||
:type edge_attr: torch.Tensor | ||
:return: Output of a single iteration over the Graph Integral Layer. | ||
:rtype: torch.Tensor | ||
""" | ||
return self.func( | ||
self.propagate(edge_index, x=x, edge_attr=edge_attr) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import torch | ||
import pytest | ||
from pina import Condition, LabelTensor, Graph | ||
from pina.condition import InputOutputPointsCondition, DomainEquationCondition | ||
from pina.graph import RadiusGraph | ||
from pina.problem import AbstractProblem, SpatialProblem | ||
from pina.domain import CartesianDomain | ||
from pina.equation.equation import Equation | ||
from pina.equation.equation_factory import FixedValue | ||
from pina.operators import laplacian | ||
|
||
def test_supervised_tensor_collector(): | ||
class SupervisedProblem(AbstractProblem): | ||
output_variables = None | ||
conditions = { | ||
'data1' : Condition(input_points=torch.rand((10,2)), | ||
output_points=torch.rand((10,2))), | ||
'data2' : Condition(input_points=torch.rand((20,2)), | ||
output_points=torch.rand((20,2))), | ||
'data3' : Condition(input_points=torch.rand((30,2)), | ||
output_points=torch.rand((30,2))), | ||
} | ||
problem = SupervisedProblem() | ||
collector = problem.collector | ||
for v in collector.conditions_name.values(): | ||
assert v in problem.conditions.keys() | ||
assert all(collector._is_conditions_ready.values()) | ||
|
||
def test_pinn_collector(): | ||
def laplace_equation(input_, output_): | ||
force_term = (torch.sin(input_.extract(['x']) * torch.pi) * | ||
torch.sin(input_.extract(['y']) * torch.pi)) | ||
delta_u = laplacian(output_.extract(['u']), input_) | ||
return delta_u - force_term | ||
|
||
my_laplace = Equation(laplace_equation) | ||
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y']) | ||
out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u']) | ||
class Poisson(SpatialProblem): | ||
output_variables = ['u'] | ||
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) | ||
|
||
conditions = { | ||
'gamma1': | ||
Condition(domain=CartesianDomain({ | ||
'x': [0, 1], | ||
'y': 1 | ||
}), | ||
equation=FixedValue(0.0)), | ||
'gamma2': | ||
Condition(domain=CartesianDomain({ | ||
'x': [0, 1], | ||
'y': 0 | ||
}), | ||
equation=FixedValue(0.0)), | ||
'gamma3': | ||
Condition(domain=CartesianDomain({ | ||
'x': 1, | ||
'y': [0, 1] | ||
}), | ||
equation=FixedValue(0.0)), | ||
'gamma4': | ||
Condition(domain=CartesianDomain({ | ||
'x': 0, | ||
'y': [0, 1] | ||
}), | ||
equation=FixedValue(0.0)), | ||
'D': | ||
Condition(domain=CartesianDomain({ | ||
'x': [0, 1], | ||
'y': [0, 1] | ||
}), | ||
equation=my_laplace), | ||
'data': | ||
Condition(input_points=in_, output_points=out_) | ||
} | ||
|
||
def poisson_sol(self, pts): | ||
return -(torch.sin(pts.extract(['x']) * torch.pi) * | ||
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) | ||
|
||
truth_solution = poisson_sol | ||
|
||
problem = Poisson() | ||
collector = problem.collector | ||
for k,v in problem.conditions.items(): | ||
if isinstance(v, InputOutputPointsCondition): | ||
assert collector._is_conditions_ready[k] == True | ||
assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points'] | ||
else: | ||
assert collector._is_conditions_ready[k] == False | ||
assert collector.data_collections[k] == {} | ||
|
||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] | ||
problem.discretise_domain(10, 'grid', locations=boundaries) | ||
problem.discretise_domain(10, 'grid', locations='D') | ||
assert all(collector._is_conditions_ready.values()) | ||
for k,v in problem.conditions.items(): | ||
if isinstance(v, DomainEquationCondition): | ||
assert list(collector.data_collections[k].keys()) == ['input_points', 'equation'] | ||
|
||
|
||
def test_supervised_graph_collector(): | ||
pos = torch.rand((100,3)) | ||
x = [torch.rand((100,3)) for _ in range(10)] | ||
graph_list_1 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) | ||
out_1 = torch.rand((10,100,3)) | ||
pos = torch.rand((50,3)) | ||
x = [torch.rand((50,3)) for _ in range(10)] | ||
graph_list_2 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) | ||
out_2 = torch.rand((10,50,3)) | ||
class SupervisedProblem(AbstractProblem): | ||
output_variables = None | ||
conditions = { | ||
'data1' : Condition(input_points=graph_list_1, | ||
output_points=out_1), | ||
'data2' : Condition(input_points=graph_list_2, | ||
output_points=out_2), | ||
} | ||
|
||
problem = SupervisedProblem() | ||
collector = problem.collector | ||
assert all(collector._is_conditions_ready.values()) | ||
for v in collector.conditions_name.values(): | ||
assert v in problem.conditions.keys() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import pytest | ||
import torch | ||
from pina.graph import RadiusGraph, KNNGraph | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"x, pos", | ||
[ | ||
([torch.rand(10, 2) for _ in range(3)], | ||
[torch.rand(10, 3) for _ in range(3)]), | ||
([torch.rand(10, 2) for _ in range(3)], | ||
[torch.rand(10, 3) for _ in range(3)]), | ||
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)), | ||
(torch.rand(3, 10, 2), torch.rand(3, 10, 3)), | ||
] | ||
) | ||
def test_build_multiple_graph_multiple_val(x, pos): | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) | ||
assert len(graph.data) == 3 | ||
data = graph.data | ||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) | ||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) | ||
data = graph.data | ||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) | ||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
assert all(d.edge_attr is not None for d in data) | ||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) | ||
|
||
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) | ||
data = graph.data | ||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) | ||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
assert all(d.edge_attr is not None for d in data) | ||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) | ||
|
||
|
||
def test_build_single_graph_multiple_val(): | ||
x = torch.rand(10, 2) | ||
pos = torch.rand(10, 3) | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) | ||
assert len(graph.data) == 1 | ||
data = graph.data | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos).all() for d_ in data) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) | ||
data = graph.data | ||
assert len(graph.data) == 1 | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos).all() for d_ in data) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
assert all(d.edge_attr is not None for d in data) | ||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) | ||
|
||
x = torch.rand(10, 2) | ||
pos = torch.rand(10, 3) | ||
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) | ||
assert len(graph.data) == 1 | ||
data = graph.data | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos).all() for d_ in data) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) | ||
data = graph.data | ||
assert len(graph.data) == 1 | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos).all() for d_ in data) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
assert all(d.edge_attr is not None for d in data) | ||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pos", | ||
[ | ||
([torch.rand(10, 3) for _ in range(3)]), | ||
([torch.rand(10, 3) for _ in range(3)]), | ||
(torch.rand(3, 10, 3)), | ||
(torch.rand(3, 10, 3)) | ||
] | ||
) | ||
def test_build_single_graph_single_val(pos): | ||
x = torch.rand(10, 2) | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) | ||
assert len(graph.data) == 3 | ||
data = graph.data | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) | ||
data = graph.data | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
assert all(d.edge_attr is not None for d in data) | ||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) | ||
x = torch.rand(10, 2) | ||
graph = KNNGraph(x=x, pos=pos, build_edge_attr=False, k=3) | ||
assert len(graph.data) == 3 | ||
data = graph.data | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) | ||
data = graph.data | ||
assert all(torch.isclose(d.x, x).all() for d in data) | ||
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) | ||
assert all(len(d.edge_index) == 2 for d in data) | ||
assert all(d.edge_attr is not None for d in data) | ||
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) | ||
|
||
|
||
def test_additional_parameters_1(): | ||
x = torch.rand(3, 10, 2) | ||
pos = torch.rand(3, 10, 2) | ||
additional_parameters = {'y': torch.ones(3)} | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, | ||
additional_params=additional_parameters) | ||
assert len(graph.data) == 3 | ||
data = graph.data | ||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) | ||
assert all(hasattr(d, 'y') for d in data) | ||
assert all(d_.y == 1 for d_ in data) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"additional_parameters", | ||
[ | ||
({'y': torch.rand(3, 10, 1)}), | ||
({'y': [torch.rand(10, 1) for _ in range(3)]}), | ||
] | ||
) | ||
def test_additional_parameters_2(additional_parameters): | ||
x = torch.rand(3, 10, 2) | ||
pos = torch.rand(3, 10, 2) | ||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, | ||
additional_params=additional_parameters) | ||
assert len(graph.data) == 3 | ||
data = graph.data | ||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) | ||
assert all(hasattr(d, 'y') for d in data) | ||
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) | ||
|
||
def test_custom_build_edge_attr_func(): | ||
x = torch.rand(3, 10, 2) | ||
pos = torch.rand(3, 10, 2) | ||
|
||
def build_edge_attr(x, pos, edge_index): | ||
return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1) | ||
|
||
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, | ||
custom_build_edge_attr=build_edge_attr) | ||
assert len(graph.data) == 3 | ||
data = graph.data | ||
assert all(hasattr(d, 'edge_attr') for d in data) | ||
assert all(d.edge_attr.shape[1] == 4 for d in data) | ||
assert all(torch.isclose(d.edge_attr, | ||
build_edge_attr(d.x, d.pos, d.edge_index)).all() | ||
for d in data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import pytest | ||
import torch | ||
from pina.graph import KNNGraph | ||
from pina.model import GraphNeuralOperator | ||
from torch_geometric.data import Batch | ||
|
||
x = [torch.rand(100, 6) for _ in range(10)] | ||
pos = [torch.rand(100, 3) for _ in range(10)] | ||
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=6) | ||
input_ = Batch.from_data_list(graph.data) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shared_weights", | ||
[ | ||
True, | ||
False | ||
] | ||
) | ||
def test_constructor(shared_weights): | ||
lifting_operator = torch.nn.Linear(6, 16) | ||
projection_operator = torch.nn.Linear(16, 3) | ||
GraphNeuralOperator(lifting_operator=lifting_operator, | ||
projection_operator=projection_operator, | ||
edge_features=3, | ||
internal_layers=[16, 16], | ||
shared_weights=shared_weights) | ||
|
||
GraphNeuralOperator(lifting_operator=lifting_operator, | ||
projection_operator=projection_operator, | ||
edge_features=3, | ||
inner_size=16, | ||
internal_n_layers=10, | ||
shared_weights=shared_weights) | ||
|
||
int_func = torch.nn.Softplus | ||
ext_func = torch.nn.ReLU | ||
|
||
GraphNeuralOperator(lifting_operator=lifting_operator, | ||
projection_operator=projection_operator, | ||
edge_features=3, | ||
internal_n_layers=10, | ||
shared_weights=shared_weights, | ||
internal_func=int_func, | ||
external_func=ext_func) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shared_weights", | ||
[ | ||
True, | ||
False | ||
] | ||
) | ||
def test_forward_1(shared_weights): | ||
lifting_operator = torch.nn.Linear(6, 16) | ||
projection_operator = torch.nn.Linear(16, 3) | ||
model = GraphNeuralOperator(lifting_operator=lifting_operator, | ||
projection_operator=projection_operator, | ||
edge_features=3, | ||
internal_layers=[16, 16], | ||
shared_weights=shared_weights) | ||
output_ = model(input_) | ||
assert output_.shape == torch.Size([1000, 3]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shared_weights", | ||
[ | ||
True, | ||
False | ||
] | ||
) | ||
def test_forward_2(shared_weights): | ||
lifting_operator = torch.nn.Linear(6, 16) | ||
projection_operator = torch.nn.Linear(16, 3) | ||
model = GraphNeuralOperator(lifting_operator=lifting_operator, | ||
projection_operator=projection_operator, | ||
edge_features=3, | ||
inner_size=32, | ||
internal_n_layers=2, | ||
shared_weights=shared_weights) | ||
output_ = model(input_) | ||
assert output_.shape == torch.Size([1000, 3]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shared_weights", | ||
[ | ||
True, | ||
False | ||
] | ||
) | ||
def test_backward(shared_weights): | ||
lifting_operator = torch.nn.Linear(6, 16) | ||
projection_operator = torch.nn.Linear(16, 3) | ||
model = GraphNeuralOperator(lifting_operator=lifting_operator, | ||
projection_operator=projection_operator, | ||
edge_features=3, | ||
internal_layers=[16, 16], | ||
shared_weights=shared_weights) | ||
input_.x.requires_grad = True | ||
output_ = model(input_) | ||
l = torch.mean(output_) | ||
l.backward() | ||
assert input_.x.grad.shape == torch.Size([1000, 6]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shared_weights", | ||
[ | ||
True, | ||
False | ||
] | ||
) | ||
def test_backward_2(shared_weights): | ||
lifting_operator = torch.nn.Linear(6, 16) | ||
projection_operator = torch.nn.Linear(16, 3) | ||
model = GraphNeuralOperator(lifting_operator=lifting_operator, | ||
projection_operator=projection_operator, | ||
edge_features=3, | ||
inner_size=32, | ||
internal_n_layers=2, | ||
shared_weights=shared_weights) | ||
input_.x.requires_grad = True | ||
output_ = model(input_) | ||
l = torch.mean(output_) | ||
l.backward() | ||
assert input_.x.grad.shape == torch.Size([1000, 6]) |
Uh oh!
There was an error while loading. Please reload this page.