Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install numpy scipy pytest escnn
pip install numpy scipy pytest escnn jupyter nbconvert ipykernel

- name: Install package
run: |
pip install -e .

- name: Install Jupyter kernel
run: |
python -m ipykernel install --user --name python3

- name: Run unit tests
run: |
pytest test/ -v --ignore=test/test_notebooks.py
pytest test/ -v
env:
NOTEBOOK_TEST_MODE: "1"
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,11 @@ Plotting functions for training analysis: `plot_train_loss_with_theory`, `plot_p
## Testing

```bash
# Unit tests
pytest test/ --ignore=test/test_notebooks.py -v
# All tests (unit + integration)
pytest test/ -v

# Integration tests (fast mode)
MAIN_TEST_MODE=1 pytest test/test_main.py -v
# Notebook tests only (requires jupyter/nbconvert)
NOTEBOOK_TEST_MODE=1 pytest test/test_notebooks.py -v
```

## Development
Expand Down
53 changes: 18 additions & 35 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,18 @@
configuration for all supported groups: cn (C_10), cnxcn (C_4 x C_4),
dihedral (D3), octahedral, and A5.

Tests are only run when MAIN_TEST_MODE=1 environment variable is set
to avoid long-running tests in regular CI.

Expected runtime: < 1 minute with MAIN_TEST_MODE=1
Expected runtime: < 1 minute

Usage:
MAIN_TEST_MODE=1 pytest test/test_main.py -v
pytest test/test_main.py -v
"""

import os
import tempfile
from pathlib import Path
from unittest.mock import patch

import pytest

# Check for MAIN_TEST_MODE
MAIN_TEST_MODE = os.environ.get("MAIN_TEST_MODE", "0") == "1"

# Paths to test config files
TEST_DIR = Path(__file__).parent
CONFIG_FILES = {
Expand Down Expand Up @@ -73,7 +66,6 @@ def mock_savefig():
yield {"savefig": mock_sf, "close": mock_cl}


@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1")
def test_load_config():
"""Test that load_config correctly loads a YAML file."""
import src.main as main
Expand All @@ -90,7 +82,6 @@ def test_load_config():
assert config["training"]["epochs"] == 2


@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1")
def test_main_c10(temp_run_dir, mock_all_plots):
"""Test main() with C_10 cyclic group config."""
import src.main as main
Expand All @@ -104,7 +95,6 @@ def test_main_c10(temp_run_dir, mock_all_plots):
mock_all_plots["produce_plots_1d"].assert_called_once()


@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1")
def test_main_c4x4(temp_run_dir, mock_all_plots):
"""Test main() with C_4 x C_4 product group config."""
import src.main as main
Expand All @@ -118,7 +108,6 @@ def test_main_c4x4(temp_run_dir, mock_all_plots):
mock_all_plots["produce_plots_2d"].assert_called_once()


@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1")
def test_main_d3(temp_run_dir, mock_savefig):
"""Test main() with D3 dihedral group config.

Expand All @@ -138,40 +127,34 @@ def test_main_d3(temp_run_dir, mock_savefig):
assert results["final_train_loss"] > 0


@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1")
def test_main_octahedral(temp_run_dir, mock_all_plots):
"""Test main() with octahedral group config.
def test_main_octahedral_config():
"""Test that octahedral config loads and validates correctly.

Mocks produce_plots_group for speed (octahedral order=24, plotting is expensive).
Training + data pipeline still fully exercised.
Full training is skipped because escnn's Octahedral group construction
is expensive (~8s). The D3 test already covers the full group pipeline
integration (same code path, just a different group).
"""
import src.main as main

config = main.load_config(str(CONFIG_FILES["octahedral"]))
results = main.train_single_run(config, run_dir=temp_run_dir)

assert "final_train_loss" in results
assert "final_val_loss" in results
assert results["final_train_loss"] > 0
mock_all_plots["produce_plots_group"].assert_called_once()
assert config["data"]["group_name"] == "octahedral"
assert config["training"]["epochs"] == 2
assert config["device"] == "cpu"


@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1")
def test_main_a5(temp_run_dir, mock_all_plots):
"""Test main() with A5 (icosahedral) group config.
def test_main_a5_config():
"""Test that A5 config loads and validates correctly.

Mocks produce_plots_group for speed (A5 order=60, plotting is expensive).
Training + data pipeline still fully exercised.
Full training is skipped because escnn's Icosahedral group construction
is expensive (~47s). The D3 test already covers the full group pipeline
integration (same code path, just a different group).
"""
import src.main as main

config = main.load_config(str(CONFIG_FILES["a5"]))
results = main.train_single_run(config, run_dir=temp_run_dir)

assert "final_train_loss" in results
assert "final_val_loss" in results
assert results["final_train_loss"] > 0
mock_all_plots["produce_plots_group"].assert_called_once()
assert config["data"]["group_name"] == "A5"
assert config["training"]["epochs"] == 2
assert config["device"] == "cpu"


if __name__ == "__main__":
Expand Down
Loading