Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delaunay Lifting (pointcloud2graph) with new ShapeNet dataset #60

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions configs/datasets/ShapeNet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
data_domain: point_cloud
data_type: ShapeNet
data_name: ShapeNet
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Dataset parameters
num_features: 3
num_classes: 50
category: plane
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph
4 changes: 4 additions & 0 deletions configs/models/graph/graphsage.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
in_channels: -1 # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
n_layers: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transform_type: "lifting"
transform_name: "GraphDelaunayLifting"
feature_lifting: ProjectionSum
6 changes: 6 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def load(self) -> torch_geometric.data.Dataset:
data = load_manual_graph()
dataset = CustomDataset([data], self.data_dir)

elif self.parameters.data_name in ["ShapeNet"]:
dataset = torch_geometric.datasets.ShapeNet(
root=root_data_dir,
include_normals=True,
)

else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
Expand Down
80 changes: 80 additions & 0 deletions modules/models/graph/graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from torch import Tensor
from torch_geometric.nn.models import GraphSAGE
from torch_geometric.utils import scatter


def global_mean_pool(x, batch=None, size=None) -> Tensor:
r"""Returns batch-wise graph-level-outputs by averaging node features
across the node dimension.

For a single graph :math:`\mathcal{G}_i`, its output is computed by

.. math::
\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n.

Functional method of the
:class:`~torch_geometric.nn.aggr.MeanAggregation` module.

Parameters
----------
x : torch.Tensor
Node feature matrix :math:`\mathbf{X}`.
batch : torch.Tensor, optional
The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`,
which assigns each node to a specific example.
size : int, optional
The number of examples :math:`B`. Automatically calculated if not given.
"""
dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2

if batch is None:
return x.mean(dim=dim, keepdim=x.dim() <= 2)
return scatter(x, batch, dim=dim, dim_size=size, reduce="mean")


class GraphSAGEModel(torch.nn.Module):
r"""A simple GreaphSage model that runs over graph data.
Note that some parameters are defined by the considered dataset.

Parameters
----------
model_config : Dict | DictConfig
Model configuration.
dataset_config : Dict | DictConfig
Dataset configuration.
"""

def __init__(self, model_config, dataset_config):
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
super().__init__()
self.base_model = GraphSAGE(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
num_layers=n_layers,
)
self.pool = global_mean_pool

def forward(self, data):
r"""Forward pass of the model.

Parameters
----------
data : torch_geometric.data.Data
Input data.

Returns
-------
torch.Tensor
Output tensor.
"""
z = self.base_model(data.x, data.edge_index)
return self.pool(z)
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2graph.delaunay_lifting import (
GraphDelaunayLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -31,6 +34,7 @@
"OneHotDegreeFeatures": OneHotDegreeFeatures,
"NodeFeaturesToFloat": NodeFeaturesToFloat,
"KeepOnlyConnectedComponent": KeepOnlyConnectedComponent,
"GraphDelaunayLifting": GraphDelaunayLifting,
}


Expand Down
77 changes: 77 additions & 0 deletions modules/transforms/liftings/pointcloud2graph/delaunay_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import torch_geometric
from torch_geometric.transforms import Delaunay
from torch_geometric.utils import to_undirected

from modules.transforms.liftings.pointcloud2graph.base import PointCloud2GraphLifting


class GraphDelaunayLifting(PointCloud2GraphLifting):
r"""Lifts point cloud to graph domain by considering k-nearest neighbors.

Parameters
----------
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.transform = Delaunay()

def face_to_edge(self, data: torch_geometric.data.Data):
r"""Converts mesh faces to edges indices for both 2D and 3D meshes.

Parameters
----------
data : torch_geometric.data.Data
The input data to be converted.

Returns
-------
torch_geometric.data.Data
The converted data.
"""
if hasattr(data, "face"):
assert data.face is not None
face = data.face
if face.shape[0] == 3:
# 2D
edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
elif face.shape[0] == 4:
# 3D
edge_index = torch.cat(
[face[:2], face[1:3], face[2:], face[::3]], dim=1
)
edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
else:
raise ValueError("Faces must be of dimension 2 or 3.")
data.edge_index = edge_index
return data

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts the topology of a graph to hypergraph domain by considering k-nearest neighbors.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
dict
The lifted topology.
"""
num_nodes = data.x.shape[0]

# Step 1: Perform Delaunay Triangulation to get faces
data_delaunay = self.transform(data)
faces = data_delaunay.face
# Step 2: Create Edge List from faces
data = self.face_to_edge(data_delaunay)

# Step 3: Convert Edge List to edge_index format
edge_index = data.edge_index

return {"num_nodes": num_nodes, "edge_index": edge_index, "face": faces}
38 changes: 38 additions & 0 deletions test/transforms/liftings/pointcloud2graph/test_delaunay_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from torch_geometric.data import Data

from modules.transforms.liftings.pointcloud2graph.delaunay_lifting import (
GraphDelaunayLifting,
)


class TestDelaunayLifting:
"""Test the GraphDelaunayLifting class."""

def setup_method(self):
"""Set up the test."""
# Define the data and the lifting.
pos = torch.tensor(
[
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.5],
[1.0, 1.0, 0.5],
[0.5, 0.5, 1.0],
],
dtype=torch.float32,
)
x = torch.tensor(
[[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]],
dtype=torch.float32,
)
self.data = Data(x=x, pos=pos)
self.lifting = GraphDelaunayLifting()

def test_lift_topology(self):
"""Test the lift_topology method."""

lifted = self.lifting.forward(self.data.clone())
assert lifted.num_nodes == 5, "The number of nodes is incorrect."
assert lifted.edge_index.shape == (2, 14), "The number of edges is incorrect."
assert lifted.face.shape == (4, 2), "The number of faces is incorrect."
358 changes: 358 additions & 0 deletions tutorials/pointcloud2graph/delaunay_lifting.ipynb

Large diffs are not rendered by default.

Loading