Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixture of Gaussians + MST lifting (Pointcloud to Hypergraph) #45

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,7 @@ cython_debug/
#.idea/

# VS Code
.vscode/
.vscode/

tutorials/pointcloud2hypergraph/data
tutorials/pointcloud2hypergraph/modules/transforms/liftings/pointcloud2hypergraph/processed
12 changes: 12 additions & 0 deletions configs/datasets/geo_shapes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: geo_shapes
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
dim: 1
num_classes: 2
num_samples: 24
num_features: 1
task: classification
loss_type: cross_entropy
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transform_type: 'lifting'
transform_name: "MoGMSTLifting"
min_components: null
max_components: null
random_state: null
43 changes: 43 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_random_shape_point_cloud,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -204,3 +205,45 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class PointCloudLoader(AbstractLoader):
r"""Loader for point-cloud dataset.

Parameters
----------
parameters: DictConfig
Configuration parameters
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)

if "data_name" not in self.cfg:
self.cfg["data_name"] = "shapes"
if "num_points" not in self.cfg:
self.cfg["num_points"] = 24
if "num_classes" not in self.cfg:
self.cfg["num_classes"] = 2

root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.cfg["data_dir"])
self.data_dir = os.path.join(root_data_dir, self.cfg["data_name"])

def load(self) -> torch_geometric.data.Dataset:
r"""Load point-cloud dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
if self.cfg["data_name"] == "shapes":
return load_random_shape_point_cloud(
num_points=self.cfg["num_points"], num_classes=self.cfg["num_classes"]
)
return None
29 changes: 21 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import toponetx.datasets.graph as graph
import torch
import torch_geometric
import torch_geometric.transforms as T
from topomodelx.utils.sparse import from_sparse
from torch_geometric.data import Data
from torch_geometric.datasets import GeometricShapes
from torch_sparse import coalesce


Expand Down Expand Up @@ -50,16 +52,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand All @@ -83,6 +85,17 @@ def generate_zero_sparse_connectivity(m, n):
return torch.sparse_coo_tensor((m, n)).coalesce()


def load_random_shape_point_cloud(seed=None, num_points=64, num_classes=2):
"""Create a toy point cloud dataset"""
rng = np.random.default_rng(seed)
dataset = GeometricShapes(root="data/GeometricShapes")
dataset.transform = T.SamplePoints(num=num_points)
data = dataset[rng.integers(40)]
data.y = rng.integers(num_classes, size=num_points)
data.x = torch.tensor(rng.integers(2, size=(num_points, 1)), dtype=torch.float)
return data


def load_cell_complex_dataset(cfg):
r"""Loads cell complex datasets."""

Expand Down
5 changes: 5 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2hypergraph.mogmst_lifting import (
MoGMSTLifting,
)

TRANSFORMS = {
# PointCloud -> Hypergraph
"MoGMSTLifting": MoGMSTLifting,
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
# Graph -> Simplicial Complex
Expand Down
105 changes: 105 additions & 0 deletions modules/transforms/liftings/pointcloud2hypergraph/mogmst_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np
import torch
import torch_geometric
from networkx import from_numpy_array, minimum_spanning_tree
from sklearn.metrics import pairwise_distances
from sklearn.mixture import GaussianMixture

from modules.transforms.liftings.pointcloud2hypergraph.base import (
PointCloud2HypergraphLifting,
)


class MoGMSTLifting(PointCloud2HypergraphLifting):
def __init__(
self, min_components=None, max_components=None, random_state=None, **kwargs
):
super().__init__(**kwargs)
if min_components is not None:
assert (
min_components > 0
), "Minimum number of components should be at least 1"
if max_components is not None:
assert (
max_components > 0
), "Maximum number of components should be at least 1"
if min_components is not None and max_components is not None:
assert min_components <= max_components, (
"Minimum number of components must be lower or equal to the"
" maximum number of components."
)
self.min_components = min_components
self.max_components = max_components
self.random_state = random_state

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
# Find a mix of Gaussians
labels, num_components, means = self.find_mog(data.pos.numpy())

# Create MST
distance_matrix = pairwise_distances(means)
original_graph = from_numpy_array(distance_matrix)
mst = minimum_spanning_tree(original_graph)

# Create hypergraph incidence
number_of_points = data.pos.shape[0]
incidence = torch.zeros((number_of_points, 2 * num_components))

# Add to which Gaussian the points belong to
nodes = torch.arange(0, number_of_points, dtype=torch.int32)
lbls = torch.tensor(labels, dtype=torch.int32)
values = torch.ones(number_of_points)
incidence[nodes, lbls] = values

# Add neighbours in MST
for i, j in mst.edges():
mask_i = labels == i
mask_j = labels == j
incidence[mask_i, num_components + j] = 1
incidence[mask_j, num_components + i] = 1

incidence = incidence.clone().detach().to_sparse_coo()
return {
"incidence_hyperedges": incidence,
"num_hyperedges": 2 * num_components,
"x_0": data.x,
}

def find_mog(self, data) -> tuple[np.ndarray, int, np.ndarray]:
if self.min_components is not None and self.max_components is not None:
possible_num_components = range(
self.min_components, self.max_components + 1
)
elif self.min_components is None and self.max_components is None:
possible_num_components = [
2**i for i in range(1, int(np.log2(data.shape[0] / 2)) + 1)
]
else:
if self.min_components is not None:
num_components = self.min_components
elif self.max_components is not None:
num_components = self.max_components
else:
# Cannot happen
num_components = 1

gm = GaussianMixture(
n_components=num_components, random_state=self.random_state
)
labels = gm.fit_predict(data)
return labels, num_components, gm.means_

best_score = float("inf")
best_labels = None
best_num_components = 0
means = None
for i in possible_num_components:
gm = GaussianMixture(n_components=i, random_state=self.random_state)
labels = gm.fit_predict(data)
score = gm.aic(data)
if score < best_score:
best_score = score
best_labels = labels
best_num_components = i
means = gm.means_
return best_labels, best_num_components, means
2 changes: 1 addition & 1 deletion modules/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def sort_vertices_ccw(vertices):
n_hyperedges = incidence.shape[1]
vertices += [i + n_vertices for i in range(n_hyperedges)]
indices = incidence.indices()
edges = np.array([indices[1].numpy(), indices[0].numpy() + n_vertices]).T
edges = np.array([indices[0].numpy(), indices[1].numpy() + n_vertices]).T
pos_n = [[i, 0] for i in range(n_vertices)]
pos_he = [[i, 1] for i in range(n_hyperedges)]
pos = pos_n + pos_he
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Test the message passing module."""

import numpy as np
import torch
from torch_geometric.data import Data

from modules.transforms.liftings.pointcloud2hypergraph.mogmst_lifting import (
MoGMSTLifting,
)


class TestMoGMSTLifting:
"""Test the MoGMSTLifting class."""

def setup_method(self):
# Load the graph
x = torch.tensor([[0.0] for i in range(8)])
y = torch.tensor([0 for i in range(8)], dtype=torch.int32)
pos = torch.tensor(
[
[-1.44, -1.55],
[-2, -2],
[-1.18, -2.38],
[-1.26, 3.28],
[-0.59, 3.68],
[-0.7, 3.33],
[0.52, 0.09],
[0.16, 0.45],
]
)
self.data = Data(x=x, pos=pos, y=torch.tensor(y))

# Initialise the HypergraphKHopLifting class
self.lifting = MoGMSTLifting(min_components=3, random_state=0)

def test_find_mog(self):
labels, num_components, means = self.lifting.find_mog(
self.data.clone().pos.numpy()
)

assert num_components == 3, "Wrong number of components"

assert (
labels[0] == labels[1] == labels[2]
and labels[3] == labels[4] == labels[5]
and labels[6] == labels[7]
and labels[0] != labels[3]
and labels[3] != labels[6]
and labels[0] != labels[6]
), "Labels have not been assigned correctly"

def test_lift_topology(self):
# Test the lift_topology method
lifted_data_k = self.lifting.forward(self.data.clone())

expected_n_hyperedges = 6

assert (
lifted_data_k.num_hyperedges == expected_n_hyperedges
), "Wrong number of hyperedges (k=1)"

incidence_np = lifted_data_k.incidence_hyperedges.to_dense().numpy()
asg_inc = incidence_np[:, :3]
mst_inc = incidence_np[:, 3:]

assert (
(
(asg_inc[:3] == asg_inc[0]).all()
and (asg_inc[3:6] == asg_inc[3]).all()
and (asg_inc[6] == asg_inc[7]).all()
)
and not (asg_inc[0] == asg_inc[3]).all()
and not (asg_inc[0] == asg_inc[6]).all()
and not (asg_inc[3] == asg_inc[6]).all()
), "Something went wrong with point assignment to means"

assert (
(mst_inc[:6] == mst_inc[0]).all()
and (mst_inc[6] == mst_inc[7]).all()
and np.sum(mst_inc[0]) == 1
and np.sum(mst_inc[6]) == 2
), "Something went wrong with MST calculation/incidence matrix creation"
Loading
Loading