From 37dad42f2a57bfd1603c5761c854df41a7a76404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 5 Dec 2024 15:09:15 -0800 Subject: [PATCH 1/4] Add .gitattributes --- .gitattributes | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..9030923a --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.ipynb linguist-vendored \ No newline at end of file From 58b19ab5b54034044d37893fec33873a59bd8c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 5 Dec 2024 18:44:05 -0800 Subject: [PATCH 2/4] Add possibility of defining custom metrics for the evaluator --- configs/evaluator/default.yaml | 2 +- configs/run.yaml | 4 +- topobenchmark/evaluator/__init__.py | 3 + topobenchmark/evaluator/metrics/__init__.py | 108 ++++++++++++++++++++ topobenchmark/evaluator/metrics/example.py | 87 ++++++++++++++++ topobenchmark/utils/config_resolvers.py | 17 +-- 6 files changed, 212 insertions(+), 9 deletions(-) create mode 100644 topobenchmark/evaluator/metrics/__init__.py create mode 100644 topobenchmark/evaluator/metrics/example.py diff --git a/configs/evaluator/default.yaml b/configs/evaluator/default.yaml index 67dcd386..8095f97c 100755 --- a/configs/evaluator/default.yaml +++ b/configs/evaluator/default.yaml @@ -6,5 +6,5 @@ num_classes: ${dataset.parameters.num_classes} # Automatically selects the default metrics depending on the task # Classification: [accuracy, precision, recall, auroc] # Regression: [mae, mse] -metrics: ${get_default_metrics:${evaluator.task}} +metrics: ${get_default_metrics:${evaluator.task},${oc.select:dataset.parameters.metrics,null}} # Select classification/regression config files to manually define the metrics \ No newline at end of file diff --git a/configs/run.yaml b/configs/run.yaml index 192caf21..bb9a396c 100755 --- a/configs/run.yaml +++ b/configs/run.yaml @@ -4,8 +4,8 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: graph/cocitation_cora - - model: graph/gcn_dgm + - dataset: graph/ZINC + - model: cell/topotune - transforms: ${get_default_transform:${dataset},${model}} #tree #${get_default_transform:${dataset},${model}} #no_transform - optimizer: default - loss: default diff --git a/topobenchmark/evaluator/__init__.py b/topobenchmark/evaluator/__init__.py index 923c8bf3..f03c6c9d 100755 --- a/topobenchmark/evaluator/__init__.py +++ b/topobenchmark/evaluator/__init__.py @@ -3,6 +3,8 @@ from torchmetrics.classification import AUROC, Accuracy, Precision, Recall from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError +from .metrics import ExampleRegressionMetric + # Define metrics METRICS = { "accuracy": Accuracy, @@ -11,6 +13,7 @@ "auroc": AUROC, "mae": MeanAbsoluteError, "mse": MeanSquaredError, + "example": ExampleRegressionMetric, } from .base import AbstractEvaluator # noqa: E402 diff --git a/topobenchmark/evaluator/metrics/__init__.py b/topobenchmark/evaluator/metrics/__init__.py new file mode 100644 index 00000000..7250366f --- /dev/null +++ b/topobenchmark/evaluator/metrics/__init__.py @@ -0,0 +1,108 @@ +"""Init file for custom metrics in evaluator module.""" + +import importlib +import inspect +import sys +from pathlib import Path +from typing import Any + + +class LoadManager: + """Manages automatic discovery and registration of loss classes.""" + + @staticmethod + def is_metric_class(obj: Any) -> bool: + """Check if an object is a valid metric class. + + Parameters + ---------- + obj : Any + The object to check if it's a valid loss class. + + Returns + ------- + bool + True if the object is a valid loss class (non-private class + with 'FeatureEncoder' in name), False otherwise. + """ + try: + from torchmetrics import Metric + + return ( + inspect.isclass(obj) + and not obj.__name__.startswith("_") + and issubclass(obj, Metric) + and obj is not Metric + ) + except ImportError: + return False + + @classmethod + def discover_metrics(cls, package_path: str) -> dict[str, type]: + """Dynamically discover all metric classes in the package. + + Parameters + ---------- + package_path : str + Path to the package's __init__.py file. + + Returns + ------- + Dict[str, Type] + Dictionary mapping loss class names to their corresponding class objects. + """ + metrics = {} + package_dir = Path(package_path).parent + + # Add parent directory to sys.path to ensure imports work + parent_dir = str(package_dir.parent) + if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + + # Iterate through all .py files in the directory + for file_path in package_dir.glob("*.py"): + if file_path.stem == "__init__": + continue + + try: + # Use importlib to safely import the module + module_name = f"{package_dir.stem}.{file_path.stem}" + module = importlib.import_module(module_name) + + # Find all loss classes in the module + for name, obj in inspect.getmembers(module): + if ( + cls.is_metric_class(obj) + and obj.__module__ == module.__name__ + ): + metrics[name] = obj # noqa: PERF403 + + except ImportError as e: + print(f"Could not import module {module_name}: {e}") + + return metrics + + +# Dynamically create the loss manager and discover losses +manager = LoadManager() +CUSTOM_METRICS = manager.discover_metrics(__file__) +CUSTOM_METRICS_list = list(CUSTOM_METRICS.keys()) + +# Combine manual and discovered losses +all_metrics = {**CUSTOM_METRICS} + +# Generate __all__ +__all__ = [ + "CUSTOM_METRICS", + "CUSTOM_METRICS_list", + *list(all_metrics.keys()), +] + +# Update locals for direct import +locals().update(all_metrics) + +# from .example import ExampleRegressionMetric + +# __all__ = [ +# "ExampleRegressionMetric", +# ] diff --git a/topobenchmark/evaluator/metrics/example.py b/topobenchmark/evaluator/metrics/example.py new file mode 100644 index 00000000..97ce9b57 --- /dev/null +++ b/topobenchmark/evaluator/metrics/example.py @@ -0,0 +1,87 @@ +"""Loss module for the topobenchmark package.""" + +from typing import Any + +import torch +from torchmetrics import Metric +from torchmetrics.functional.regression.mse import ( + _mean_squared_error_compute, + _mean_squared_error_update, +) + + +class ExampleRegressionMetric(Metric): + r"""Example metric. + + Parameters + ---------- + squared : bool + Whether to compute the squared error (default: True). + num_outputs : int + The number of outputs. + **kwargs : Any + Additional keyword arguments. + """ + + is_differentiable = True + higher_is_better = False + full_state_update = False + + sum_squared_error: torch.Tensor + total: torch.Tensor + + def __init__( + self, + squared: bool = True, + num_outputs: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + if not isinstance(squared, bool): + raise ValueError( + f"Expected argument `squared` to be a boolean but got {squared}" + ) + self.squared = squared + + if not (isinstance(num_outputs, int) and num_outputs > 0): + raise ValueError( + f"Expected num_outputs to be a positive integer but got {num_outputs}" + ) + self.num_outputs = num_outputs + + self.add_state( + "sum_squared_error", + default=torch.zeros(num_outputs), + dist_reduce_fx="sum", + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with predictions and targets. + + Parameters + ---------- + preds : torch.Tensor + Predictions from model. + target : torch.Tensor + Ground truth values. + """ + sum_squared_error, num_obs = _mean_squared_error_update( + preds, target, num_outputs=self.num_outputs + ) + + self.sum_squared_error += sum_squared_error + self.total += num_obs + + def compute(self) -> torch.Tensor: + """Compute mean squared error over state. + + Returns + ------- + torch.Tensor + Mean squared error. + """ + return _mean_squared_error_compute( + self.sum_squared_error, self.total, squared=self.squared + ) diff --git a/topobenchmark/utils/config_resolvers.py b/topobenchmark/utils/config_resolvers.py index cb44617c..e65e77f0 100644 --- a/topobenchmark/utils/config_resolvers.py +++ b/topobenchmark/utils/config_resolvers.py @@ -255,13 +255,15 @@ def infere_num_cell_dimensions(selected_dimensions, in_channels): return len(in_channels) -def get_default_metrics(task): +def get_default_metrics(task, metrics=None): r"""Get default metrics for a given task. Parameters ---------- task : str Task, either "classification" or "regression". + metrics : list, optional + List of metrics to be used. If None, the default metrics will be used. Returns ------- @@ -273,9 +275,12 @@ def get_default_metrics(task): ValueError If the task is invalid. """ - if "classification" in task: - return ["accuracy", "precision", "recall", "auroc"] - elif "regression" in task: - return ["mse", "mae"] + if metrics is not None: + return metrics else: - raise ValueError(f"Invalid task {task}") + if "classification" in task: + return ["accuracy", "precision", "recall", "auroc"] + elif "regression" in task: + return ["mse", "mae"] + else: + raise ValueError(f"Invalid task {task}") From 9271ec4cd2d61975613c9368799b69cdc17d3125 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 5 Dec 2024 19:50:33 -0800 Subject: [PATCH 3/4] Add tests --- test/evaluator/test_evaluator.py | 38 ++++++++++++++++++++++++---- test/utils/test_config_resolvers.py | 3 +++ topobenchmark/evaluator/evaluator.py | 8 ++++-- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/test/evaluator/test_evaluator.py b/test/evaluator/test_evaluator.py index e09eb579..eecc59a0 100644 --- a/test/evaluator/test_evaluator.py +++ b/test/evaluator/test_evaluator.py @@ -1,6 +1,6 @@ """ Test the TBEvaluator class.""" import pytest - +import torch from topobenchmark.evaluator import TBEvaluator class TestTBEvaluator: @@ -8,8 +8,36 @@ class TestTBEvaluator: def setup_method(self): """ Setup the test.""" - self.evaluator_multilable = TBEvaluator(task="multilabel classification") - self.evaluator_regression = TBEvaluator(task="regression") + self.classification_metrics = ["accuracy", "precision", "recall", "auroc"] + self.evaluator_classification = TBEvaluator(task="classification", num_classes=3, metrics=self.classification_metrics) + self.evaluator_multilabel = TBEvaluator(task="multilabel classification", num_classes=2, metrics=self.classification_metrics) + self.regression_metrics = ["example", "mae"] + self.evaluator_regression = TBEvaluator(task="regression", num_classes=1, metrics=self.regression_metrics) with pytest.raises(ValueError): - TBEvaluator(task="wrong") - repr = self.evaluator_multilable.__repr__() \ No newline at end of file + TBEvaluator(task="wrong", num_classes=2, metrics=self.classification_metrics) + + def test_repr(self): + """Test the __repr__ method.""" + assert "TBEvaluator" in self.evaluator_classification.__repr__() + assert "TBEvaluator" in self.evaluator_multilabel.__repr__() + assert "TBEvaluator" in self.evaluator_regression.__repr__() + + def test_update_and_compute(self): + """Test the update and compute methods.""" + self.evaluator_classification.update({"logits": torch.randn(10, 3), "labels": torch.randint(0, 3, (10,))}) + out = self.evaluator_classification.compute() + for metric in self.classification_metrics: + assert metric in out + self.evaluator_multilabel.update({"logits": torch.randn(10, 2), "labels": torch.randint(0, 2, (10, 2))}) + out = self.evaluator_multilabel.compute() + for metric in self.classification_metrics: + assert metric in out + self.evaluator_regression.update({"logits": torch.randn(10, 1), "labels": torch.randn(10,)}) + out = self.evaluator_regression.compute() + for metric in self.regression_metrics: + assert metric in out + + def test_reset(self): + """Test the reset method.""" + self.evaluator_multilabel.reset() + self.evaluator_regression.reset() diff --git a/test/utils/test_config_resolvers.py b/test/utils/test_config_resolvers.py index 9137de1a..6da4697f 100644 --- a/test/utils/test_config_resolvers.py +++ b/test/utils/test_config_resolvers.py @@ -117,6 +117,9 @@ def test_infer_num_cell_dimensions(self): def test_get_default_metrics(self): """Test get_default_metrics.""" + out = get_default_metrics("classification", ["accuracy", "precision"]) + assert out == ["accuracy", "precision"] + out = get_default_metrics("classification") assert out == ["accuracy", "precision", "recall", "auroc"] diff --git a/topobenchmark/evaluator/evaluator.py b/topobenchmark/evaluator/evaluator.py index 8206f87e..c091ca62 100755 --- a/topobenchmark/evaluator/evaluator.py +++ b/topobenchmark/evaluator/evaluator.py @@ -37,6 +37,7 @@ def __init__(self, task, **kwargs): elif self.task == "multilabel classification": parameters = {"num_classes": kwargs["num_classes"]} parameters["task"] = "multilabel" + parameters["num_labels"] = kwargs["num_classes"] metric_names = kwargs["metrics"] elif self.task == "regression": @@ -44,7 +45,7 @@ def __init__(self, task, **kwargs): metric_names = kwargs["metrics"] else: - raise ValueError(f"Invalid task {kwargs['task']}") + raise ValueError(f"Invalid task {task}") metrics = {} for name in metric_names: @@ -83,7 +84,10 @@ def update(self, model_out: dict): if self.task == "regression": self.metrics.update(preds, target.unsqueeze(1)) - elif self.task == "classification": + elif ( + self.task == "classification" + or self.task == "multilabel classification" + ): self.metrics.update(preds, target) else: From 48f9fcf8ea7cbe6a25c5c406b15b0b961eff8588 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Tue, 17 Dec 2024 15:32:06 -0800 Subject: [PATCH 4/4] Update README --- README.md | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 35c7cfeb..e596d9c2 100755 --- a/README.md +++ b/README.md @@ -92,6 +92,8 @@ python -m topobenchmark model=cell/cwn dataset=graph/MUTAG The same CLI override mechanism also applies when modifying more finer configurations within a `CONFIG GROUP`. Please, refer to the official [`hydra`documentation](https://hydra.cc/docs/intro/) for further details. + + ## :bike: Experiments Reproducibility To reproduce Table 1 from the [`TopoBenchmark: A Framework for Benchmarking Topological Deep Learning`](https://arxiv.org/pdf/2406.06642) paper, please run the following command: @@ -116,6 +118,7 @@ We list the neural networks trained and evaluated by `TopoBenchmark`, organized | GAT | [Graph Attention Networks](https://openreview.net/pdf?id=rJXMpikCZ) | | GIN | [How Powerful are Graph Neural Networks?](https://openreview.net/pdf?id=ryGs6iA5Km) | | GCN | [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/pdf/1609.02907v4) | +| GraphMLP | [Graph-MLP: Node Classification without Message Passing in Graph](https://arxiv.org/pdf/2106.04051) | ### Simplicial complexes | Model | Reference | @@ -145,7 +148,7 @@ We list the neural networks trained and evaluated by `TopoBenchmark`, organized ### Combinatorial complexes | Model | Reference | | --- | --- | -| GCCN | [Generalized Combinatorial Complex Neural Networks](https://arxiv.org/pdf/2410.06530) | +| GCCN | [TopoTune: A Framework for Generalized Combinatorial Complex Neural Networks](https://arxiv.org/pdf/2410.06530) | ## :bulb: TopoTune @@ -178,12 +181,17 @@ python -m topobenchmark \ To use a single augmented Hasse graph expansion, use `model={domain}/topotune_onehasse` instead of `model={domain}/topotune`. -To specify a set of neighborhoods (routes) on the complex, use a list of neighborhoods each specified as `\[\[{source_rank}, {destination_rank}\], {neighborhood}\]`. Currently, the following options for `{neighborhood}` are supported: -- `up_laplacian`, from rank $r$ to $r$ -- `down_laplacian`, from rank $r$ to $r$ -- `boundary`, from rank $r$ to $r-1$ -- `coboundary`, from rank $r$ to $r+1$ -- `adjacency`, from rank $r$ to $r$ (stand-in for `up_adjacency`, as `down_adjacency` not yet supported in TopoBenchmark) +To specify a set of neighborhoods on the complex, use a list of neighborhoods each specified as a string of the form +`r-{neighborhood}-k`, where $k$ represents the source cell rank, and $r$ is the number of ranks up or down that the selected `{neighborhood}` considers. Currently, the following options for `{neighborhood}` are supported: +- `up_laplacian`, between cells of rank $k$ through $k+r$ cells. +- `down_laplacian`, between cells of rank $k$ through $k-r$ cells. +- `hodge_laplacian`, between cells of rank $k$ through both $k-r$ and $k+r$ cells. +- `up_adjacency`, between cells of rank $k$ through $k+r$ cells. +- `down_adjacency`, between cells of rank $k$ through $k-r$ cells. +- `up_incidence`, from rank $k$ to $k+r$. +- `down_incidence`, from rank $k$ to $k-r$. + +The number $r$ can be omitted, in which case $r=1$ by default (e.g. `up_incidence-k` represents the incidence from rank $k$ to $k+1$). ### Using backbone models from any package @@ -235,16 +243,18 @@ We list the liftings used in `TopoBenchmark` to transform datasets. Here, a _lif -## Data Transformations +
+ Data Transformations | Transform | Description | Reference | | --- | --- | --- | | Message Passing Homophily | Higher-order homophily measure for hypergraphs | [Source](https://arxiv.org/abs/2310.07684) | | Group Homophily | Higher-order homophily measure for hypergraphs that considers groups of predefined sizes | [Source](https://arxiv.org/abs/2103.11818) | +
## :books: Datasets - +### Graphs | Dataset | Task | Description | Reference | | --- | --- | --- | --- | | Cora | Classification | Cocitation dataset. | [Source](https://link.springer.com/article/10.1023/A:1009953814988) | @@ -264,14 +274,14 @@ We list the liftings used in `TopoBenchmark` to transform datasets. Here, a _lif | US-county-demos | Regression | In turn each node attribute is used as the target label. | [Source](https://arxiv.org/pdf/2002.08274) | | ZINC | Regression | Graph-level regression. | [Source](https://pubs.acs.org/doi/10.1021/ci3001277) | - - - -## :hammer_and_wrench: Development - -To join the development of `TopoBenchmark`, you should install the library in dev mode. - -For this, you can create an environment using conda or docker. Please, follow the steps in :jigsaw: Get Started. +### Hypergraphs +| Dataset | Task | Description | Reference | +| --- | --- | --- | --- | +| Cora-Cocitation | Classification | Cocitation dataset. | [Source](https://proceedings.neurips.cc/paper_files/paper/2019/file/1efa39bcaec6f3900149160693694536-Paper.pdf) | +| Citeseer-Cocitation | Classification | Cocitation dataset. | [Source](https://proceedings.neurips.cc/paper_files/paper/2019/file/1efa39bcaec6f3900149160693694536-Paper.pdf) | +| PubMed-Cocitation | Classification | Cocitation dataset. | [Source](https://proceedings.neurips.cc/paper_files/paper/2019/file/1efa39bcaec6f3900149160693694536-Paper.pdf) | +| Cora-Coauthorship | Classification | Cocitation dataset. | [Source](https://proceedings.neurips.cc/paper_files/paper/2019/file/1efa39bcaec6f3900149160693694536-Paper.pdf) | +| DBLP-Coauthorship | Classification | Cocitation dataset. | [Source](https://proceedings.neurips.cc/paper_files/paper/2019/file/1efa39bcaec6f3900149160693694536-Paper.pdf) |