-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:geometric-intelligence/TopoBenchmar…
…k into dev
- Loading branch information
Showing
11 changed files
with
282 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.ipynb linguist-vendored |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,43 @@ | ||
""" Test the TBEvaluator class.""" | ||
import pytest | ||
|
||
import torch | ||
from topobenchmark.evaluator import TBEvaluator | ||
|
||
class TestTBEvaluator: | ||
""" Test the TBXEvaluator class.""" | ||
|
||
def setup_method(self): | ||
""" Setup the test.""" | ||
self.evaluator_multilable = TBEvaluator(task="multilabel classification") | ||
self.evaluator_regression = TBEvaluator(task="regression") | ||
self.classification_metrics = ["accuracy", "precision", "recall", "auroc"] | ||
self.evaluator_classification = TBEvaluator(task="classification", num_classes=3, metrics=self.classification_metrics) | ||
self.evaluator_multilabel = TBEvaluator(task="multilabel classification", num_classes=2, metrics=self.classification_metrics) | ||
self.regression_metrics = ["example", "mae"] | ||
self.evaluator_regression = TBEvaluator(task="regression", num_classes=1, metrics=self.regression_metrics) | ||
with pytest.raises(ValueError): | ||
TBEvaluator(task="wrong") | ||
repr = self.evaluator_multilable.__repr__() | ||
TBEvaluator(task="wrong", num_classes=2, metrics=self.classification_metrics) | ||
|
||
def test_repr(self): | ||
"""Test the __repr__ method.""" | ||
assert "TBEvaluator" in self.evaluator_classification.__repr__() | ||
assert "TBEvaluator" in self.evaluator_multilabel.__repr__() | ||
assert "TBEvaluator" in self.evaluator_regression.__repr__() | ||
|
||
def test_update_and_compute(self): | ||
"""Test the update and compute methods.""" | ||
self.evaluator_classification.update({"logits": torch.randn(10, 3), "labels": torch.randint(0, 3, (10,))}) | ||
out = self.evaluator_classification.compute() | ||
for metric in self.classification_metrics: | ||
assert metric in out | ||
self.evaluator_multilabel.update({"logits": torch.randn(10, 2), "labels": torch.randint(0, 2, (10, 2))}) | ||
out = self.evaluator_multilabel.compute() | ||
for metric in self.classification_metrics: | ||
assert metric in out | ||
self.evaluator_regression.update({"logits": torch.randn(10, 1), "labels": torch.randn(10,)}) | ||
out = self.evaluator_regression.compute() | ||
for metric in self.regression_metrics: | ||
assert metric in out | ||
|
||
def test_reset(self): | ||
"""Test the reset method.""" | ||
self.evaluator_multilabel.reset() | ||
self.evaluator_regression.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Init file for custom metrics in evaluator module.""" | ||
|
||
import importlib | ||
import inspect | ||
import sys | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
|
||
class LoadManager: | ||
"""Manages automatic discovery and registration of loss classes.""" | ||
|
||
@staticmethod | ||
def is_metric_class(obj: Any) -> bool: | ||
"""Check if an object is a valid metric class. | ||
Parameters | ||
---------- | ||
obj : Any | ||
The object to check if it's a valid loss class. | ||
Returns | ||
------- | ||
bool | ||
True if the object is a valid loss class (non-private class | ||
with 'FeatureEncoder' in name), False otherwise. | ||
""" | ||
try: | ||
from torchmetrics import Metric | ||
|
||
return ( | ||
inspect.isclass(obj) | ||
and not obj.__name__.startswith("_") | ||
and issubclass(obj, Metric) | ||
and obj is not Metric | ||
) | ||
except ImportError: | ||
return False | ||
|
||
@classmethod | ||
def discover_metrics(cls, package_path: str) -> dict[str, type]: | ||
"""Dynamically discover all metric classes in the package. | ||
Parameters | ||
---------- | ||
package_path : str | ||
Path to the package's __init__.py file. | ||
Returns | ||
------- | ||
Dict[str, Type] | ||
Dictionary mapping loss class names to their corresponding class objects. | ||
""" | ||
metrics = {} | ||
package_dir = Path(package_path).parent | ||
|
||
# Add parent directory to sys.path to ensure imports work | ||
parent_dir = str(package_dir.parent) | ||
if parent_dir not in sys.path: | ||
sys.path.insert(0, parent_dir) | ||
|
||
# Iterate through all .py files in the directory | ||
for file_path in package_dir.glob("*.py"): | ||
if file_path.stem == "__init__": | ||
continue | ||
|
||
try: | ||
# Use importlib to safely import the module | ||
module_name = f"{package_dir.stem}.{file_path.stem}" | ||
module = importlib.import_module(module_name) | ||
|
||
# Find all loss classes in the module | ||
for name, obj in inspect.getmembers(module): | ||
if ( | ||
cls.is_metric_class(obj) | ||
and obj.__module__ == module.__name__ | ||
): | ||
metrics[name] = obj # noqa: PERF403 | ||
|
||
except ImportError as e: | ||
print(f"Could not import module {module_name}: {e}") | ||
|
||
return metrics | ||
|
||
|
||
# Dynamically create the loss manager and discover losses | ||
manager = LoadManager() | ||
CUSTOM_METRICS = manager.discover_metrics(__file__) | ||
CUSTOM_METRICS_list = list(CUSTOM_METRICS.keys()) | ||
|
||
# Combine manual and discovered losses | ||
all_metrics = {**CUSTOM_METRICS} | ||
|
||
# Generate __all__ | ||
__all__ = [ | ||
"CUSTOM_METRICS", | ||
"CUSTOM_METRICS_list", | ||
*list(all_metrics.keys()), | ||
] | ||
|
||
# Update locals for direct import | ||
locals().update(all_metrics) | ||
|
||
# from .example import ExampleRegressionMetric | ||
|
||
# __all__ = [ | ||
# "ExampleRegressionMetric", | ||
# ] |
Oops, something went wrong.