Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Coerulatus committed Apr 30, 2024
1 parent b61fc89 commit bde562e
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test_codebase.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: "Testing Codebase"

on:
on:
workflow_dispatch:
push:
branches: [main,github-actions-test]
Expand Down Expand Up @@ -54,7 +54,7 @@ jobs:
pip install -e .[all]
- name: Run tests for codebase [pytest]
run: |
pytest -n 2 --cov --cov-report=xml:coverage.xml test/transforms/feature_liftings test/transforms/liftings
pytest -n 2 --cov --cov-report=xml:coverage.xml test/transforms/feature_liftings test/transforms/liftings
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: "Testing Tutorials"

on:
on:
workflow_dispatch:
push:
branches: [main,github-actions-test]
Expand Down Expand Up @@ -54,6 +54,6 @@ jobs:
- name: Install main package
run: |
pip install -e .[all]
- name: Run tests for tutorials on all domains [pytest]
- name: Run tests for tutorials on all domains [pytest]
run: |
pytest test/tutorials/test_tutorials.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Official repository for the ICML 2024 Topological Deep Learning Challenge, hoste
- Please, check out the [main webpage of the challenge](https://pyt-team.github.io/packs/challenge.html) for the full description of the competition (motivation, submission requirements, evaluation, etc.)

## Brief Description
The main purpose of the challenge is to further expand the current scope and impact of Topological Deep Learning (TDL), enabling the exploration of its applicability in new contexts and scenarios. To do so, we propose participants to design and implement lifting mappings between different data structures and topological domains (point-clouds, graphs, hypergraphs, simplicial/cell/combinatorial complexes), potentially bridging the gap between TDL and all kinds of existing datasets.
The main purpose of the challenge is to further expand the current scope and impact of Topological Deep Learning (TDL), enabling the exploration of its applicability in new contexts and scenarios. To do so, we propose participants to design and implement lifting mappings between different data structures and topological domains (point-clouds, graphs, hypergraphs, simplicial/cell/combinatorial complexes), potentially bridging the gap between TDL and all kinds of existing datasets.


## General Guidelines
Expand Down
2 changes: 1 addition & 1 deletion modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from omegaconf import DictConfig

from modules.data.load.base import AbstractLoader
from modules.data.utils.custom_dataset import CustomDataset
from modules.data.utils.concat2geometric_dataset import ConcatToGeometricDataset
from modules.data.utils.custom_dataset import CustomDataset
from modules.data.utils.utils import (
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
Expand Down
11 changes: 6 additions & 5 deletions modules/data/utils/concat2geometric_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch_geometric.data import Data, Dataset


class ConcatToGeometricDataset(Dataset):
def __init__(self, concat_dataset):
super().__init__()
Expand All @@ -11,18 +12,18 @@ def len(self):

def get(self, idx):
data = self.concat_dataset[idx]

x = data.x.float()
edge_index = data.edge_index
edge_attr = data.edge_attr
y = data.y
if len(x.shape)==1:
if len(x.shape) == 1:
x = x.unsqueeze(dim=1)
if len(edge_attr.shape)==1:
if len(edge_attr.shape) == 1:
edge_attr = edge_attr.unsqueeze(dim=1)
if len(y.shape)==1:
if len(y.shape) == 1:
y = y.unsqueeze(dim=1)

# Construct PyTorch Geometric Data object
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
return data
return data
27 changes: 21 additions & 6 deletions modules/transforms/liftings/graph2hypergraph/knn_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,28 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
# check for loops, since KNNGraph is inconsistent with nodes with equal features
if self.loop:
for i in range(num_nodes):
if not torch.any(torch.all(data_lifted.edge_index == torch.tensor([[i,i]]).T, dim=0)):
connected_nodes = data_lifted.edge_index[0, data_lifted.edge_index[1]==i]
dists = torch.sqrt(torch.sum((data.pos[connected_nodes]-data.pos[i].unsqueeze(0)**2),dim=1))
if not torch.any(
torch.all(data_lifted.edge_index == torch.tensor([[i, i]]).T, dim=0)
):
connected_nodes = data_lifted.edge_index[
0, data_lifted.edge_index[1] == i
]
dists = torch.sqrt(
torch.sum(
(data.pos[connected_nodes] - data.pos[i].unsqueeze(0) ** 2),
dim=1,
)
)
furthest = torch.argmax(dists)
idx = torch.where(torch.all(data_lifted.edge_index==torch.tensor([[connected_nodes[furthest],i]]).T, dim=0))[0]
data_lifted.edge_index[:,idx] = torch.tensor([[i,i]]).T

idx = torch.where(
torch.all(
data_lifted.edge_index
== torch.tensor([[connected_nodes[furthest], i]]).T,
dim=0,
)
)[0]
data_lifted.edge_index[:, idx] = torch.tensor([[i, i]]).T

incidence_1[data_lifted.edge_index[1], data_lifted.edge_index[0]] = 1
incidence_1 = torch.Tensor(incidence_1).to_sparse_coo()
return {
Expand Down
4 changes: 2 additions & 2 deletions modules/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pprint
import random
import shutil

import matplotlib.pyplot as plt
import networkx as nx
Expand All @@ -9,9 +10,8 @@
import torch
import torch_geometric
from matplotlib.patches import Polygon
import shutil

plt.rcParams['text.usetex']= True if shutil.which('latex') else False
plt.rcParams["text.usetex"] = True if shutil.which("latex") else False
rootutils.setup_root("./", indicator=".project-root", pythonpath=True)


Expand Down

0 comments on commit bde562e

Please sign in to comment.