diff --git a/tests/__pycache__/__init__.cpython-38.pyc b/tests/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..6be8d5c Binary files /dev/null and b/tests/__pycache__/__init__.cpython-38.pyc differ diff --git a/tests/__pycache__/test_dice.cpython-38-pytest-6.2.4.pyc b/tests/__pycache__/test_dice.cpython-38-pytest-6.2.4.pyc index 853fdf4..c7ad467 100644 Binary files a/tests/__pycache__/test_dice.cpython-38-pytest-6.2.4.pyc and b/tests/__pycache__/test_dice.cpython-38-pytest-6.2.4.pyc differ diff --git a/tests/__pycache__/test_meshnet.cpython-38-pytest-6.2.4.pyc b/tests/__pycache__/test_meshnet.cpython-38-pytest-6.2.4.pyc new file mode 100644 index 0000000..b97588e Binary files /dev/null and b/tests/__pycache__/test_meshnet.cpython-38-pytest-6.2.4.pyc differ diff --git a/tests/test_meshnet.py b/tests/test_meshnet.py new file mode 100644 index 0000000..af24076 --- /dev/null +++ b/tests/test_meshnet.py @@ -0,0 +1,34 @@ +import os +import torch +import pytest +from app.code.executor.meshnet import enMesh_checkpoint + +def test_modelAE_from_production_config(): + # Emulate production: get the modelAE.json from app/code/executor/ + # Import the package to use its __file__ attribute. + import app.code.executor + config_file_path = os.path.join(os.path.dirname(app.code.executor.__file__), "modelAE.json") + + # Instantiate the model using production parameters. + model = enMesh_checkpoint( + in_channels=1, + n_classes=3, + channels=5, + config_file=config_file_path + ) + + # Create a small random input tensor. + x = torch.randn(1, 1, 16, 16, 16) + + # Test in train mode (which uses checkpoint_sequential) + model.train() + output_train = model(x) + expected_shape = (1, 3, 16, 16, 16) # Expecting n_classes=3 as output channels. + assert output_train.shape == expected_shape, f"Train mode output shape mismatch: expected {expected_shape}, got {output_train.shape}" + + # Test in eval mode (using inference mode) + model.eval() + with torch.no_grad(): + output_eval = model(x) + assert output_eval.shape == expected_shape, f"Eval mode output shape mismatch: expected {expected_shape}, got {output_eval.shape}" +