diff --git a/.github/workflows/mypy-type-check.yml b/.github/workflows/mypy-type-check.yml index b7c15a5ec..eeb5988b4 100644 --- a/.github/workflows/mypy-type-check.yml +++ b/.github/workflows/mypy-type-check.yml @@ -6,7 +6,7 @@ on: push: branches: [ develop, pre-release, master, main ] pull_request: - branches: [ develop, pre-release, master, main ] + branches: [ develop, pre-release, master, main, dev-define-engines-abc ] jobs: diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index afccd6512..82aa762a0 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -8,7 +8,7 @@ on: branches: [ develop, pre-release, master, main ] tags: v* pull_request: - branches: [ develop, pre-release, master, main ] + branches: [ develop, pre-release, master, main, dev-define-engines-abc] jobs: build: diff --git a/tests/engines/__init__.py b/tests/engines/__init__.py new file mode 100644 index 000000000..193a523c1 --- /dev/null +++ b/tests/engines/__init__.py @@ -0,0 +1 @@ +"""Unit test package for tiatoolbox engines.""" diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py new file mode 100644 index 000000000..aec2bf310 --- /dev/null +++ b/tests/engines/test_engine_abc.py @@ -0,0 +1,625 @@ +"""Test tiatoolbox.models.engine.engine_abc.""" + +from __future__ import annotations + +import copy +import logging +import shutil +from pathlib import Path +from typing import TYPE_CHECKING, NoReturn + +import numpy as np +import pytest +import torchvision.models as torch_models +from typing_extensions import Unpack + +from tiatoolbox.models.architecture import ( + fetch_pretrained_weights, + get_pretrained_model, +) +from tiatoolbox.models.architecture.vanilla import CNNModel +from tiatoolbox.models.dataset import PatchDataset, WSIPatchDataset +from tiatoolbox.models.engine.engine_abc import ( + EngineABC, + EngineABCRunParams, + prepare_engines_save_dir, +) +from tiatoolbox.models.engine.io_config import ModelIOConfigABC + +if TYPE_CHECKING: + import torch.nn + + +class TestEngineABC(EngineABC): + """Test EngineABC.""" + + def __init__( + self: TestEngineABC, + model: str | torch.nn.Module, + weights: str | Path | None = None, + verbose: bool | None = None, + ) -> NoReturn: + """Test EngineABC init.""" + super().__init__(model=model, weights=weights, verbose=verbose) + + def get_dataloader( + self: EngineABC, + images: Path, + masks: Path | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + ) -> torch.utils.data.DataLoader: + """Test pre process images.""" + return super().get_dataloader( + images, + masks, + labels, + ioconfig, + patch_mode=patch_mode, + ) + + def post_process_wsi( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output.""" + return super().post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, + ) + + def infer_wsi( + self: EngineABC, + dataloader: torch.utils.data.DataLoader, + save_path: Path, + **kwargs: dict, + ) -> dict | np.ndarray: + """Test infer_wsi.""" + return super().infer_wsi( + dataloader, + save_path, + **kwargs, + ) + + +def test_engine_abc_incorrect_model_type() -> NoReturn: + """Test EngineABC initialization with incorrect model type.""" + with pytest.raises( + TypeError, + match=r".*missing 1 required positional argument: 'model'", + ): + TestEngineABC() # skipcq + + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + TestEngineABC(model=1) + + +def test_incorrect_ioconfig() -> NoReturn: + """Test EngineABC initialization with incorrect ioconfig.""" + model = torch_models.resnet18() + engine = TestEngineABC(model=model) + + with pytest.raises( + ValueError, + match=r".*Must provide.*`ioconfig`.*", + ): + engine.run(images=[], masks=[], ioconfig=None) + + +def test_incorrect_output_type() -> NoReturn: + """Test EngineABC for incorrect output type.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + + with pytest.raises( + TypeError, + match=r".*output_type must be 'dict' or 'zarr' or 'annotationstore*", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + output_type="random", + ) + + +def test_pretrained_ioconfig() -> NoReturn: + """Test EngineABC initialization with pretrained model name in the toolbox.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + ) + assert "probabilities" in out + assert "labels" not in out + + +def test_ioconfig() -> NoReturn: + """Test EngineABC initialization with valid ioconfig.""" + ioconfig = ModelIOConfigABC( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + ], + patch_input_shape=(224, 224), + ) + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + ioconfig=ioconfig, + ) + + assert "probabilities" in out + assert "labels" not in out + + +def test_prepare_engines_save_dir( + tmp_path: pytest.TempPathFactory, + caplog: pytest.LogCaptureFixture, +) -> NoReturn: + """Test prepare save directory for engines.""" + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + overwrite=False, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + overwrite=True, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + out_dir = prepare_engines_save_dir( + save_dir=None, + patch_mode=True, + overwrite=False, + ) + assert out_dir is None + + with pytest.raises( + OSError, + match=r".*Input WSIs detected but no save directory provided.*", + ): + _ = prepare_engines_save_dir( + save_dir=None, + patch_mode=False, + overwrite=False, + ) + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_single_output", + patch_mode=False, + overwrite=False, + ) + + assert out_dir == tmp_path / "wsi_single_output" + assert out_dir.exists() + assert r"When providing multiple whole-slide images / tiles" not in caplog.text + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_multiple_output", + patch_mode=False, + overwrite=False, + ) + + assert out_dir == tmp_path / "wsi_multiple_output" + assert out_dir.exists() + assert r"When providing multiple whole slide images" in caplog.text + + # test for file overwrite with Path.mkdirs() method + out_path = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output" / "output.zarr", + patch_mode=True, + overwrite=True, + ) + assert out_path.exists() + + out_path = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output" / "output.zarr", + patch_mode=True, + overwrite=True, + ) + assert out_path.exists() + + with pytest.raises(FileExistsError): + out_path = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output" / "output.zarr", + patch_mode=True, + overwrite=False, + ) + + +def test_engine_initalization() -> NoReturn: + """Test engine initialization.""" + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + _ = TestEngineABC(model=0) + + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) + model = CNNModel("alexnet", num_classes=1) + eng = TestEngineABC(model=model) + assert isinstance(eng, EngineABC) + + model = get_pretrained_model("alexnet-kather100k")[0] + weights_path = fetch_pretrained_weights("alexnet-kather100k") + eng = TestEngineABC(model=model, weights=weights_path) + assert isinstance(eng, EngineABC) + + +def test_engine_run() -> NoReturn: + """Test engine run.""" + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*The input numpy array should be four dimensional.*", + ): + eng.run(images=np.zeros((10, 10))) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + TypeError, + match=r"Input must be a list of file paths or a numpy array.", + ): + eng.run(images=1) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*len\(labels\) is not equal to len(images)*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(1)), + on_gpu=False, + ) + + with pytest.raises( + ValueError, + match=r".*len\(masks\) is not equal to len(images)*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((1, 224, 224, 3)), + on_gpu=False, + ) + + with pytest.raises( + ValueError, + match=r".*The shape of the numpy array should be NHWC*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((10, 3)), + on_gpu=False, + ) + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ) + assert "probabilities" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + ) + assert "probabilities" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + assert "probabilities" in out + assert "labels" in out + + pred = eng.post_process_wsi( + raw_predictions=Path("/path/to/raw_predictions.npy"), + ) + assert str(pred) == "/path/to/raw_predictions.npy" + + +def test_engine_run_with_verbose() -> NoReturn: + """Test engine run with verbose.""" + # Run pytest with `-rP` option to view progress bar on the captured stderr call. + + eng = TestEngineABC(model="alexnet-kather100k", verbose=True) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + + assert "probabilities" in out + assert "labels" in out + + +def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: + """Test the engine run and patch pred store.""" + save_dir = tmp_path / "patch_output" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + # Test custom zarr output file name + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_file="patch_pred_output", + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*Patch output must contain coordinates.", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_type="AnnotationStore", + ) + + with pytest.raises( + ValueError, + match=r".*Patch output must contain coordinates.", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_type="AnnotationStore", + class_dict={0: "class0", 1: "class1"}, + ) + + with pytest.raises( + ValueError, + match=r".*Patch output must contain coordinates.", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_type="AnnotationStore", + scale_factor=(2.0, 2.0), + ) + + +def test_cache_mode_patches(tmp_path: pytest.TempPathFactory) -> NoReturn: + """Test the caching mode.""" + save_dir = tmp_path / "patch_output" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + cache_mode=True, + ) + assert out.exists(), "Zarr output file does not exist" + + output_file_name = "output2.zarr" + cache_size = 4 + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + cache_mode=True, + cache_size=4, + batch_size=8, + output_file=output_file_name, + ) + assert out.stem == output_file_name.split(".")[0] + assert eng.batch_size == cache_size + assert out.exists(), "Zarr output file does not exist" + + +def test_get_dataloader(sample_svs: Path) -> None: + """Test the get_dataloader function.""" + eng = TestEngineABC(model="alexnet-kather100k") + ioconfig = ModelIOConfigABC( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + ], + patch_input_shape=(224, 224), + ) + dataloader = eng.get_dataloader( + images=np.zeros(shape=(10, 224, 224, 3), dtype=np.uint8), + patch_mode=True, + ioconfig=ioconfig, + ) + + assert isinstance(dataloader.dataset, PatchDataset) + + dataloader = eng.get_dataloader( + images=sample_svs, + patch_mode=False, + ioconfig=ioconfig, + ) + + assert isinstance(dataloader.dataset, WSIPatchDataset) + + +def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + """Test for delegating args to io config.""" + # test not providing config / full input info for not pretrained models + model = CNNModel("resnet50") + eng = TestEngineABC(model=model) + + kwargs = { + "patch_input_shape": [512, 512], + "resolution": 1.75, + "units": "mpp", + } + with caplog.at_level(logging.WARNING): + eng.run( + np.zeros((10, 224, 224, 3)), + patch_mode=True, + save_dir=tmp_path / "dump", + patch_input_shape=kwargs["patch_input_shape"], + resolution=kwargs["resolution"], + units=kwargs["units"], + ) + assert "provide a valid ModelIOConfigABC" in caplog.text + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + # test providing config / full input info for non pretrained models + ioconfig = ModelIOConfigABC( + patch_input_shape=(512, 512), + stride_shape=(256, 256), + input_resolutions=[{"resolution": 1.35, "units": "mpp"}], + ) + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_mode=True, + save_dir=f"{tmp_path}/dump", + ioconfig=ioconfig, + ) + assert eng._ioconfig.patch_input_shape == (512, 512) + assert eng._ioconfig.stride_shape == (256, 256) + assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}] + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_mode=True, + save_dir=f"{tmp_path}/dump", + **kwargs, + ) + assert eng._ioconfig.patch_input_shape == [512, 512] + assert eng._ioconfig.stride_shape == [512, 512] + assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}] + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + # test overwriting pretrained ioconfig + eng = TestEngineABC(model="alexnet-kather100k") + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_input_shape=(300, 300), + stride_shape=(300, 300), + resolution=1.99, + units="baseline", + patch_mode=True, + save_dir=f"{tmp_path}/dump", + ) + assert eng._ioconfig.patch_input_shape == (300, 300) + assert eng._ioconfig.stride_shape == (300, 300) + assert eng._ioconfig.input_resolutions[0]["resolution"] == 1.99 + assert eng._ioconfig.input_resolutions[0]["units"] == "baseline" + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_input_shape=(300, 300), + stride_shape=(300, 300), + resolution=None, + units=None, + patch_mode=True, + save_dir=f"{tmp_path}/dump", + ) + assert eng._ioconfig.patch_input_shape == (300, 300) + assert eng._ioconfig.stride_shape == (300, 300) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + eng.ioconfig = None + _ioconfig = eng._update_ioconfig( + ioconfig=None, + patch_input_shape=(300, 300), + stride_shape=(300, 300), + resolution=1.99, + units="baseline", + ) + + assert _ioconfig.patch_input_shape == (300, 300) + assert _ioconfig.stride_shape == (300, 300) + assert _ioconfig.input_resolutions[0]["resolution"] == 1.99 + assert _ioconfig.input_resolutions[0]["units"] == "baseline" + + for key in kwargs: + _kwargs = copy.deepcopy(kwargs) + _kwargs[key] = None + with pytest.raises( + ValueError, + match=r".*Must provide either `ioconfig` or " + r"`patch_input_shape`, `resolution`, and `units`*", + ): + eng._update_ioconfig( + ioconfig=None, + patch_input_shape=_kwargs["patch_input_shape"], + stride_shape=(1, 1), + resolution=_kwargs["resolution"], + units=_kwargs["units"], + ) diff --git a/tests/engines/test_ioconfig.py b/tests/engines/test_ioconfig.py new file mode 100644 index 000000000..41169298b --- /dev/null +++ b/tests/engines/test_ioconfig.py @@ -0,0 +1,23 @@ +"""Tests for IOconfig.""" + +import pytest + +from tiatoolbox.models import ModelIOConfigABC + + +def test_validation_error_io_config() -> None: + """Test Validation Error for ModelIOConfigABC.""" + with pytest.raises(ValueError, match=r".*Multiple resolution units found.*"): + ModelIOConfigABC( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + {"units": "mpp", "resolution": 0.25}, + ], + patch_input_shape=(224, 224), + ) + + with pytest.raises(ValueError, match=r"Invalid resolution units.*"): + ModelIOConfigABC( + input_resolutions=[{"units": "level", "resolution": 1.0}], + patch_input_shape=(224, 224), + ) diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py new file mode 100644 index 000000000..9b647fc97 --- /dev/null +++ b/tests/engines/test_patch_predictor.py @@ -0,0 +1,688 @@ +"""Test for Patch Predictor.""" + +from __future__ import annotations + +import copy +import json +import shutil +import sqlite3 +from pathlib import Path +from typing import TYPE_CHECKING, Callable + +import numpy as np +import torch +import zarr +from click.testing import CliRunner + +from tests.conftest import timed +from tiatoolbox import cli, logger, rcParam +from tiatoolbox.models import IOPatchPredictorConfig +from tiatoolbox.models.architecture.vanilla import CNNModel +from tiatoolbox.models.engine.patch_predictor import PatchPredictor +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import download_data, get_zarr_array, imwrite + +if TYPE_CHECKING: + import pytest + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def _test_predictor_output( + inputs: list, + model: str, + probabilities_check: list | None = None, + classification_check: list | None = None, + output_type: str = "dict", + tmp_path: Path | None = None, +) -> None: + """Test the predictions of multiple models included in tiatoolbox.""" + cache_mode = None if tmp_path is None else True + save_dir = None if tmp_path is None else tmp_path / "output" + predictor = PatchPredictor( + model=model, + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + inputs, + return_labels=False, + device=device, + cache_mode=cache_mode, + save_dir=save_dir, + output_type=output_type, + ) + + if tmp_path is not None: + output = zarr.open(output, mode="r") + + probabilities = output["probabilities"] + classification = output["predictions"] + for idx, probabilities_ in enumerate(probabilities): + probabilities_max = max(probabilities_) + assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( + model, + probabilities_max, + probabilities_check[idx], + probabilities_, + classification_check[idx], + ) + assert classification[idx] == classification_check[idx], ( + model, + probabilities_max, + probabilities_check[idx], + probabilities_, + classification_check[idx], + ) + if save_dir: + shutil.rmtree(save_dir) + + +def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: + """Test for delegating args to io config.""" + mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) + model = CNNModel("resnet50") + predictor = PatchPredictor(model=model, weights=None) + kwargs = { + "patch_input_shape": [512, 512], + "resolution": 1.75, + "units": "mpp", + } + + # test providing config / full input info for default models without weights + ioconfig = IOPatchPredictorConfig( + patch_input_shape=(512, 512), + stride_shape=(256, 256), + input_resolutions=[{"resolution": 1.35, "units": "mpp"}], + ) + predictor.run( + images=[mini_wsi_svs], + ioconfig=ioconfig, + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + patch_mode=False, + save_dir=f"{tmp_path}/dump", + **kwargs, + ) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + # test overwriting pretrained ioconfig + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1) + predictor.run( + images=[mini_wsi_svs], + patch_input_shape=(300, 300), + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.patch_input_shape == (300, 300) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + stride_shape=(300, 300), + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.stride_shape == (300, 300) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + resolution=1.99, + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + units="baseline", + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + units="level", + resolution=0, + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "level" + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 0 + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + units="power", + resolution=20, + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "power" + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 20 + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + +def test_patch_predictor_api( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test PatchPredictor API.""" + save_dir_path = tmp_path + + # convert to pathlib Path to prevent reader complaint + inputs = [Path(sample_patch1), Path(sample_patch2)] + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1) + # don't run test on GPU + # Default run + output = predictor.run( + inputs, + device="cpu", + ) + assert sorted(output.keys()) == ["predictions", "probabilities"] + assert len(output["probabilities"]) == 2 + shutil.rmtree(save_dir_path, ignore_errors=True) + + # whether to return labels + output = predictor.run( + inputs, + labels=["1", "a"], + return_labels=True, + ) + assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) + assert len(output["probabilities"]) == len(output["labels"]) + assert output["labels"].tolist() == ["1", "a"] + shutil.rmtree(save_dir_path, ignore_errors=True) + + # test loading user weight + pretrained_weights_url = ( + "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" + ) + + # remove prev generated data + shutil.rmtree(save_dir_path, ignore_errors=True) + save_dir_path.mkdir(parents=True) + pretrained_weights = ( + save_dir_path / "tmp_pretrained_weigths" / "resnet18-kather100k.pth" + ) + + download_data(pretrained_weights_url, pretrained_weights) + + predictor = PatchPredictor( + model="resnet18-kather100k", + weights=pretrained_weights, + batch_size=1, + ) + ioconfig = predictor.ioconfig + + # --- test different using user model + model = CNNModel(backbone="resnet18", num_classes=9) + # test prediction + predictor = PatchPredictor(model=model, batch_size=1, verbose=False) + output = predictor.run( + inputs, + labels=[1, 2], + return_labels=True, + ioconfig=ioconfig, + ) + assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) + assert len(output["probabilities"]) == len(output["labels"]) + assert output["labels"].tolist() == [1, 2] + + +def test_wsi_predictor_api( + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """Test normal run of wsi predictor.""" + save_dir_path = tmp_path + + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + patch_size = np.array([224, 224]) + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=32) + + save_dir = f"{save_dir_path}/model_wsi_output" + + # wrapper to make this more clean + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 1.0, + "units": "baseline", + "save_dir": save_dir, + } + # ! add this test back once the read at `baseline` is fixed + # sanity check, both output should be the same with same resolution read args + # remove previously generated data + + _kwargs = copy.deepcopy(kwargs) + # test reading of multiple whole-slide images + output = predictor.run( + images=[mini_wsi_svs, mini_wsi_jpg], + masks=[mini_wsi_msk, mini_wsi_msk], + patch_mode=False, + **_kwargs, + ) + + wsi_out = zarr.open(str(output[mini_wsi_svs]), mode="r") + tile_out = zarr.open(str(output[mini_wsi_jpg]), mode="r") + diff = tile_out["probabilities"][:] == wsi_out["probabilities"][:] + accuracy = np.sum(diff) / np.size(wsi_out["probabilities"][:]) + assert accuracy > 0.99, np.nonzero(~diff) + + diff = tile_out["predictions"][:] == wsi_out["predictions"][:] + accuracy = np.sum(diff) / np.size(wsi_out["predictions"][:]) + assert accuracy > 0.99, np.nonzero(~diff) + + shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) + + +def test_patch_predictor_kather100k_output( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test the output of patch classification models on Kather100K dataset.""" + inputs = [Path(sample_patch1), Path(sample_patch2)] + pretrained_info = { + "alexnet-kather100k": [1.0, 0.9999735355377197], + "resnet18-kather100k": [1.0, 0.9999911785125732], + "resnet34-kather100k": [1.0, 0.9979840517044067], + "resnet50-kather100k": [1.0, 0.9999986886978149], + "resnet101-kather100k": [1.0, 0.9999932050704956], + "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], + "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], + "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], + "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], + "densenet121-kather100k": [1.0, 1.0], + "densenet161-kather100k": [1.0, 0.9999959468841553], + "densenet169-kather100k": [1.0, 0.9999934434890747], + "densenet201-kather100k": [1.0, 0.9999983310699463], + "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], + "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], + "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], + "googlenet-kather100k": [1.0, 0.9999639987945557], + } + for model, expected_prob in pretrained_info.items(): + _test_predictor_output( + inputs, + model, + probabilities_check=expected_prob, + classification_check=[6, 3], + ) + + # cache mode + for model, expected_prob in pretrained_info.items(): + _test_predictor_output( + inputs, + model, + probabilities_check=expected_prob, + classification_check=[6, 3], + tmp_path=tmp_path, + ) + + +def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: + """Helper function to extract probabilities from Annotation Store.""" + con = sqlite3.connect(dbfile) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + output = {"probabilities": [], "predictions": []} + + for item in annotations_properties: + for json_str in item: + probs_dict = json.loads(json_str) + if "proba_0" in probs_dict: + output["probabilities"].append(probs_dict.pop("prob_0")) + output["predictions"].append(probs_dict.pop("type")) + + return output + + +def _validate_probabilities(output: list | dict | zarr.group) -> bool: + """Helper function to test if the probabilities value are valid.""" + probabilities = np.array([0.5]) + + if "probabilities" in output: + probabilities = output["probabilities"] + + predictions = output["predictions"] + if isinstance(probabilities, dict): + return all(0 <= probability <= 1 for _, probability in probabilities.items()) + + predictions = np.array(get_zarr_array(predictions)).astype(int) + probabilities = get_zarr_array(probabilities) + + if not np.all(np.array(probabilities) <= 1): + return False + + if not np.all(np.array(probabilities) >= 0): + return False + + return np.all(predictions[:][0:5] == [7, 3, 2, 3, 3]) + + +def test_wsi_predictor_zarr( + sample_wsi_dict: dict, tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Test normal run of patch predictor for WSIs.""" + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + images=[mini_wsi_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + ) + + assert output[mini_wsi_svs].exists() + + output_ = zarr.open(output[mini_wsi_svs]) + + assert output_["probabilities"].shape == (70, 9) # number of patches x classes + assert output_["probabilities"].ndim == 2 + # number of patches x [start_x, start_y, end_x, end_y] + assert output_["coordinates"].shape == (70, 4) + assert output_["coordinates"].ndim == 2 + # prediction for each patch + assert output_["predictions"].shape == (70,) + assert output_["predictions"].ndim == 1 + assert _validate_probabilities(output=output_) + assert "Output file saved at " in caplog.text + + output = predictor.run( + images=[mini_wsi_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check_no_probabilities", + ) + + assert output[mini_wsi_svs].exists() + + output_ = zarr.open(output[mini_wsi_svs]) + + assert "probabilities" not in output_ + # number of patches x [start_x, start_y, end_x, end_y] + assert output_["coordinates"].shape == (70, 4) + assert output_["coordinates"].ndim == 2 + # prediction for each patch + assert output_["predictions"].shape == (70,) + assert output_["predictions"].ndim == 1 + assert _validate_probabilities(output=output_) + assert "Output file saved at " in caplog.text + + +def test_patch_predictor_patch_mode_annotation_store( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test the output of patch classification models on Kather100K dataset.""" + inputs = [Path(sample_patch1), Path(sample_patch2)] + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "patch_out_check", + output_type="annotationstore", + ) + + assert output.exists() + output = _extract_probabilities_from_annotation_store(output) + assert np.all(output["predictions"] == [6, 3]) + assert np.all(np.array(output["probabilities"]) <= 1) + assert np.all(np.array(output["probabilities"]) >= 0) + + +def test_patch_predictor_patch_mode_no_probabilities( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test the output of patch classification models on Kather100K dataset.""" + inputs = [Path(sample_patch1), Path(sample_patch2)] + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + + output = predictor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + ) + + assert "probabilities" not in output + + # don't run test on GPU + output = predictor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "patch_out_check", + output_type="annotationstore", + ) + + assert output.exists() + output = _extract_probabilities_from_annotation_store(output) + assert np.all(output["predictions"] == [6, 3]) + assert output["probabilities"] == [] + + +def test_engine_run_wsi_annotation_store( + sample_wsi_dict: dict, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the engine run for Whole slide images.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + eng = PatchPredictor(model="alexnet-kather100k") + + patch_size = np.array([224, 224]) + save_dir = f"{tmp_path}/model_wsi_output" + + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 0.5, + "save_dir": save_dir, + "units": "mpp", + "scale_factor": (2.0, 2.0), + } + + output = eng.run( + images=[mini_wsi_svs], + masks=[mini_wsi_msk], + patch_mode=False, + output_type="AnnotationStore", + batch_size=4, + **kwargs, + ) + + output_ = output[mini_wsi_svs] + + assert output_.exists() + assert output_.suffix == ".db" + output_ = _extract_probabilities_from_annotation_store(output_) + + # prediction for each patch + assert np.array(output_["predictions"]).shape == (69,) + assert _validate_probabilities(output_) + + assert "Output file saved at " in caplog.text + + shutil.rmtree(save_dir) + + +# -------------------------------------------------------------------------------------- +# torch.compile +# -------------------------------------------------------------------------------------- +def test_patch_predictor_torch_compile( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test PatchPredictor with torch.compile functionality. + + Args: + sample_patch1 (Path): Path to sample patch 1. + sample_patch2 (Path): Path to sample patch 2. + tmp_path (Path): Path to temporary directory. + + """ + torch_compile_mode = rcParam["torch_compile_mode"] + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "default" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile default mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "reduce-overhead" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile reduce-overhead mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "max-autotune" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile max-autotune mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = torch_compile_mode + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: + """Test for models CLI single file.""" + runner = CliRunner() + models_wsi_result = runner.invoke( + cli.main, + [ + "patch-predictor", + "--img-input", + str(sample_svs), + "--patch-mode", + "False", + "--output-path", + str(tmp_path / "output"), + ], + ) + + assert models_wsi_result.exit_code == 0 + assert (tmp_path / "output" / (sample_svs.stem + ".db")).exists() + + +def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: + """Test for models CLI multiple file with mask.""" + mini_wsi_svs = Path(remote_sample("svs-1-small")) + sample_wsi_msk = remote_sample("small_svs_tissue_mask") + sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) + imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) + mini_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") + + # Make multiple copies for test + dir_path = tmp_path.joinpath("new_copies") + dir_path.mkdir() + + dir_path_masks = tmp_path.joinpath("new_copies_masks") + dir_path_masks.mkdir() + + try: + dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) + dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) + dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) + except OSError: + shutil.copy(mini_wsi_svs, dir_path / ("1_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("2_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("3_" + mini_wsi_svs.name)) + + try: + dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) + dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) + dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) + except OSError: + shutil.copy(mini_wsi_msk, dir_path_masks / ("1_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks / ("2_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks / ("3_" + mini_wsi_msk.name)) + + runner = CliRunner() + models_tiles_result = runner.invoke( + cli.main, + [ + "patch-predictor", + "--img-input", + str(dir_path), + "--patch-mode", + str(False), + "--masks", + str(dir_path_masks), + "--output-path", + str(tmp_path / "output"), + "--output-type", + "zarr", + ], + ) + + assert models_tiles_result.exit_code == 0 + assert (tmp_path / "output" / ("1_" + mini_wsi_svs.stem + ".zarr")).exists() + assert (tmp_path / "output" / ("2_" + mini_wsi_svs.stem + ".zarr")).exists() + assert (tmp_path / "output" / ("3_" + mini_wsi_svs.stem + ".zarr")).exists() diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index 4f6229a12..ab9a6033f 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -5,13 +5,24 @@ import shutil from pathlib import Path +import cv2 import numpy as np import pytest +import torch from tiatoolbox import rcParam -from tiatoolbox.models.dataset import DatasetInfoABC, KatherPatchDataset, PatchDataset -from tiatoolbox.utils import download_data, unzip_data +from tiatoolbox.models import PatchDataset, WSIPatchDataset +from tiatoolbox.models.dataset import ( + DatasetInfoABC, + KatherPatchDataset, + PatchDatasetABC, + predefined_preproc_func, +) +from tiatoolbox.utils import download_data, imread, imwrite, unzip_data from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.wsicore import WSIReader + +RNG = np.random.default_rng() # Numpy Random Generator class Proto1(DatasetInfoABC): @@ -116,3 +127,435 @@ def test_kather_dataset(tmp_path: Path) -> None: # remove generated data shutil.rmtree(save_dir_path, ignore_errors=True) + + +def test_patch_dataset_path_imgs( + sample_patch1: str | Path, + sample_patch2: str | Path, +) -> None: + """Test for patch dataset with a list of file paths as input.""" + size = (224, 224, 3) + + dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) + + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + +def test_patch_dataset_list_imgs(tmp_path: Path) -> None: + """Test for patch dataset with a list of images as input.""" + save_dir_path = tmp_path + + size = (5, 5, 3) + img = RNG.integers(low=0, high=255, size=size) + list_imgs = [img, img, img] + dataset = PatchDataset(list_imgs) + + dataset.preproc_func = lambda x: x + + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + # test for changing to another preproc + dataset.preproc_func = lambda x: x - 10 + item = dataset[0] + assert np.sum(item["image"] - (list_imgs[0] - 10)) == 0 + + # * test for loading npy + # remove previously generated data + if Path.exists(save_dir_path): + shutil.rmtree(save_dir_path, ignore_errors=True) + Path.mkdir(save_dir_path, parents=True) + np.save( + str(save_dir_path / "sample2.npy"), + RNG.integers(0, 255, (4, 4, 3)), + ) + imgs = [ + save_dir_path / "sample2.npy", + ] + _ = PatchDataset(imgs) + assert imgs[0] is not None + # test for path object + imgs = [ + save_dir_path / "sample2.npy", + ] + _ = PatchDataset(imgs) + + +def test_patch_datasetarray_imgs() -> None: + """Test for patch dataset with a numpy array of a list of images.""" + size = (5, 5, 3) + img = RNG.integers(0, 255, size=size) + list_imgs = [img, img, img] + labels = [1, 2, 3] + array_imgs = np.array(list_imgs) + + # test different setter for label + dataset = PatchDataset(array_imgs, labels=labels) + an_item = dataset[2] + assert an_item["label"] == 3 + dataset = PatchDataset(array_imgs, labels=None) + an_item = dataset[2] + assert "label" not in an_item + + dataset = PatchDataset(array_imgs) + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + +def test_patch_dataset_crash(tmp_path: Path) -> None: + """Test to make sure patch dataset crashes with incorrect input.""" + # all below examples should fail when input to PatchDataset + save_dir_path = tmp_path + + # not supported input type + imgs = {"a": RNG.integers(0, 255, (4, 4, 4))} + with pytest.raises( + ValueError, + match=r".*Input must be either a list/array of images.*", + ): + _ = PatchDataset(imgs) + + # ndarray of mixed dtype + imgs = np.array( + # string array of the same shape + [ + RNG.integers(0, 255, (4, 5, 3)), + np.array( # skipcq: PYL-E1121 + ["PatchDataset should crash here" for _ in range(4 * 5 * 3)], + ).reshape( + 4, + 5, + 3, + ), + ], + dtype=object, + ) + with pytest.raises(ValueError, match="Provided input array is non-numerical."): + _ = PatchDataset(imgs) + + # ndarray(s) of NHW images + imgs = RNG.integers(0, 255, (4, 4, 4)) + with pytest.raises(ValueError, match=r".*array of the form HWC*"): + _ = PatchDataset(imgs) + + # list of ndarray(s) with different sizes + imgs = [ + RNG.integers(0, 255, (4, 4, 3)), + RNG.integers(0, 255, (4, 5, 3)), + ] + with pytest.raises(ValueError, match="Images must have the same dimensions."): + _ = PatchDataset(imgs) + + # list of ndarray(s) with HW and HWC mixed up + imgs = [ + RNG.integers(0, 255, (4, 4, 3)), + RNG.integers(0, 255, (4, 4)), + ] + with pytest.raises( + ValueError, + match="Each sample must be an array of the form HWC.", + ): + _ = PatchDataset(imgs) + + # list of mixed dtype + imgs = [RNG.integers(0, 255, (4, 4, 3)), "you_should_crash_here", 123, 456] + with pytest.raises( + ValueError, + match="Input must be either a list/array of images or a list of " + "valid image paths.", + ): + _ = PatchDataset(imgs) + + # list of mixed dtype + imgs = ["you_should_crash_here", 123, 456] + with pytest.raises( + ValueError, + match="Input must be either a list/array of images or a list of " + "valid image paths.", + ): + _ = PatchDataset(imgs) + + # list not exist paths + with pytest.raises( + ValueError, + match=r".*valid image paths.*", + ): + _ = PatchDataset(["img.npy"]) + + # ** test different extension parser + # save dummy data to temporary location + # remove prev generated data + shutil.rmtree(save_dir_path, ignore_errors=True) + save_dir_path.mkdir(parents=True) + + torch.save({"a": "a"}, save_dir_path / "sample1.tar") + np.save( + str(save_dir_path / "sample2.npy"), + RNG.integers(0, 255, (4, 4, 3)), + ) + + imgs = [ + save_dir_path / "sample1.tar", + save_dir_path / "sample2.npy", + ] + with pytest.raises( + TypeError, + match="Cannot load image data from", + ): + _ = PatchDataset(imgs) + + # preproc func for not defined dataset + with pytest.raises( + ValueError, + match=r".* preprocessing .* does not exist.", + ): + predefined_preproc_func("secret-dataset") + + +def test_wsi_patch_dataset( # noqa: PLR0915 + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """A test for creation and bare output.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: + """Testing function.""" + return WSIPatchDataset(img_path=img_path, **kwargs) + + def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: + """Testing function.""" + return reuse_init(mode="wsi", **kwargs) + + # test for ABC validate + # intentionally created to check error + # skipcq + class Proto(PatchDatasetABC): + def __init__(self: Proto) -> None: + super().__init__() + self.inputs = "CRASH" + self._check_input_integrity("wsi") + + # skipcq + def __getitem__(self: Proto, idx: int) -> object: + """Get an item from the dataset.""" + + with pytest.raises( + ValueError, + match=r".*`inputs` should be a list of patch coordinates.*", + ): + Proto() # skipcq + + # invalid path input + with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): + WSIPatchDataset( + img_path="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + ) + + # invalid mask path input + with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): + WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + + # invalid mode + with pytest.raises(ValueError, match="`X` is not supported."): + reuse_init(mode="X") + + # invalid patch + with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): + reuse_init() + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \[512 512 512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512, 512]) + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \['512' 'a'\].", + ): + reuse_init_wsi(patch_input_shape=[512, "a"]) + with pytest.raises(ValueError, match="Invalid `stride_shape` value None."): + reuse_init_wsi(patch_input_shape=512) + # invalid stride + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \['512' 'a'\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, "a"]) + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \[512 512 512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, 512, 512]) + # negative + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \[ 512 -512\].", + ): + reuse_init_wsi(patch_input_shape=[512, -512], stride_shape=[512, 512]) + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \[ 512 -512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, -512]) + + # * for wsi + # dummy test for analysing the output + # stride and patch size should be as expected + patch_size = [512, 512] + stride_size = [256, 256] + ds = reuse_init_wsi( + patch_input_shape=patch_size, + stride_shape=stride_size, + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + reader = WSIReader.open(mini_wsi_svs) + # tiling top to bottom, left to right + ds_roi = ds[2]["image"] + step_idx = 2 # manually calibrate + start = (step_idx * stride_size[1], 0) + end = (start[0] + patch_size[0], start[1] + patch_size[1]) + rd_roi = reader.read_bounds( + start + end, + resolution=1.0, + units="mpp", + coord_space="resolution", + ) + correlation = np.corrcoef( + cv2.cvtColor(ds_roi, cv2.COLOR_RGB2GRAY).flatten(), + cv2.cvtColor(rd_roi, cv2.COLOR_RGB2GRAY).flatten(), + ) + assert ds_roi.shape[0] == rd_roi.shape[0] + assert ds_roi.shape[1] == rd_roi.shape[1] + assert np.min(correlation) > 0.9, correlation + + # test creation with auto mask gen and input mask + ds = reuse_init_wsi( + patch_input_shape=patch_size, + stride_shape=stride_size, + resolution=1.0, + units="mpp", + auto_get_mask=True, + ) + assert len(ds) > 0 + ds = WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path=mini_wsi_msk, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + negative_mask = imread(mini_wsi_msk) + negative_mask = np.zeros_like(negative_mask) + negative_mask_path = tmp_path / "negative_mask.png" + imwrite(negative_mask_path, negative_mask) + with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): + ds = WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path=negative_mask_path, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + # * for tile + reader = WSIReader.open(mini_wsi_jpg) + tile_ds = WSIPatchDataset( + img_path=mini_wsi_jpg, + mode="tile", + patch_input_shape=patch_size, + stride_shape=stride_size, + auto_get_mask=False, + ) + step_idx = 3 # manually calibrate + start = (step_idx * stride_size[1], 0) + end = (start[0] + patch_size[0], start[1] + patch_size[1]) + roi2 = reader.read_bounds( + start + end, + resolution=1.0, + units="baseline", + coord_space="resolution", + ) + roi1 = tile_ds[3]["image"] # match with step_index + correlation = np.corrcoef( + cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), + cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), + ) + assert roi1.shape[0] == roi2.shape[0] + assert roi1.shape[1] == roi2.shape[1] + assert np.min(correlation) > 0.9, correlation + + +def test_patch_dataset_abc() -> None: + """Test for ABC methods. + + Test missing definition for abstract intentionally created to check error. + + """ + + # skipcq + class Proto(PatchDatasetABC): + # skipcq + def __init__(self: Proto) -> None: + super().__init__() + + # crash due to undefined __getitem__ + with pytest.raises(TypeError): + Proto() # skipcq + + # skipcq + class Proto(PatchDatasetABC): + # skipcq + def __init__(self: Proto) -> None: + super().__init__() + + # skipcq + def __getitem__(self: Proto, idx: int) -> None: + """Get an item from the dataset.""" + + ds = Proto() # skipcq + + # test setter and getter + assert ds.preproc_func(1) == 1 + ds.preproc_func = lambda x: x - 1 # skipcq: PYL-W0201 + assert ds.preproc_func(1) == 0 + assert ds.preproc(1) == 1, "Must be unchanged!" + ds.preproc_func = None # skipcq: PYL-W0201 + assert ds.preproc_func(2) == 2 + + # test assign uncallable to preproc_func/postproc_func + with pytest.raises(ValueError, match=r".*callable*"): + ds.preproc_func = 1 # skipcq: PYL-W0201 diff --git a/tests/models/test_feature_extractor.py b/tests/models/test_feature_extractor.py deleted file mode 100644 index cd33f0a5a..000000000 --- a/tests/models/test_feature_extractor.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Test for feature extractor.""" - -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -import torch - -from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone -from tiatoolbox.models.engine.semantic_segmentor import ( - DeepFeatureExtractor, - IOSegmentorConfig, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils.misc import select_device -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_engine(remote_sample: Callable, tmp_path: Path) -> None: - """Test feature extraction with DeepFeatureExtractor engine.""" - save_dir = tmp_path / "output" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * test providing pretrained from torch vs pretrained_model.yaml - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - extractor = DeepFeatureExtractor(batch_size=1, pretrained_model="fcn-tissue_mask") - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - assert len(positions.shape) == 2 - assert len(features.shape) == 4 - - # * test same output between full infer and engine - # pre-emptive clean up - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - - -@pytest.mark.parametrize( - "model", [CNNBackbone("resnet50"), TimmBackbone("efficientnet_b0", pretrained=True)] -) -def test_full_inference( - remote_sample: Callable, tmp_path: Path, model: Callable -) -> None: - """Test full inference with CNNBackbone and TimmBackbone models.""" - save_dir = tmp_path / "output" - # pre-emptive clean up - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[256, 256], - save_resolution={"units": "mpp", "resolution": 8.0}, - ) - - extractor = DeepFeatureExtractor(batch_size=4, model=model) - # should still run because we skip exception - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - - reader = WSIReader.open(mini_wsi_svs) - patches = [ - reader.read_bounds( - positions[patch_idx], - resolution=0.25, - units="mpp", - pad_constant_values=0, - coord_space="resolution", - ) - for patch_idx in range(4) - ] - patches = np.array(patches) - patches = torch.from_numpy(patches) # NHWC - patches = patches.permute(0, 3, 1, 2) # NCHW - patches = patches.type(torch.float32) - model = model.to("cpu") - # Inference mode - model.eval() - with torch.inference_mode(): - _features = model(patches).numpy() - # ! must maintain same batch size and likely same ordering - # ! else the output values will not exactly be the same (still < 1.0e-4 - # ! of epsilon though) - assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1 diff --git a/tests/models/test_abc.py b/tests/models/test_models_abc.py similarity index 91% rename from tests/models/test_abc.py rename to tests/models/test_models_abc.py index f7a60e34c..f5d744f0f 100644 --- a/tests/models/test_abc.py +++ b/tests/models/test_models_abc.py @@ -70,7 +70,6 @@ def forward(self: Proto) -> None: # skipcq def infer_batch() -> None: """Define infer batch.""" - pass # base class definition pass # noqa: PIE790 @pytest.mark.skipif( @@ -141,16 +140,16 @@ def test_model_abc() -> None: model.postproc_func = None # skipcq: PYL-W0201 assert model.postproc_func(2) == 0 - # Test on CPU - model = model.to(device="cpu") - assert isinstance(model, nn.Module) - assert model.dummy_param.device.type == "cpu" - # Test load_weights_from_file() method weights_path = fetch_pretrained_weights("alexnet-kather100k") with pytest.raises(RuntimeError, match=r".*loading state_dict*"): _ = model.load_weights_from_file(weights_path) + # Test on CPU + model = model.to(device="cpu") + assert isinstance(model, nn.Module) + assert model.dummy_param.device.type == "cpu" + def test_model_to() -> None: """Test for placing model on device.""" @@ -165,3 +164,15 @@ def test_model_to() -> None: model = torch_models.resnet18() model = model_to(device="cpu", model=model) assert isinstance(model, nn.Module) + + +def test_get_pretrained_model_not_str() -> None: + """Test TypeError is raised if input is not str.""" + with pytest.raises(TypeError, match="pretrained_model must be a string."): + _ = get_pretrained_model(1) + + +def test_get_pretrained_model_not_in_info() -> None: + """Test ValueError is raised if input is not in info.""" + with pytest.raises(ValueError, match="Pretrained model `alexnet` does not exist."): + _ = get_pretrained_model("alexnet") diff --git a/tests/models/test_multi_task_segmentor.py b/tests/models/test_multi_task_segmentor.py deleted file mode 100644 index 8b234ac55..000000000 --- a/tests/models/test_multi_task_segmentor.py +++ /dev/null @@ -1,420 +0,0 @@ -"""Unit test package for HoVerNet+.""" - -import copy - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest - -from tiatoolbox.models import IOSegmentorConfig, MultiTaskSegmentor, SemanticSegmentor -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection -from tiatoolbox.utils.misc import select_device - -ON_GPU = toolbox_env.has_gpu() -BATCH_SIZE = 1 if not ON_GPU else 8 # 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def semantic_postproc_func(raw_output: np.ndarray) -> np.ndarray: - """Function to post process semantic segmentations. - - Post processes semantic segmentation to form one map output. - - """ - return np.argmax(raw_output, axis=-1) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for multi task segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("svs-1-small")) - save_dir = root_save_dir / "multitask" - shutil.rmtree(save_dir, ignore_errors=True) - - # * generate full output w/o parallel post-processing worker first - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_a = joblib.load(f"{output[0][1]}.0.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - assert multi_segmentor.num_postproc_workers == NUM_POSTPROC_WORKERS - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_b = joblib.load(f"{output[0][1]}.0.dat") - layer_map_b = np.load(f"{output[0][1]}.1.npy") - assert len(inst_dict_b) > 0, "Must have some nuclei" - assert layer_map_b is not None, "Must have some layers." - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - -def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - required_dims = (258, 258) - # above image is 512 x 512 at 0.252 mpp resolution. This is 258 x 258 at 0.500 mpp. - - save_dir = f"{root_save_dir}/multi/" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - layer_map = np.load(f"{output[0][1]}.1.npy") - - assert len(inst_dict) > 0, "Must have some nuclei." - assert layer_map is not None, "Must have some layers." - assert layer_map.shape == required_dims, ( - "Output layer map dimensions must be same as the expected output shape" - ) - - -def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test segmentor when image is masked.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = root_save_dir / "instance" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[512, 512], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - output = multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_functionality_process_instance_predictions( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test the functionality of instance predictions processing.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(4)] - - dummy_reference = [{i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)}] - - dummy_tiles = [np.zeros((512, 512))] - dummy_bounds = np.array([0, 0, 512, 512]) - - multi_segmentor.wsi_layers = [np.zeros_like(raw_maps[0][..., 0])] - multi_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - multi_segmentor._futures = [ - [dummy_reference, [dummy_reference[0].keys()], dummy_tiles, dummy_bounds], - ] - multi_segmentor._merge_post_process_results() - assert len(multi_segmentor._wsi_inst_info[0]) == 0 - - -def test_empty_image(tmp_path: Path) -> None: - """Test MultiTaskSegmentor for an empty image.""" - root_save_dir = Path(tmp_path) - sample_patch = np.ones((256, 256, 3), dtype="uint8") * 255 - sample_patch_path = root_save_dir / "sample_tile.png" - imwrite(sample_patch_path, sample_patch) - - save_dir = root_save_dir / "hovernetplus" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "hovernet" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "semantic" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=2048, - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - -def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - with pytest.raises( - ValueError, - match=r"Output type must be specified for instance or semantic segmentation.", - ): - MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - save_dir = f"{root_save_dir}/multi/" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=2048, - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - multi_segmentor.model.postproc_func = semantic_postproc_func - - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - layer_map = np.load(f"{output[0][1]}.0.npy") - - assert layer_map is not None, "Must have some segmentations." - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/multi/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[512, 512], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernetplus-oed", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Crash."): - multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py deleted file mode 100644 index 2956849fb..000000000 --- a/tests/models/test_nucleus_instance_segmentor.py +++ /dev/null @@ -1,603 +0,0 @@ -"""Test for Nucleus Instance Segmentor.""" - -import copy - -# ! The garbage collector -import gc -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest -import torch -import yaml -from click.testing import CliRunner - -from tiatoolbox import cli, rcParam -from tiatoolbox.models import ( - IOSegmentorConfig, - NucleusInstanceSegmentor, - SemanticSegmentor, -) -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - _process_tile_predictions, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection -from tiatoolbox.utils.misc import select_device -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def helper_tile_info() -> list: - """Helper function for tile information.""" - torch._dynamo.reset() - current_torch_compile_mode = rcParam["torch_compile_mode"] - rcParam["torch_compile_mode"] = "disable" - predictor = NucleusInstanceSegmentor(model="A") - torch._dynamo.reset() - rcParam["torch_compile_mode"] = current_torch_compile_mode - # ! assuming the tiles organized as follows (coming out from - # ! PatchExtractor). If this is broken, need to check back - # ! PatchExtractor output ordering first - # left to right, top to bottom - # --------------------- - # | 0 | 1 | 2 | 3 | - # --------------------- - # | 4 | 5 | 6 | 7 | - # --------------------- - # | 8 | 9 | 10 | 11 | - # --------------------- - # | 12 | 13 | 14 | 15 | - # --------------------- - # ! assume flag index ordering: left right top bottom - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - ], - margin=1, - tile_shape=[4, 4], - stride_shape=[4, 4], - patch_input_shape=[4, 4], - patch_output_shape=[4, 4], - ) - - return predictor._get_tile_info([16, 16], ioconfig) - - -# ---------------------------------------------------- - - -def test_get_tile_info() -> None: - """Test for getting tile info.""" - info = helper_tile_info() - _, flag = info[0] # index 0 should be full grid, removal - # removal flag at top edges - assert ( - np.sum( - np.nonzero(flag[:, 0]) - != np.array([4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), - ) - == 0 - ), "Fail Top" - # removal flag at bottom edges - assert ( - np.sum( - np.nonzero(flag[:, 1]) != np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), - ) - == 0 - ), "Fail Bottom" - # removal flag at left edges - assert ( - np.sum( - np.nonzero(flag[:, 2]) - != np.array([1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]), - ) - == 0 - ), "Fail Left" - # removal flag at right edges - assert ( - np.sum( - np.nonzero(flag[:, 3]) - != np.array([0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14]), - ) - == 0 - ), "Fail Right" - - -def test_vertical_boundary_boxes() -> None: - """Test for vertical boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [3, 0, 5, 4], - [7, 0, 9, 4], - [11, 0, 13, 4], - [3, 4, 5, 8], - [7, 4, 9, 8], - [11, 4, 13, 8], - [3, 8, 5, 12], - [7, 8, 9, 12], - [11, 8, 13, 12], - [3, 12, 5, 16], - [7, 12, 9, 16], - [11, 12, 13, 16], - ], - ) - _flag = np.array( - [ - [0, 1, 0, 0], - [0, 1, 0, 0], - [0, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - ], - ) - boxes, flag = info[1] - assert np.sum(_boxes - boxes) == 0, "Wrong Vertical Bounds" - assert np.sum(flag - _flag) == 0, "Fail Vertical Flag" - - -def test_horizontal_boundary_boxes() -> None: - """Test for horizontal boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [0, 3, 4, 5], - [4, 3, 8, 5], - [8, 3, 12, 5], - [12, 3, 16, 5], - [0, 7, 4, 9], - [4, 7, 8, 9], - [8, 7, 12, 9], - [12, 7, 16, 9], - [0, 11, 4, 13], - [4, 11, 8, 13], - [8, 11, 12, 13], - [12, 11, 16, 13], - ], - ) - _flag = np.array( - [ - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - ], - ) - boxes, flag = info[2] - assert np.sum(_boxes - boxes) == 0, "Wrong Horizontal Bounds" - assert np.sum(flag - _flag) == 0, "Fail Horizontal Flag" - - -def test_cross_section_boundary_boxes() -> None: - """Test for cross-section boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [2, 2, 6, 6], - [6, 2, 10, 6], - [10, 2, 14, 6], - [2, 6, 6, 10], - [6, 6, 10, 10], - [10, 6, 14, 10], - [2, 10, 6, 14], - [6, 10, 10, 14], - [10, 10, 14, 14], - ], - ) - _flag = np.array( - [ - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - ], - ) - boxes, flag = info[3] - assert np.sum(boxes - _boxes) == 0, "Wrong Cross Section Bounds" - assert np.sum(flag - _flag) == 0, "Fail Cross Section Flag" - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/instance/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[512, 512], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - instance_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree("output", ignore_errors=True) - shutil.rmtree(save_dir, ignore_errors=True) - instance_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - instance_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_ci(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for nuclei instance segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - save_dir = f"{root_save_dir}/instance/" - - # * test run on wsi, test run with worker - # resolution for travis testing, not the correct ones - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[1024, 1024], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - inst_segmentor = NucleusInstanceSegmentor( - batch_size=1, - num_loader_workers=0, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_merge_tile_predictions_ci( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Functional tests for merging tile predictions.""" - gc.collect() # Force clean up everything on hold - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 0.5 - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[512, 512], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - # mainly to hook the merge prediction function - inst_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - raw_maps = [[v] for v in raw_maps] # mask it as patch output - - dummy_reference = {i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)} - dummy_flag_mode_list = [ - [[1, 1, 0, 0], 1], - [[0, 0, 1, 1], 2], - [[1, 1, 1, 1], 3], - [[0, 0, 0, 0], 0], - ] - - inst_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - inst_segmentor._futures = [[dummy_reference, dummy_reference.keys()]] - inst_segmentor._merge_post_process_results() - assert len(inst_segmentor._wsi_inst_info) == 0 - - blank_raw_maps = [np.zeros_like(v) for v in raw_maps] - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=dummy_flag_mode_list[0][0], - tile_mode=dummy_flag_mode_list[0][1], - tile_output=[[np.array([0, 0, 512, 512]), blank_raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - for tile_flag, tile_mode in dummy_flag_mode_list: - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=tile_mode, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - # test exception flag - tile_flag = [0, 0, 0, 0] - with pytest.raises(ValueError, match=r".*Unknown tile mode.*"): - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=-1, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for nuclei instance segmentor.""" - root_save_dir = Path(tmp_path) - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * generate full output w/o parallel post-processing worker first - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - batch_size=8, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_a = joblib.load(f"{output[0][1]}.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - assert inst_segmentor.num_postproc_workers == 2 - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_b = joblib.load(f"{output[0][1]}.dat") - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - # ** - # To evaluate the precision of doing post-processing on tile - # then re-assemble without using full image prediction maps, - # we compare its output with the output when doing - # post-processing on the entire images. - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - _, inst_dict_b = semantic_segmentor.model.postproc(raw_maps) - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.9, "Heavy loss of precision!" - - -def test_cli_nucleus_instance_segment_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for nucleus segmentation with IOConfig.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - pretrained_weights = fetch_pretrained_weights("hovernet_fast-pannuke") - - # resolution for travis testing, not the correct ones - config = { - "input_resolutions": [{"units": "mpp", "resolution": resolution}], - "output_resolutions": [ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - "margin": 128, - "tile_shape": [512, 512], - "patch_input_shape": [256, 256], - "patch_output_shape": [164, 164], - "stride_shape": [164, 164], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - - with Path.open(tmp_path / "config.yaml", "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_jpg), - "--pretrained-weights", - str(pretrained_weights), - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--mode", - "tile", - "--output-path", - str(output_path), - "--yaml-config-path", - str(tmp_path.joinpath("config.yaml")), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() - - -def test_cli_nucleus_instance_segment(remote_sample: Callable, tmp_path: Path) -> None: - """Test for nucleus segmentation.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--output-path", - str(output_path), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py deleted file mode 100644 index 913d63241..000000000 --- a/tests/models/test_patch_predictor.py +++ /dev/null @@ -1,1279 +0,0 @@ -"""Test for Patch Predictor.""" - -from __future__ import annotations - -import copy -import shutil -from pathlib import Path -from typing import Callable - -import cv2 -import numpy as np -import pytest -import torch -from click.testing import CliRunner - -from tests.conftest import timed -from tiatoolbox import cli, logger, rcParam -from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor -from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.models.dataset import ( - PatchDataset, - PatchDatasetABC, - WSIPatchDataset, - predefined_preproc_func, -) -from tiatoolbox.utils import download_data, imread, imwrite -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils.misc import select_device -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -RNG = np.random.default_rng() # Numpy Random Generator - -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - - -def test_patch_dataset_path_imgs( - sample_patch1: str | Path, - sample_patch2: str | Path, -) -> None: - """Test for patch dataset with a list of file paths as input.""" - size = (224, 224, 3) - - dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_list_imgs(tmp_path: Path) -> None: - """Test for patch dataset with a list of images as input.""" - save_dir_path = tmp_path - - size = (5, 5, 3) - img = RNG.integers(low=0, high=255, size=size) - list_imgs = [img, img, img] - dataset = PatchDataset(list_imgs) - - dataset.preproc_func = lambda x: x - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - # test for changing to another preproc - dataset.preproc_func = lambda x: x - 10 - item = dataset[0] - assert np.sum(item["image"] - (list_imgs[0] - 10)) == 0 - - # * test for loading npy - # remove previously generated data - if Path.exists(save_dir_path): - shutil.rmtree(save_dir_path, ignore_errors=True) - Path.mkdir(save_dir_path, parents=True) - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - assert imgs[0] is not None - # test for path object - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - - -def test_patch_datasetarray_imgs() -> None: - """Test for patch dataset with a numpy array of a list of images.""" - size = (5, 5, 3) - img = RNG.integers(0, 255, size=size) - list_imgs = [img, img, img] - labels = [1, 2, 3] - array_imgs = np.array(list_imgs) - - # test different setter for label - dataset = PatchDataset(array_imgs, labels=labels) - an_item = dataset[2] - assert an_item["label"] == 3 - dataset = PatchDataset(array_imgs, labels=None) - an_item = dataset[2] - assert "label" not in an_item - - dataset = PatchDataset(array_imgs) - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_crash(tmp_path: Path) -> None: - """Test to make sure patch dataset crashes with incorrect input.""" - # all below examples should fail when input to PatchDataset - save_dir_path = tmp_path - - # not supported input type - imgs = {"a": RNG.integers(0, 255, (4, 4, 4))} - with pytest.raises( - ValueError, - match=r".*Input must be either a list/array of images.*", - ): - _ = PatchDataset(imgs) - - # ndarray of mixed dtype - imgs = np.array( - # string array of the same shape - [ - RNG.integers(0, 255, (4, 5, 3)), - np.array( # skipcq: PYL-E1121 - ["you_should_crash_here" for _ in range(4 * 5 * 3)], - ).reshape( - 4, - 5, - 3, - ), - ], - dtype=object, - ) - with pytest.raises(ValueError, match="Provided input array is non-numerical."): - _ = PatchDataset(imgs) - - # ndarray(s) of NHW images - imgs = RNG.integers(0, 255, (4, 4, 4)) - with pytest.raises(ValueError, match=r".*array of the form HWC*"): - _ = PatchDataset(imgs) - - # list of ndarray(s) with different sizes - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 5, 3)), - ] - with pytest.raises(ValueError, match="Images must have the same dimensions."): - _ = PatchDataset(imgs) - - # list of ndarray(s) with HW and HWC mixed up - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 4)), - ] - with pytest.raises( - ValueError, - match="Each sample must be an array of the form HWC.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = [RNG.integers(0, 255, (4, 4, 3)), "you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = ["you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list not exist paths - with pytest.raises( - ValueError, - match=r".*valid image paths.*", - ): - _ = PatchDataset(["img.npy"]) - - # ** test different extension parser - # save dummy data to temporary location - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - - torch.save({"a": "a"}, save_dir_path / "sample1.tar") - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - - imgs = [ - save_dir_path / "sample1.tar", - save_dir_path / "sample2.npy", - ] - with pytest.raises( - ValueError, - match="Cannot load image data from", - ): - _ = PatchDataset(imgs) - - # preproc func for not defined dataset - with pytest.raises( - ValueError, - match=r".* preprocessing .* does not exist.", - ): - predefined_preproc_func("secret-dataset") - - -def test_wsi_patch_dataset( # noqa: PLR0915 - sample_wsi_dict: dict, - tmp_path: Path, -) -> None: - """A test for creation and bare output.""" - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return WSIPatchDataset(img_path=img_path, **kwargs) - - def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return reuse_init(mode="wsi", **kwargs) - - # test for ABC validate - # intentionally created to check error - # skipcq - class Proto(PatchDatasetABC): - def __init__(self: Proto) -> None: - super().__init__() - self.inputs = "CRASH" - self._check_input_integrity("wsi") - - # skipcq - def __getitem__(self: Proto, idx: int) -> object: - """Get an item from the dataset.""" - - with pytest.raises( - ValueError, - match=r".*`inputs` should be a list of patch coordinates.*", - ): - Proto() # skipcq - - # invalid path input - with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): - WSIPatchDataset( - img_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - ) - - # invalid mask path input - with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): - WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - - # invalid mode - with pytest.raises(ValueError, match="`X` is not supported."): - reuse_init(mode="X") - - # invalid patch - with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): - reuse_init() - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, "a"]) - with pytest.raises(ValueError, match="Invalid `stride_shape` value None."): - reuse_init_wsi(patch_input_shape=512) - # invalid stride - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, "a"]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, 512, 512]) - # negative - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, -512], stride_shape=[512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, -512]) - - # * for wsi - # dummy test for analysing the output - # stride and patch size should be as expected - patch_size = [512, 512] - stride_size = [256, 256] - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - reader = WSIReader.open(mini_wsi_svs) - # tiling top to bottom, left to right - ds_roi = ds[2]["image"] - step_idx = 2 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - rd_roi = reader.read_bounds( - start + end, - resolution=1.0, - units="mpp", - coord_space="resolution", - ) - correlation = np.corrcoef( - cv2.cvtColor(ds_roi, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(rd_roi, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert ds_roi.shape[0] == rd_roi.shape[0] - assert ds_roi.shape[1] == rd_roi.shape[1] - assert np.min(correlation) > 0.9, correlation - - # test creation with auto mask gen and input mask - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=True, - ) - assert len(ds) > 0 - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=mini_wsi_msk, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - negative_mask = imread(mini_wsi_msk) - negative_mask = np.zeros_like(negative_mask) - negative_mask_path = tmp_path / "negative_mask.png" - imwrite(negative_mask_path, negative_mask) - with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=negative_mask_path, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - - # * for tile - reader = WSIReader.open(mini_wsi_jpg) - tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, - mode="tile", - patch_input_shape=patch_size, - stride_shape=stride_size, - auto_get_mask=False, - ) - step_idx = 3 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - roi2 = reader.read_bounds( - start + end, - resolution=1.0, - units="baseline", - coord_space="resolution", - ) - roi1 = tile_ds[3]["image"] # match with step_index - correlation = np.corrcoef( - cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert roi1.shape[0] == roi2.shape[0] - assert roi1.shape[1] == roi2.shape[1] - assert np.min(correlation) > 0.9, correlation - - -def test_patch_dataset_abc() -> None: - """Test for ABC methods. - - Test missing definition for abstract intentionally created to check error. - - """ - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # crash due to undefined __getitem__ - with pytest.raises(TypeError): - Proto() # skipcq - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # skipcq - def __getitem__(self: Proto, idx: int) -> None: - """Get an item from the dataset.""" - - ds = Proto() # skipcq - - # test setter and getter - assert ds.preproc_func(1) == 1 - ds.preproc_func = lambda x: x - 1 # skipcq: PYL-W0201 - assert ds.preproc_func(1) == 0 - assert ds.preproc(1) == 1, "Must be unchanged!" - ds.preproc_func = None # skipcq: PYL-W0201 - assert ds.preproc_func(2) == 2 - - # test assign uncallable to preproc_func/postproc_func - with pytest.raises(ValueError, match=r".*callable*"): - ds.preproc_func = 1 # skipcq: PYL-W0201 - - -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - - -def test_io_patch_predictor_config() -> None: - """Test for IOConfig.""" - # test for creating - cfg = IOPatchPredictorConfig( - patch_input_shape=[224, 224], - stride_shape=[224, 224], - input_resolutions=[{"resolution": 0.5, "units": "mpp"}], - # test adding random kwarg and they should be accessible as kwargs - crop_from_source=True, - ) - assert cfg.crop_from_source - - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_predictor_crash(tmp_path: Path) -> None: - """Test for crash when making predictor.""" - # without providing any model - with pytest.raises(ValueError, match=r"Must provide.*"): - PatchPredictor() - - # provide wrong unknown pretrained model - with pytest.raises(ValueError, match=r"Pretrained .* does not exist"): - PatchPredictor(pretrained_model="secret_model-kather100k") - - # provide wrong model of unknown type, deprecated later with type hint - with pytest.raises(TypeError, match=r".*must be a string.*"): - PatchPredictor(pretrained_model=123) - - # test predict crash - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - with pytest.raises(ValueError, match=r".*not a valid mode.*"): - predictor.predict("aaa", mode="random", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(TypeError, match=r".*must be a list of file paths.*"): - predictor.predict("aaa", mode="wsi", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"): - predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path) - with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): - predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - - -def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: - """Test for delegating args to io config.""" - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - - # test not providing config / full input info for not pretrained models - model = CNNModel("resnet50") - predictor = PatchPredictor(model=model) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict([mini_wsi_svs], mode="wsi", save_dir=tmp_path / "dump") - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - kwargs = { - "patch_input_shape": [512, 512], - "resolution": 1.75, - "units": "mpp", - } - for key in kwargs: - _kwargs = copy.deepcopy(kwargs) - _kwargs.pop(key) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - device=select_device(on_gpu=ON_GPU), - **_kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test providing config / full input info for not pretrained models - ioconfig = IOPatchPredictorConfig( - patch_input_shape=(512, 512), - stride_shape=(256, 256), - input_resolutions=[{"resolution": 1.35, "units": "mpp"}], - ) - predictor.predict( - [mini_wsi_svs], - ioconfig=ioconfig, - mode="wsi", - save_dir=f"{tmp_path}/dump", - device=select_device(on_gpu=ON_GPU), - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - device=select_device(on_gpu=ON_GPU), - **kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test overwriting pretrained ioconfig - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - predictor.predict( - [mini_wsi_svs], - patch_input_shape=(300, 300), - mode="wsi", - device=select_device(on_gpu=ON_GPU), - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.patch_input_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - stride_shape=(300, 300), - mode="wsi", - device=select_device(on_gpu=ON_GPU), - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.stride_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - resolution=1.99, - mode="wsi", - device=select_device(on_gpu=ON_GPU), - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - units="baseline", - mode="wsi", - device=select_device(on_gpu=ON_GPU), - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - predictor.predict( - [mini_wsi_svs], - mode="wsi", - merge_predictions=True, - save_dir=f"{tmp_path}/dump", - device=select_device(on_gpu=ON_GPU), - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - -def test_patch_predictor_api( - sample_patch1: Path, - sample_patch2: Path, - tmp_path: Path, -) -> None: - """Helper function to get the model output using API 1.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent reader complaint - inputs = [Path(sample_patch1), Path(sample_patch2)] - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - # don't run test on GPU - output = predictor.predict( - inputs, - device=select_device(on_gpu=ON_GPU), - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == ["predictions"] - assert len(output["predictions"]) == 2 - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - labels=[1, "a"], - return_labels=True, - device=select_device(on_gpu=ON_GPU), - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions"]) - assert len(output["predictions"]) == len(output["labels"]) - assert output["labels"] == [1, "a"] - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - device=select_device(on_gpu=ON_GPU), - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["probabilities"]) - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - device=select_device(on_gpu=ON_GPU), - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - # test saving output, should have no effect - _ = predictor.predict( - inputs, - device=select_device(on_gpu=ON_GPU), - save_dir="special_dir_not_exist", - ) - assert not Path.is_dir(Path("special_dir_not_exist")) - - # test loading user weight - pretrained_weights_url = ( - "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" - ) - - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - pretrained_weights = ( - save_dir_path / "tmp_pretrained_weigths" / "resnet18-kather100k.pth" - ) - - download_data(pretrained_weights_url, pretrained_weights) - - _ = PatchPredictor( - pretrained_model="resnet18-kather100k", - pretrained_weights=pretrained_weights, - batch_size=1, - ) - - # --- test different using user model - model = CNNModel(backbone="resnet18", num_classes=9) - # test prediction - predictor = PatchPredictor(model=model, batch_size=1, verbose=False) - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - device=select_device(on_gpu=ON_GPU), - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - -def test_wsi_predictor_api( - sample_wsi_dict: dict, - tmp_path: Path, - chdir: Callable, -) -> None: - """Test normal run of wsi predictor.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - patch_size = np.array([224, 224]) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - save_dir = f"{save_dir_path}/model_wsi_output" - - # wrapper to make this more clean - kwargs = { - "return_probabilities": True, - "return_labels": True, - "device": select_device(on_gpu=ON_GPU), - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 1.0, - "units": "baseline", - "save_dir": save_dir, - } - # ! add this test back once the read at `baseline` is fixed - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - # remove previously generated data - shutil.rmtree(save_dir, ignore_errors=True) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "device": select_device(on_gpu=ON_GPU), - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 0.5, - "save_dir": save_dir, - "merge_predictions": True, # to test the api coverage - "units": "mpp", - } - - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = False - # test reading of multiple whole-slide images - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" not in output_info - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - # coverage test - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = True - # test reading of multiple whole-slide images - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - _kwargs = copy.deepcopy(kwargs) - with pytest.raises(FileExistsError): - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - # remove previously generated data - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - with chdir(save_dir_path): - # test reading of multiple whole-slide images - _kwargs = copy.deepcopy(kwargs) - _kwargs["save_dir"] = None # default coverage - _kwargs["return_probabilities"] = False - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - assert Path.exists(Path("output")) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" in output_info - assert Path(output_info["merged"]).exists() - - # remove previously generated data - shutil.rmtree("output", ignore_errors=True) - - -def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: - """Test normal run of wsi predictor with merge predictions option.""" - # convert to pathlib Path to prevent reader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - # blind test - # pseudo output dict from model with 2 patches - output = { - "resolution": 1.0, - "units": "baseline", - "probabilities": [[0.45, 0.55], [0.90, 0.10]], - "predictions": [1, 0], - "coordinates": [[0, 0, 2, 2], [2, 2, 4, 4]], - } - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - ) - _merged = np.array([[2, 2, 0, 0], [2, 2, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]) - assert np.sum(merged - _merged) == 0 - - # blind test for merging probabilities - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - return_raw=True, - ) - _merged = np.array( - [ - [0.45, 0.45, 0, 0], - [0.45, 0.45, 0, 0], - [0, 0, 0.90, 0.90], - [0, 0, 0.90, 0.90], - ], - ) - assert merged.shape == (4, 4, 2) - assert np.mean(np.abs(merged[..., 0] - _merged)) < 1.0e-6 - - # integration test - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "device": select_device(on_gpu=ON_GPU), - "patch_input_shape": np.array([224, 224]), - "stride_shape": np.array([224, 224]), - "resolution": 1.0, - "units": "baseline", - "merge_predictions": True, - } - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - # mock up to change the preproc func and - # force to use the default in merge function - # still should have the same results - kwargs["merge_predictions"] = False - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - merged_tile_output = predictor.merge_predictions( - mini_wsi_jpg, - tile_output[0], - resolution=kwargs["resolution"], - units=kwargs["units"], - ) - tile_output.append(merged_tile_output) - - # first make sure nothing breaks with predictions - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - merged_wsi = wsi_output[1] - merged_tile = tile_output[1] - # ensure shape of merged predictions of tile and wsi input are the same - assert merged_wsi.shape == merged_tile.shape - # ensure consistent predictions between tile and wsi mode - diff = merged_tile == merged_wsi - accuracy = np.sum(diff) / np.size(merged_wsi) - assert accuracy > 0.9, np.nonzero(~diff) - - -def _test_predictor_output( - inputs: list, - pretrained_model: str, - probabilities_check: list | None = None, - predictions_check: list | None = None, - device: str = select_device(on_gpu=ON_GPU), -) -> None: - """Test the predictions of multiple models included in tiatoolbox.""" - predictor = PatchPredictor( - pretrained_model=pretrained_model, - batch_size=32, - verbose=False, - ) - # don't run test on GPU - output = predictor.predict( - inputs, - return_probabilities=True, - return_labels=False, - device=device, - ) - predictions = output["predictions"] - probabilities = output["probabilities"] - for idx, probabilities_ in enumerate(probabilities): - probabilities_max = max(probabilities_) - assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - assert predictions[idx] == predictions_check[idx], ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - - -def test_patch_predictor_kather100k_output( - sample_patch1: Path, - sample_patch2: Path, -) -> None: - """Test the output of patch prediction models on Kather100K dataset.""" - inputs = [Path(sample_patch1), Path(sample_patch2)] - pretrained_info = { - "alexnet-kather100k": [1.0, 0.9999735355377197], - "resnet18-kather100k": [1.0, 0.9999911785125732], - "resnet34-kather100k": [1.0, 0.9979840517044067], - "resnet50-kather100k": [1.0, 0.9999986886978149], - "resnet101-kather100k": [1.0, 0.9999932050704956], - "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], - "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], - "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], - "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], - "densenet121-kather100k": [1.0, 1.0], - "densenet161-kather100k": [1.0, 0.9999959468841553], - "densenet169-kather100k": [1.0, 0.9999934434890747], - "densenet201-kather100k": [1.0, 0.9999983310699463], - "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], - "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], - "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], - "googlenet-kather100k": [1.0, 0.9999639987945557], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[6, 3], - device=select_device(on_gpu=ON_GPU), - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: - """Test the output of patch prediction models on PCam dataset.""" - inputs = [Path(sample_patch3), Path(sample_patch4)] - pretrained_info = { - "alexnet-pcam": [0.999980092048645, 0.9769067168235779], - "resnet18-pcam": [0.999992847442627, 0.9466130137443542], - "resnet34-pcam": [1.0, 0.9976525902748108], - "resnet50-pcam": [0.9999270439147949, 0.9999996423721313], - "resnet101-pcam": [1.0, 0.9997289776802063], - "resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802], - "resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244], - "wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799], - "wide_resnet101_2-pcam": [1.0, 0.9945427179336548], - "densenet121-pcam": [0.9999251365661621, 0.9997479319572449], - "densenet161-pcam": [0.9999969005584717, 0.9662821292877197], - "densenet169-pcam": [0.9999998807907104, 0.9993504881858826], - "densenet201-pcam": [0.9999942779541016, 0.9950824975967407], - "mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986], - "mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085], - "mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972], - "googlenet-pcam": [0.9999929666519165, 0.8701475858688354], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[1, 0], - device=select_device(on_gpu=ON_GPU), - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_command_line_models_file_not_found(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI file not found error.""" - runner = CliRunner() - model_file_not_found_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs)[:-1], - "--file-types", - '"*.ndpi, *.svs"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert model_file_not_found_result.output == "" - assert model_file_not_found_result.exit_code == 1 - assert isinstance(model_file_not_found_result.exception, FileNotFoundError) - - -def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI mode not in wsi, tile.""" - runner = CliRunner() - mode_not_in_wsi_tile_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--file-types", - '"*.ndpi, *.svs"', - "--mode", - '"patch"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert "Invalid value for '--mode'" in mode_not_in_wsi_tile_result.output - assert mode_not_in_wsi_tile_result.exit_code != 0 - assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) - - -def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI single file.""" - runner = CliRunner() - models_wsi_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--mode", - "wsi", - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_wsi_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_single_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI single file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - mini_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - # Make multiple copies for test - dir_path = tmp_path.joinpath("new_copies") - dir_path.mkdir() - - dir_path_masks = tmp_path.joinpath("new_copies_masks") - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("3_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - except OSError: - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("1_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("2_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("3_" + mini_wsi_msk.name)) - - tmp_path = tmp_path.joinpath("output") - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("0.merged.npy").exists() - assert tmp_path.joinpath("0.raw.json").exists() - assert tmp_path.joinpath("1.merged.npy").exists() - assert tmp_path.joinpath("1.raw.json").exists() - assert tmp_path.joinpath("2.merged.npy").exists() - assert tmp_path.joinpath("2.raw.json").exists() - assert tmp_path.joinpath("results.json").exists() - - -# ------------------------------------------------------------------------------------- -# torch.compile -# ------------------------------------------------------------------------------------- - - -def test_patch_predictor_torch_compile( - sample_patch1: Path, - sample_patch2: Path, - tmp_path: Path, -) -> None: - """Test PatchPredictor with with torch.compile functionality. - - Args: - sample_patch1 (Path): Path to sample patch 1. - sample_patch2 (Path): Path to sample patch 2. - tmp_path (Path): Path to temporary directory. - - """ - torch_compile_mode = rcParam["torch_compile_mode"] - torch._dynamo.reset() - rcParam["torch_compile_mode"] = "default" - _, compile_time = timed( - test_patch_predictor_api, - sample_patch1, - sample_patch2, - tmp_path, - ) - logger.info("torch.compile default mode: %s", compile_time) - torch._dynamo.reset() - rcParam["torch_compile_mode"] = "reduce-overhead" - _, compile_time = timed( - test_patch_predictor_api, - sample_patch1, - sample_patch2, - tmp_path, - ) - logger.info("torch.compile reduce-overhead mode: %s", compile_time) - torch._dynamo.reset() - rcParam["torch_compile_mode"] = "max-autotune" - _, compile_time = timed( - test_patch_predictor_api, - sample_patch1, - sample_patch2, - tmp_path, - ) - logger.info("torch.compile max-autotune mode: %s", compile_time) - torch._dynamo.reset() - rcParam["torch_compile_mode"] = torch_compile_mode diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py deleted file mode 100644 index 01776b800..000000000 --- a/tests/models/test_semantic_segmentation.py +++ /dev/null @@ -1,945 +0,0 @@ -"""Test for Semantic Segmentor.""" - -from __future__ import annotations - -import copy - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -import torch -import torch.multiprocessing as torch_mp -import torch.nn.functional as F # noqa: N812 -import yaml -from click.testing import CliRunner -from torch import nn - -from tests.conftest import timed -from tiatoolbox import cli, logger, rcParam -from tiatoolbox.models import SemanticSegmentor -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.architecture.utils import centre_crop -from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, - WSIStreamDataset, -) -from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imread, imwrite -from tiatoolbox.utils.misc import select_device -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -class _CNNTo1(ModelABC): - """Contains a convolution. - - Simple model to test functionality, this contains a single - convolution layer which has weight=0 and bias=1. - - """ - - def __init__(self: _CNNTo1) -> None: - super().__init__() - self.conv = nn.Conv2d(3, 1, 3, padding=1) - self.conv.weight.data.fill_(0) - self.conv.bias.data.fill_(1) - - def forward(self: _CNNTo1, img: np.ndarray) -> torch.Tensor: - """Define how to use layer.""" - return self.conv(img) - - @staticmethod - def infer_batch(model: nn.Module, batch_data: torch.Tensor, device: str) -> list: - """Run inference on an input batch. - - Contains logic for forward operation as well as i/o - aggregation for a single data batch. - - Args: - model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): A batch of data generated by - torch.utils.data.DataLoader. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". - - """ - device = "cuda" if ON_GPU else "cpu" - #### - model.eval() # infer mode - - #### - img_list = batch_data - - img_list = img_list.to(device).type(torch.float32) - img_list = img_list.permute(0, 3, 1, 2) # to NCHW - - hw = np.array(img_list.shape[2:]) - with torch.inference_mode(): # do not compute gradient - logit_list = model(img_list) - logit_list = centre_crop(logit_list, hw // 2) - logit_list = logit_list.permute(0, 2, 3, 1) # to NHWC - prob_list = F.relu(logit_list) - - prob_list = prob_list.cpu().numpy() - return [prob_list] - - -# ------------------------------------------------------------------------------------- -# IOConfig -# ------------------------------------------------------------------------------------- - - -def test_segmentor_ioconfig() -> None: - """Test for IOConfig.""" - default_config = { - "input_resolutions": [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - "output_resolutions": [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - "patch_input_shape": [2048, 2048], - "patch_output_shape": [1024, 1024], - "stride_shape": [512, 512], - } - - # error when uniform resolution units are not uniform - xconfig = copy.deepcopy(default_config) - xconfig["input_resolutions"] = [ - {"units": "mpp", "resolution": 0.25}, - {"units": "power", "resolution": 0.50}, - ] - with pytest.raises(ValueError, match=r".*Invalid resolution units.*"): - _ = IOSegmentorConfig(**xconfig) - - # error when uniform resolution units are not supported - xconfig = copy.deepcopy(default_config) - xconfig["input_resolutions"] = [ - {"units": "alpha", "resolution": 0.25}, - {"units": "alpha", "resolution": 0.50}, - ] - with pytest.raises(ValueError, match=r".*Invalid resolution units.*"): - _ = IOSegmentorConfig(**xconfig) - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - assert ioconfig.highest_input_resolution == {"units": "mpp", "resolution": 0.25} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 1.0 - assert ioconfig.input_resolutions[1]["resolution"] == 0.5 - assert ioconfig.input_resolutions[2]["resolution"] == 1 / 3 - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - output_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - save_resolution={"units": "power", "resolution": 8.0}, - ) - assert ioconfig.highest_input_resolution == {"units": "power", "resolution": 40} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 0.5 - assert ioconfig.input_resolutions[1]["resolution"] == 1.0 - assert ioconfig.save_resolution["resolution"] == 8.0 / 40.0 - - resolutions = [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ] - with pytest.raises(ValueError, match=r".*Unknown units.*"): - ioconfig.scale_to_highest(resolutions, "axx") - - -# ------------------------------------------------------------------------------------- -# Dataset -# ------------------------------------------------------------------------------------- - - -def test_functional_wsi_stream_dataset(remote_sample: Callable) -> None: - """Functional test for WSIStreamDataset.""" - gc.collect() # Force clean up everything on hold - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - mp_manager = torch_mp.Manager() - mp_shared_space = mp_manager.Namespace() - - sds = WSIStreamDataset(ioconfig, [mini_wsi_svs], mp_shared_space) - # test for collate - out = sds.collate_fn([None, 1, 2, 3]) - assert np.sum(out.numpy() != np.array([1, 2, 3])) == 0 - - # artificial data injection - mp_shared_space.wsi_idx = torch.tensor(0) # a scalar - mp_shared_space.patch_inputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - mp_shared_space.patch_outputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - # test read - for _, sample in enumerate(sds): - patch_data, _ = sample - (patch_resolution1, patch_resolution2, patch_resolution3) = patch_data - assert np.round(patch_resolution1.shape[0] / patch_resolution2.shape[0]) == 2 - assert np.round(patch_resolution1.shape[0] / patch_resolution3.shape[0]) == 3 - - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Functional crash tests for segmentor.""" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk")) - - model = _CNNTo1() - - save_dir = tmp_path / "test_crash_segmentor" - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - - # fake injection to trigger Segmentor to create parallel - # post processing workers because baseline Semantic Segmentor does not support - # post processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # * test basic crash - with pytest.raises(TypeError, match=r".*`mask_reader`.*"): - semantic_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) - with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): - semantic_segmentor.filter_coordinates( - WSIReader.open(mini_wsi_msk), - np.array([1.0, 2.0]), - ) - semantic_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True) - with pytest.raises(ValueError, match=r".*must be a valid file path.*"): - semantic_segmentor.get_reader( - mini_wsi_msk, - "not_exist", - "wsi", - auto_get_mask=True, - ) - - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - with pytest.raises(ValueError, match=r".*provide.*"): - SemanticSegmentor() - with pytest.raises(ValueError, match=r".*valid mode.*"): - semantic_segmentor.predict([], mode="abc") - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - with pytest.raises(ValueError, match=r".*already exists.*"): - semantic_segmentor.predict( - [], - mode="tile", - patch_input_shape=(2048, 2048), - save_dir=save_dir, - ) - shutil.rmtree(save_dir, ignore_errors=True) - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - shutil.rmtree(save_dir, ignore_errors=True) - - # * Test crash propagation when parallelize post-processing - semantic_segmentor.num_postproc_workers = 2 - semantic_segmentor.model.forward = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir, - ) - shutil.rmtree(save_dir, ignore_errors=True) - - # test ignore crash - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=False, - save_dir=save_dir, - ) - - -def test_functional_segmentor_merging(tmp_path: Path) -> None: - """Functional test for assmebling output.""" - save_dir = Path(tmp_path) - - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - # predictions with HW - _output = np.array( - [ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 2, 2], - [0, 0, 2, 2], - ], - ) - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # a second rerun to test overlapping count, - # should still maintain same result - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # else will leave hanging file pointer - # and hence cant remove its folder later - del canvas # skipcq - - # * predictions with HWC - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - _ = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - del _ # skipcq - - # * test crashing when switch to image having larger - # * shape but still provide old links - semantic_segmentor.merge_prediction( - [8, 8], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - with pytest.raises(ValueError, match=r".*`save_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.py", - ) - - with pytest.raises(ValueError, match=r".*`cache_count_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - # * test non HW predictions - with pytest.raises(ValueError, match=r".*Prediction is no HW or HWC.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2,), 1), np.full((2,), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - - # * with an out of bound location - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [ - np.full((2, 2), 1), - np.full((2, 2), 2), - np.full((2, 2), 3), - np.full((2, 2), 4), - ], - [[0, 0, 2, 2], [2, 2, 4, 4], [0, 4, 2, 6], [4, 0, 6, 2]], - save_path=None, - ) - assert np.sum(canvas - _output) < 1.0e-8 - del canvas # skipcq - - -def test_functional_segmentor( - remote_sample: Callable, - tmp_path: Path, - chdir: Callable, -) -> None: - """Functional test for segmentor.""" - save_dir = tmp_path / "dump" - # # convert to pathlib Path to prevent wsireader complaint - resolution = 2.0 - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - mini_wsi_msk = f"{tmp_path}/mini_mask.jpg" - imwrite(mini_wsi_msk, (thumb > 0).astype(np.uint8)) - - # preemptive clean up - shutil.rmtree(save_dir, ignore_errors=True) - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - # fake injection to trigger Segmentor to create parallel - # post-processing workers because baseline Semantic Segmentor does not support - # post-processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # should still run because we skip exception - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - patch_input_shape=(512, 512), - resolution=resolution, - units="mpp", - crash_on_exception=False, - save_dir=save_dir, - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - patch_input_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=True, - save_dir=save_dir, - ) - shutil.rmtree(save_dir, ignore_errors=True) - - with chdir(tmp_path): - # * check exception bypass in the log - # there should be no exception, but how to check the log? - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - patch_input_shape=(512, 512), - patch_output_shape=(512, 512), - stride_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=False, - ) - shutil.rmtree( - tmp_path / "output", - ignore_errors=True, - ) # default output dir test - - # * test basic running and merging prediction - # * should dumping all 1 in the output - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "baseline", "resolution": 1.0}], - output_resolutions=[{"units": "baseline", "resolution": 1.0}], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - file_list = [ - mini_wsi_jpg, - mini_wsi_jpg, - ] - output_list = semantic_segmentor.predict( - file_list, - mode="tile", - device=select_device(on_gpu=ON_GPU), - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - pred_2 = np.load(output_list[1][1] + ".raw.0.npy") - assert len(output_list) == 2 - assert np.sum(pred_1 - pred_2) == 0 - # due to overlapping merge and division, will not be - # exactly 1, but should be approximately so - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # * test running with mask and svs - # * also test merging prediction at designated resolution - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[{"units": "mpp", "resolution": resolution}], - save_resolution={"units": "mpp", "resolution": resolution}, - patch_input_shape=[512, 512], - patch_output_shape=[256, 256], - stride_shape=[512, 512], - ) - shutil.rmtree(save_dir, ignore_errors=True) - output_list = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - reader = WSIReader.open(mini_wsi_svs) - expected_shape = reader.slide_dimensions(**ioconfig.save_resolution) - expected_shape = np.array(expected_shape)[::-1] # to YX - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - saved_shape = np.array(pred_1.shape[:2]) - assert np.sum(expected_shape - saved_shape) == 0 - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # check normal run with auto get mask - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - model=model, - auto_generate_mask=True, - ) - _ = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -def test_subclass(remote_sample: Callable, tmp_path: Path) -> None: - """Create subclass and test parallel processing setup.""" - save_dir = Path(tmp_path) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - model = _CNNTo1() - - class XSegmentor(SemanticSegmentor): - """Dummy class to test subclassing.""" - - def __init__(self: XSegmentor) -> None: - super().__init__(model=model) - self.num_postproc_worker = 2 - - semantic_segmentor = XSegmentor() - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - patch_input_shape=(1024, 1024), - patch_output_shape=(512, 512), - stride_shape=(256, 256), - resolution=1.0, - units="baseline", - crash_on_exception=False, - save_dir=save_dir / "raw", - ) - - -# specifically designed for travis -def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None: - """Test for load up pretrained and over-writing tile mode ioconfig.""" - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=1.0, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - assert save_dir.joinpath("raw/0.raw.0.npy").exists() - assert save_dir.joinpath("raw/file_map.dat").exists() - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - wsi_with_artifacts = Path(remote_sample("wsi3_20k_20k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [wsi_with_artifacts], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("wsi3_20k_20k_pred"))) - _test_pred = np.load(str(save_dir / "raw" / "0.raw.0.npy")) - _test_pred = (_test_pred[..., 1] > 0.75) * 255 - # divide 255 to binarize - assert np.mean(_cache_pred[..., 0] == _test_pred) > 0.99 - - shutil.rmtree(save_dir, ignore_errors=True) - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_bcss_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - - wsi_breast = Path(remote_sample("wsi4_4k_4k_svs")) - semantic_segmentor = SemanticSegmentor( - num_loader_workers=4, - batch_size=BATCH_SIZE, - pretrained_model="fcn_resnet50_unet-bcss", - ) - semantic_segmentor.predict( - [wsi_breast], - mode="wsi", - device=select_device(on_gpu=ON_GPU), - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = np.load(Path(remote_sample("wsi4_4k_4k_pred"))) - _test_pred = np.load(f"{save_dir}/raw/0.raw.0.npy") - _test_pred = np.argmax(_test_pred, axis=-1) - assert np.mean(np.abs(_cache_pred - _test_pred)) < 1.0e-2 - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_cli_semantic_segment_out_exists_error( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation if output path exists.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path, - ], - ) - - assert semantic_segment_result.output == "" - assert semantic_segment_result.exit_code == 1 - assert isinstance(semantic_segment_result.exception, FileExistsError) - - -def test_cli_semantic_segmentation_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation single file custom ioconfig.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") - - config = { - "input_resolutions": [{"units": "mpp", "resolution": 2.0}], - "output_resolutions": [{"units": "mpp", "resolution": 2.0}], - "patch_input_shape": [1024, 1024], - "patch_output_shape": [512, 512], - "stride_shape": [256, 256], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - with Path.open(tmp_path.joinpath("config.yaml"), "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--pretrained-weights", - str(pretrained_weights), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path.joinpath("output"), - "--yaml-config-path", - tmp_path.joinpath("config.yaml"), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("output/0.raw.0.npy").exists() - assert tmp_path.joinpath("output/file_map.dat").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_semantic_segmentation_multi_file( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path / "small_svs_tissue_mask.jpg" - - # Make multiple copies for test - dir_path = tmp_path / "new_copies" - dir_path.mkdir() - - dir_path_masks = tmp_path / "new_copies_masks" - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - dir_path_masks.joinpath("2_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - except OSError: - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("1_" + sample_wsi_msk.name)) - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("2_" + sample_wsi_msk.name)) - - tmp_path = tmp_path / "output" - - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("0.raw.0.npy").exists() - assert tmp_path.joinpath("1.raw.0.npy").exists() - assert tmp_path.joinpath("file_map.dat").exists() - assert tmp_path.joinpath("results.json").exists() - - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("small_svs_tissue_mask"))) - _test_pred = np.load(str(tmp_path.joinpath("0.raw.0.npy"))) - _test_pred = (_test_pred[..., 1] > 0.50) * 255 - - assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3 - - -# ------------------------------------------------------------------------------------- -# torch.compile -# ------------------------------------------------------------------------------------- - - -def test_semantic_segmentor_torch_compile( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test SemanticSegmentor using pretrained model with torch.compile functionality. - - Args: - remote_sample (Callable): Callable object used to extract remote sample. - tmp_path (Path): Path to temporary directory. - - """ - torch_compile_mode = rcParam["torch_compile_mode"] - torch._dynamo.reset() - rcParam["torch_compile_mode"] = "default" - _, compile_time = timed( - test_functional_pretrained, - remote_sample, - tmp_path, - ) - logger.info("torch.compile default mode: %s", compile_time) - torch._dynamo.reset() - rcParam["torch_compile_mode"] = "reduce-overhead" - _, compile_time = timed( - test_functional_pretrained, - remote_sample, - tmp_path, - ) - logger.info("torch.compile reduce-overhead mode: %s", compile_time) - torch._dynamo.reset() - rcParam["torch_compile_mode"] = "max-autotune" - _, compile_time = timed( - test_functional_pretrained, - remote_sample, - tmp_path, - ) - logger.info("torch.compile max-autotune mode: %s", compile_time) - torch._dynamo.reset() - rcParam["torch_compile_mode"] = torch_compile_mode diff --git a/tests/test_app_bokeh.py b/tests/test_app_bokeh.py index cb99a3df3..87b074a71 100644 --- a/tests/test_app_bokeh.py +++ b/tests/test_app_bokeh.py @@ -18,7 +18,7 @@ import requests from bokeh.application import Application from bokeh.application.handlers import FunctionHandler -from bokeh.events import ButtonClick, DoubleTap, MenuItemClick +from bokeh.events import DoubleTap, MenuItemClick from flask_cors import CORS from matplotlib import colormaps from PIL import Image @@ -423,54 +423,7 @@ def test_load_img_overlay(doc: Document, data_path: pytest.TempPathFactory) -> N assert main.UI["p"].renderers[main.UI["vstate"].layer_dict["layer2"]].alpha == 0.4 -def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> None: - """Test running hovernet on a box.""" - slide_select = doc.get_model_by_name("slide_select0") - slide_select.value = [data_path["slide2"].name] - run_button = doc.get_model_by_name("to_model0") - assert len(main.UI["color_column"].children) == 0 - slide_select.value = [data_path["slide1"].name] - # set up a box selection - main.UI["box_source"].data = { - "x": [1200], - "y": [-2000], - "width": [400], - "height": [400], - } - - # select hovernet model and run it on box - model_select = doc.get_model_by_name("model_drop0") - model_select.value = "hovernet" - - click = ButtonClick(run_button) - run_button._trigger_event(click) - im = get_tile("overlay", 4, 8, 4, show=False) - _, num = label(np.any(im[:, :, :3], axis=2)) - # check there are multiple cells being detected - assert len(main.UI["color_column"].children) > 3 - assert num > 10 - - # test save functionality - save_button = doc.get_model_by_name("save_button0") - click = ButtonClick(save_button) - save_button._trigger_event(click) - saved_path = ( - data_path["base_path"] - / "overlays" - / (data_path["slide1"].stem + "_saved_anns.db") - ) - assert saved_path.exists() - - # load an overlay with different types - cprop_select = doc.get_model_by_name("cprop0") - cprop_select.value = ["prob"] - layer_drop = doc.get_model_by_name("layer_drop0") - click = MenuItemClick(layer_drop, str(data_path["dat_anns"])) - layer_drop._trigger_event(click) - assert main.UI["vstate"].types == ["annotation"] - # check the per-type ui controls have been updated - assert len(main.UI["color_column"].children) == 1 - assert len(main.UI["type_column"].children) == 1 +# test_hovernet_on_box should be fixed before merge to develop. def test_alpha_sliders(doc: Document) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index a06d57d90..8a5553dee 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1648,6 +1648,7 @@ def test_patch_pred_store() -> None: """Test patch_pred_store.""" # Define a mock patch_output patch_output = { + "probabilities": [(0.99, 0.01), (0.01, 0.99), (0.99, 0.01)], "predictions": [1, 0, 1], "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], "other": "other", @@ -1668,6 +1669,17 @@ def test_patch_pred_store() -> None: with pytest.raises(ValueError, match="coordinates"): misc.dict_to_store(patch_output, (1.0, 1.0)) + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "other": "other", + } + + store = misc.dict_to_store(patch_output, (1.0, 1.0)) + + # Check that it is an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + def test_patch_pred_store_cdict() -> None: """Test patch_pred_store with a class dict.""" diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 26f85625e..02a5a8c86 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -86,6 +86,26 @@ def cli_file_type( ) +def cli_output_type( + usage_help: str = "The format of the output type. " + "'output_type' can be 'zarr' or 'AnnotationStore'. " + "Default value is 'AnnotationStore'.", + default: str = "AnnotationStore", + input_type: click.Choice | None = None, +) -> Callable: + """Enables --file-types option for cli.""" + click_choices = click.Choice( + choices=["zarr", "AnnotationStore"], case_sensitive=False + ) + input_type = click_choices if input_type is None else input_type + return click.option( + "--output-type", + help=add_default_to_usage_help(usage_help, default), + default=default, + type=input_type, + ) + + def cli_mode( usage_help: str = "Selected mode to show or save the required information.", default: str = "save", @@ -102,6 +122,20 @@ def cli_mode( ) +def cli_patch_mode( + usage_help: str = "Whether to run the model in patch mode or WSI mode.", + *, + default: bool = False, +) -> Callable: + """Enables --return-probabilities option for cli.""" + return click.option( + "--patch-mode", + type=bool, + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_region( usage_help: str = "Image region in the whole slide image to read from. " "default=0 0 2000 2000", @@ -215,7 +249,7 @@ def cli_pretrained_model( ) -> Callable: """Enables --pretrained-model option for cli.""" return click.option( - "--pretrained-model", + "--model", help=add_default_to_usage_help(usage_help, default), default=default, ) @@ -234,6 +268,39 @@ def cli_pretrained_weights( ) +def cli_model( + usage_help: str = "Name of the predefined model used to process the data. " + "The format is _. For example, " + "`resnet18-kather100K` is a resnet18 model trained on the Kather dataset. " + "Please see " + "https://tia-toolbox.readthedocs.io/en/latest/usage.html#deep-learning-models " + "for a detailed list of available pretrained models." + "By default, the corresponding pretrained weights will also be" + "downloaded. However, you can override with your own set of weights" + "via the `pretrained_weights` argument. Argument is case insensitive.", + default: str = "resnet18-kather100k", +) -> Callable: + """Enables --pretrained-model option for cli.""" + return click.option( + "--model", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + +def cli_weights( + usage_help: str = "Path to the model weight file. If not supplied, the default " + "pretrained weight will be used.", + default: str | None = None, +) -> Callable: + """Enables --pretrained-weights option for cli.""" + return click.option( + "--weights", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_device( usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.", default: str = "cpu", @@ -277,7 +344,7 @@ def cli_merge_predictions( def cli_return_labels( usage_help: str = "Whether to return raw model output as labels.", *, - default: bool = True, + default: bool = False, ) -> Callable: """Enables --return-labels option for cli.""" return click.option( diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index fdb4b95ca..cba248b5e 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -67,7 +67,7 @@ def nucleus_instance_segment( verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" - from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor + from tiatoolbox.models import IOInstanceSegmentorConfig, NucleusInstanceSegmentor from tiatoolbox.utils import save_as_json files_all, masks_all, output_path = prepare_model_cli( @@ -78,7 +78,7 @@ def nucleus_instance_segment( ) ioconfig = prepare_ioconfig_seg( - IOSegmentorConfig, + IOInstanceSegmentorConfig, pretrained_weights, yaml_config_path, ) diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index 069b6c367..a534c8f50 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -2,25 +2,23 @@ from __future__ import annotations -import click - from tiatoolbox.cli.common import ( cli_batch_size, cli_device, cli_file_type, cli_img_input, cli_masks, - cli_merge_predictions, - cli_mode, + cli_model, cli_num_loader_workers, cli_output_path, - cli_pretrained_model, - cli_pretrained_weights, + cli_output_type, + cli_patch_mode, cli_resolution, cli_return_labels, cli_return_probabilities, cli_units, cli_verbose, + cli_weights, prepare_model_cli, tiatoolbox_cli, ) @@ -35,45 +33,42 @@ @cli_file_type( default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", ) -@cli_mode( - usage_help="Type of input file to process.", - default="wsi", - input_type=click.Choice(["patch", "wsi", "tile"], case_sensitive=False), -) -@cli_pretrained_model(default="resnet18-kather100k") -@cli_pretrained_weights() -@cli_return_probabilities(default=False) -@cli_merge_predictions(default=True) -@cli_return_labels(default=True) +@cli_model(default="resnet18-kather100k") +@cli_weights() @cli_device(default="cpu") @cli_batch_size(default=1) @cli_resolution(default=0.5) @cli_units(default="mpp") @cli_masks(default=None) @cli_num_loader_workers(default=0) +@cli_output_type( + default="AnnotationStore", +) +@cli_patch_mode(default=False) +@cli_return_probabilities(default=True) +@cli_return_labels(default=False) @cli_verbose(default=True) def patch_predictor( - pretrained_model: str, - pretrained_weights: str, + model: str, + weights: str, img_input: str, file_types: str, masks: str | None, - mode: str, output_path: str, batch_size: int, resolution: float, units: str, num_loader_workers: int, device: str, + output_type: str, *, return_probabilities: bool, return_labels: bool, - merge_predictions: bool, + patch_mode: bool, verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" - from tiatoolbox.models import PatchPredictor - from tiatoolbox.utils import save_as_json + from tiatoolbox.models.engine.patch_predictor import PatchPredictor files_all, masks_all, output_path = prepare_model_cli( img_input=img_input, @@ -83,26 +78,22 @@ def patch_predictor( ) predictor = PatchPredictor( - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + model=model, + weights=weights, batch_size=batch_size, num_loader_workers=num_loader_workers, verbose=verbose, ) - output = predictor.predict( - imgs=files_all, + _ = predictor.run( + images=files_all, masks=masks_all, - mode=mode, - return_probabilities=return_probabilities, - merge_predictions=merge_predictions, - labels=None, - return_labels=return_labels, + patch_mode=patch_mode, resolution=resolution, units=units, device=device, save_dir=output_path, - save_output=True, + output_type=output_type, + return_probabilities=return_probabilities, + return_labels=return_labels, ) - - save_as_json(output, str(output_path.joinpath("results.json"))) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index c94a5e45c..3dd13f0d8 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -6,7 +6,7 @@ alexnet-kather100k: backbone: alexnet num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -20,7 +20,7 @@ resnet18-kather100k: backbone: resnet18 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -34,7 +34,7 @@ resnet34-kather100k: backbone: resnet34 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -48,7 +48,7 @@ resnet50-kather100k: backbone: resnet50 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -62,7 +62,7 @@ resnet101-kather100k: backbone: resnet101 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -76,7 +76,7 @@ resnext50_32x4d-kather100k: backbone: resnext50_32x4d num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -90,7 +90,7 @@ resnext101_32x8d-kather100k: backbone: resnext101_32x8d num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -104,7 +104,7 @@ wide_resnet50_2-kather100k: backbone: wide_resnet50_2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -118,7 +118,7 @@ wide_resnet101_2-kather100k: backbone: wide_resnet101_2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -132,7 +132,7 @@ densenet121-kather100k: backbone: densenet121 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -146,7 +146,7 @@ densenet161-kather100k: backbone: densenet161 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -160,7 +160,7 @@ densenet169-kather100k: backbone: densenet169 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -174,7 +174,7 @@ densenet201-kather100k: backbone: densenet201 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -188,7 +188,7 @@ mobilenet_v2-kather100k: backbone: mobilenet_v2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -202,7 +202,7 @@ mobilenet_v3_large-kather100k: backbone: mobilenet_v3_large num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -216,7 +216,7 @@ mobilenet_v3_small-kather100k: backbone: mobilenet_v3_small num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -230,7 +230,7 @@ googlenet-kather100k: backbone: googlenet num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -245,7 +245,7 @@ alexnet-pcam: backbone: alexnet num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -259,7 +259,7 @@ resnet18-pcam: backbone: resnet18 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -273,7 +273,7 @@ resnet34-pcam: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -287,7 +287,7 @@ resnet50-pcam: backbone: resnet50 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -301,7 +301,7 @@ resnet101-pcam: backbone: resnet101 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -315,7 +315,7 @@ resnext50_32x4d-pcam: backbone: resnext50_32x4d num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -329,7 +329,7 @@ resnext101_32x8d-pcam: backbone: resnext101_32x8d num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -343,7 +343,7 @@ wide_resnet50_2-pcam: backbone: wide_resnet50_2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -357,7 +357,7 @@ wide_resnet101_2-pcam: backbone: wide_resnet101_2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -371,7 +371,7 @@ densenet121-pcam: backbone: densenet121 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -385,7 +385,7 @@ densenet161-pcam: backbone: densenet161 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -399,7 +399,7 @@ densenet169-pcam: backbone: densenet169 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -413,7 +413,7 @@ densenet201-pcam: backbone: densenet201 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -427,7 +427,7 @@ mobilenet_v2-pcam: backbone: mobilenet_v2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -441,7 +441,7 @@ mobilenet_v3_large-pcam: backbone: mobilenet_v3_large num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -455,7 +455,7 @@ mobilenet_v3_small-pcam: backbone: mobilenet_v3_small num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -469,7 +469,7 @@ googlenet-pcam: backbone: googlenet num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -484,7 +484,7 @@ resnet18-idars-tumour: backbone: resnet18 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [512, 512] stride_shape: [512, 512] @@ -497,7 +497,7 @@ resnet34-idars-msi: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -510,7 +510,7 @@ resnet34-idars-braf: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -523,7 +523,7 @@ resnet34-idars-cimp: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -536,7 +536,7 @@ resnet34-idars-cin: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -549,7 +549,7 @@ resnet34-idars-tp53: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -562,7 +562,7 @@ resnet34-idars-hm: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -579,7 +579,7 @@ fcn-tissue_mask: encoder: "resnet50" decoder_block: [3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'mpp', 'resolution': 2.0} @@ -600,7 +600,7 @@ fcn_resnet50_unet-bcss: encoder: "resnet50" decoder_block: [3, 3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'mpp', 'resolution': 0.25} @@ -625,7 +625,7 @@ unet_tissue_mask_tsef: encoder: "resnet50" decoder_block: [3, 3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 1.0} @@ -652,7 +652,7 @@ hovernet_fast-pannuke: 5: "Non-Neoplastic Epithelial", } ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -682,7 +682,7 @@ hovernet_fast-monusac: 4: "Neutrophil", } ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -712,7 +712,7 @@ hovernet_original-consep: 4: "Miscellaneous", } ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -735,7 +735,7 @@ hovernet_original-kumar: num_types: null # None in python ?, only do instance segmentation mode: "original" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -769,7 +769,7 @@ hovernetplus-oed: 4: "Keratin", } ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.50} @@ -793,7 +793,7 @@ micronet-consep: num_input_channels: 3 num_output_channels: 2 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 39d1441ce..ab52740ed 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -1,7 +1,8 @@ """Models package for the models implemented in tiatoolbox.""" -from tiatoolbox.models import architecture, dataset, engine, models_abc +from __future__ import annotations +from . import architecture, dataset, engine, models_abc from .architecture.hovernet import HoVerNet from .architecture.hovernetplus import HoVerNetPlus from .architecture.idars import IDaRS @@ -9,31 +10,40 @@ from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN -from .engine.multi_task_segmentor import MultiTaskSegmentor -from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor -from .engine.patch_predictor import ( +from .dataset import PatchDataset, WSIPatchDataset, WSIStreamDataset +from .engine.io_config import ( + IOInstanceSegmentorConfig, IOPatchPredictorConfig, - PatchDataset, - PatchPredictor, - WSIPatchDataset, -) -from .engine.semantic_segmentor import ( - DeepFeatureExtractor, IOSegmentorConfig, - SemanticSegmentor, - WSIStreamDataset, + ModelIOConfigABC, ) +from .engine.multi_task_segmentor import MultiTaskSegmentor +from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from .engine.patch_predictor import PatchPredictor +from .engine.semantic_segmentor import DeepFeatureExtractor, SemanticSegmentor __all__ = [ "SCCNN", + "DeepFeatureExtractor", "HoVerNet", "HoVerNetPlus", "IDaRS", + "IOInstanceSegmentorConfig", + "IOPatchPredictorConfig", + "IOSegmentorConfig", "MapDe", "MicroNet", + "ModelIOConfigABC", "MultiTaskSegmentor", "NuClick", "NucleusInstanceSegmentor", + "PatchDataset", "PatchPredictor", "SemanticSegmentor", + "WSIPatchDataset", + "WSIStreamDataset", + "architecture", + "dataset", + "engine", + "models_abc", ] diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index a2c33dc4f..a056f6651 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -2,21 +2,20 @@ from __future__ import annotations -import os from pydoc import locate -from typing import TYPE_CHECKING, Optional, Union - -import torch +from typing import TYPE_CHECKING from tiatoolbox import rcParam from tiatoolbox.models.dataset.classification import predefined_preproc_func +from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils import download_data if TYPE_CHECKING: # pragma: no cover from pathlib import Path - from tiatoolbox.models.models_abc import IOConfigABC + import torch + from tiatoolbox.models.engine.io_config import ModelIOConfigABC __all__ = ["fetch_pretrained_weights", "get_pretrained_model"] PRETRAINED_INFO = rcParam["pretrained_model_info"] @@ -64,7 +63,7 @@ def get_pretrained_model( pretrained_weights: str | Path | None = None, *, overwrite: bool = False, -) -> tuple[torch.nn.Module, IOConfigABC]: +) -> tuple[torch.nn.Module, ModelIOConfigABC]: """Load a predefined PyTorch model with the appropriate pretrained weights. Args: @@ -144,15 +143,12 @@ def get_pretrained_model( overwrite=overwrite, ) - # ! assume to be saved in single GPU mode - # always load on to the CPU - saved_state_dict = torch.load(pretrained_weights, map_location="cpu") - model.load_state_dict(saved_state_dict, strict=True) + model = load_torch_model(model=model, weights=pretrained_weights) # ! io_info = info["ioconfig"] creator = locate(f"tiatoolbox.models.engine.{io_info['class']}") - iostate = creator(**io_info["kwargs"]) - return model, iostate + ioconfig = creator(**io_info["kwargs"]) + return model, ioconfig diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 9798df62a..38dfb06ec 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -784,7 +784,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: @staticmethod def infer_batch( # skipcq: PYL-W0221 - model: nn.Module, batch_data: np.ndarray, *, device: str + model: nn.Module, batch_data: np.ndarray, device: str ) -> tuple: """Run inference on an input batch. diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index c7d3d1498..cb487ec53 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -175,7 +175,7 @@ def _infer_batch( with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar - return output.cpu().numpy() + return {"probabilities": output.cpu().numpy()} class CNNModel(ModelABC): diff --git a/tiatoolbox/models/dataset/__init__.py b/tiatoolbox/models/dataset/__init__.py index 9c09991fa..16c80fd18 100644 --- a/tiatoolbox/models/dataset/__init__.py +++ b/tiatoolbox/models/dataset/__init__.py @@ -1,9 +1,21 @@ """Contains dataset functionality for use with models in tiatoolbox.""" -from tiatoolbox.models.dataset.classification import ( +from tiatoolbox.models.dataset.classification import predefined_preproc_func + +from .dataset_abc import ( PatchDataset, + PatchDatasetABC, WSIPatchDataset, - predefined_preproc_func, + WSIStreamDataset, ) -from tiatoolbox.models.dataset.dataset_abc import PatchDatasetABC -from tiatoolbox.models.dataset.info import DatasetInfoABC, KatherPatchDataset +from .info import DatasetInfoABC, KatherPatchDataset + +__all__ = [ + "DatasetInfoABC", + "KatherPatchDataset", + "PatchDataset", + "PatchDatasetABC", + "WSIPatchDataset", + "WSIStreamDataset", + "predefined_preproc_func", +] diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index edf4e28aa..c1bf8fa8c 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -2,25 +2,14 @@ from __future__ import annotations -from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING -import cv2 -import numpy as np from torchvision import transforms -from tiatoolbox import logger -from tiatoolbox.models.dataset import dataset_abc -from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils import imread -from tiatoolbox.wsicore.wsimeta import WSIMeta -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader - if TYPE_CHECKING: # pragma: no cover + import numpy as np import torch - from PIL.Image import Image - - from tiatoolbox.type_hints import IntPair, Resolution, Units + from PIL import Image class _TorchPreprocCaller: @@ -72,291 +61,3 @@ def predefined_preproc_func(dataset_name: str) -> _TorchPreprocCaller: preprocs = preproc_dict[dataset_name] return _TorchPreprocCaller(preprocs) - - -class PatchDataset(dataset_abc.PatchDatasetABC): - """Define PatchDataset for torch inference. - - Define a simple patch dataset, which inherits from the - `torch.utils.data.Dataset` class. - - Attributes: - inputs (list or np.ndarray): - Either a list of patches, where each patch is a ndarray or a - list of valid path with its extension be (".jpg", ".jpeg", - ".tif", ".tiff", ".png") pointing to an image. - labels (list): - List of labels for sample at the same index in `inputs`. - Default is `None`. - - Examples: - >>> # A user defined preproc func and expected behavior - >>> preproc_func = lambda img: img/2 # reduce intensity by half - >>> transformed_img = preproc_func(img) - >>> # create a dataset to get patches preprocessed by the above function - >>> ds = PatchDataset( - ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], - ... labels=["labels1", "labels2"], - ... ) - - """ - - def __init__( - self: PatchDataset, - inputs: np.ndarray | list, - labels: list | None = None, - ) -> None: - """Initialize :class:`PatchDataset`.""" - super().__init__() - - self.data_is_npy_alike = False - - self.inputs = inputs - self.labels = labels - - # perform check on the input - self._check_input_integrity(mode="patch") - - def __getitem__(self: PatchDataset, idx: int) -> dict: - """Get an item from the dataset.""" - patch = self.inputs[idx] - - # Mode 0 is list of paths - if not self.data_is_npy_alike: - patch = self.load_img(patch) - - # Apply preprocessing to selected patch - patch = self._preproc(patch) - - data = { - "image": patch, - } - if self.labels is not None: - data["label"] = self.labels[idx] - return data - - return data - - -class WSIPatchDataset(dataset_abc.PatchDatasetABC): - """Define a WSI-level patch dataset. - - Attributes: - reader (:class:`.WSIReader`): - A WSI Reader or Virtual Reader for reading pyramidal image - or large tile in pyramidal way. - inputs: - List of coordinates to read from the `reader`, each - coordinate is of the form `[start_x, start_y, end_x, - end_y]`. - patch_input_shape: - A tuple (int, int) or ndarray of shape (2,). Expected size to - read from `reader` at requested `resolution` and `units`. - Expected to be `(height, width)`. - resolution: - See (:class:`.WSIReader`) for details. - units: - See (:class:`.WSIReader`) for details. - preproc_func: - Preprocessing function used to transform the input data. It will - be called on each patch before returning it. - - """ - - def __init__( # skipcq: PY-R1000 # noqa: PLR0915 - self: WSIPatchDataset, - img_path: str | Path, - mode: str = "wsi", - mask_path: str | Path | None = None, - patch_input_shape: IntPair = None, - stride_shape: IntPair = None, - resolution: Resolution = None, - units: Units = None, - min_mask_ratio: float = 0, - preproc_func: Callable | None = None, - *, - auto_get_mask: bool = True, - ) -> None: - """Create a WSI-level patch dataset. - - Args: - mode (str): - Can be either `wsi` or `tile` to denote the image to - read is either a whole-slide image or a large image - tile. - img_path (str or Path): - Valid to pyramidal whole-slide image or large tile to - read. - mask_path (str or Path): - Valid mask image. - patch_input_shape: - A tuple (int, int) or ndarray of shape (2,). Expected - shape to read from `reader` at requested `resolution` - and `units`. Expected to be positive and of (height, - width). Note, this is not at `resolution` coordinate - space. - stride_shape: - A tuple (int, int) or ndarray of shape (2,). Expected - stride shape to read at requested `resolution` and - `units`. Expected to be positive and of (height, width). - Note, this is not at level 0. - resolution (Resolution): - Check (:class:`.WSIReader`) for details. When - `mode='tile'`, value is fixed to be `resolution=1.0` and - `units='baseline'` units: check (:class:`.WSIReader`) for - details. - units (Units): - Units in which `resolution` is defined. - auto_get_mask (bool): - If `True`, then automatically get simple threshold mask using - WSIReader.tissue_mask() function. - min_mask_ratio (float): - Only patches with positive area percentage above this value are - included. Defaults to 0. - preproc_func (Callable): - Preprocessing function used to transform the input data. If - supplied, the function will be called on each patch before - returning it. - - Examples: - >>> # A user defined preproc func and expected behavior - >>> preproc_func = lambda img: img/2 # reduce intensity by half - >>> transformed_img = preproc_func(img) - >>> # Create a dataset to get patches from WSI with above - >>> # preprocessing function - >>> ds = WSIPatchDataset( - ... img_path='/A/B/C/wsi.svs', - ... mode="wsi", - ... patch_input_shape=[512, 512], - ... stride_shape=[256, 256], - ... auto_get_mask=False, - ... preproc_func=preproc_func - ... ) - - """ - super().__init__() - - # Is there a generic func for path test in toolbox? - if not Path.is_file(Path(img_path)): - msg = "`img_path` must be a valid file path." - raise ValueError(msg) - if mode not in ["wsi", "tile"]: - msg = f"`{mode}` is not supported." - raise ValueError(msg) - patch_input_shape = np.array(patch_input_shape) - stride_shape = np.array(stride_shape) - - if ( - not np.issubdtype(patch_input_shape.dtype, np.integer) - or np.size(patch_input_shape) > 2 # noqa: PLR2004 - or np.any(patch_input_shape < 0) - ): - msg = f"Invalid `patch_input_shape` value {patch_input_shape}." - raise ValueError(msg) - if ( - not np.issubdtype(stride_shape.dtype, np.integer) - or np.size(stride_shape) > 2 # noqa: PLR2004 - or np.any(stride_shape < 0) - ): - msg = f"Invalid `stride_shape` value {stride_shape}." - raise ValueError(msg) - - self.preproc_func = preproc_func - img_path = Path(img_path) - if mode == "wsi": - self.reader = WSIReader.open(img_path) - else: - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"` and `resolution=1.0`.', - stacklevel=2, - ) - img = imread(img_path) - axes = "YXS"[: len(img.shape)] - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - # ! should we expose this so that use can provide their metadata ? - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - axes=axes, - objective_power=10, - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - # infer value such that read if mask provided is through - # 'mpp' or 'power' as varying 'baseline' is locked atm - units = "mpp" - resolution = 1.0 - self.reader = VirtualWSIReader( - img, - info=metadata, - ) - - # may decouple into misc ? - # the scaling factor will scale base level to requested read resolution/units - wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units) - - # use all patches, as long as it overlaps source image - self.inputs = PatchExtractor.get_coordinates( - image_shape=wsi_shape, - patch_input_shape=patch_input_shape[::-1], - stride_shape=stride_shape[::-1], - input_within_bound=False, - ) - - mask_reader = None - if mask_path is not None: - mask_path = Path(mask_path) - if not Path.is_file(mask_path): - msg = "`mask_path` must be a valid file path." - raise ValueError(msg) - mask = imread(mask_path) # assume to be gray - mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) - mask = np.array(mask > 0, dtype=np.uint8) - - mask_reader = VirtualWSIReader(mask) - mask_reader.info = self.reader.info - elif auto_get_mask and mode == "wsi" and mask_path is None: - # if no mask provided and `wsi` mode, generate basic tissue - # mask on the fly - mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") - # ? will this mess up ? - mask_reader.info = self.reader.info - - if mask_reader is not None: - selected = PatchExtractor.filter_coordinates( - mask_reader, # must be at the same resolution - self.inputs, # must already be at requested resolution - wsi_shape=wsi_shape, - min_mask_ratio=min_mask_ratio, - ) - self.inputs = self.inputs[selected] - - if len(self.inputs) == 0: - msg = "No patch coordinates remain after filtering." - raise ValueError(msg) - - self.patch_input_shape = patch_input_shape - self.resolution = resolution - self.units = units - - # Perform check on the input - self._check_input_integrity(mode="wsi") - - def __getitem__(self: WSIPatchDataset, idx: int) -> dict: - """Get an item from the dataset.""" - coords = self.inputs[idx] - # Read image patch from the whole-slide image - patch = self.reader.read_bounds( - coords, - resolution=self.resolution, - units=self.units, - pad_constant_values=255, - coord_space="resolution", - ) - - # Apply preprocessing to selected patch - patch = self._preproc(patch) - - return {"image": patch, "coords": np.array(coords)} diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index b60ecd66e..045bb39b7 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -2,23 +2,32 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, Callable, Union +import cv2 +import numpy as np +import torch +import torch.utils.data as torch_data + +from tiatoolbox import logger +from tiatoolbox.tools.patchextraction import PatchExtractor +from tiatoolbox.utils import imread +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader + if TYPE_CHECKING: # pragma: no cover from collections.abc import Iterable + from multiprocessing.managers import Namespace + + from tiatoolbox.models.engine.io_config import IOSegmentorConfig + from tiatoolbox.typing import IntPair, Resolution, Units try: from typing import TypeGuard except ImportError: - from typing_extensions import TypeGuard # to support python <3.10 - - -import numpy as np -import torch - -from tiatoolbox.utils import imread + from typing_extensions import TypeGuard # to support python <=3.9 input_type = Union[list[Union[str, Path, np.ndarray]], np.ndarray] @@ -136,7 +145,7 @@ def load_img(path: str | Path) -> np.ndarray: if path.suffix not in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"): msg = f"Cannot load image data from `{path.suffix}` files." - raise ValueError(msg) + raise TypeError(msg) return imread(path, as_uint8=False) @@ -182,3 +191,432 @@ def __len__(self: PatchDatasetABC) -> int: def __getitem__(self: PatchDatasetABC, idx: int) -> None: """Get an item from the dataset.""" ... # pragma: no cover + + +class WSIStreamDataset(torch_data.Dataset): + """Reading a wsi in parallel mode with persistent workers. + + To speed up the inference process for multiple WSIs. The + `torch.utils.data.Dataloader` is set to run in persistent mode. + Normally, this will prevent workers from altering their initial + states (such as provided input etc.). To sidestep this, we use a + shared parallel workspace context manager to send data and signal + from the main thread, thus allowing each worker to load a new wsi as + well as corresponding patch information. + + Args: + mp_shared_space (:class:`Namespace`): + A shared multiprocessing space, must be from + `torch.multiprocessing`. + ioconfig (:class:`IOSegmentorConfig`): + An object which contains I/O placement for patches. + wsi_paths (list): List of paths pointing to a WSI or tiles. + preproc (Callable): + Pre-processing function to be applied to a patch. + mode (str): + Either `"wsi"` or `"tile"` to indicate the format of images + in `wsi_paths`. + + Examples: + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... ) + >>> mp_manager = torch_mp.Manager() + >>> mp_shared_space = mp_manager.Namespace() + >>> mp_shared_space.signal = 1 # adding variable to the shared space + >>> wsi_paths = ['A.svs', 'B.svs'] + >>> wsi_dataset = WSIStreamDataset(ioconfig, wsi_paths, mp_shared_space) + + """ + + def __init__( + self: WSIStreamDataset, + ioconfig: IOSegmentorConfig, + wsi_paths: list[str | Path], + mp_shared_space: Namespace, + preproc: Callable[[np.ndarray], np.ndarray] | None = None, + mode: str = "wsi", + ) -> None: + """Initialize :class:`WSIStreamDataset`.""" + super().__init__() + self.mode = mode + self.preproc = preproc + self.ioconfig = copy.deepcopy(ioconfig) + + if mode == "tile": + logger.warning( + "WSIPatchDataset only reads image tile at " + '`units="baseline"`. Resolutions will be converted ' + "to baseline value.", + stacklevel=2, + ) + self.ioconfig = self.ioconfig.to_baseline() + + self.mp_shared_space = mp_shared_space + self.wsi_paths = wsi_paths + self.wsi_idx = None # to be received externally via thread communication + self.reader = None + + def _get_reader(self: WSIStreamDataset, img_path: str | Path) -> WSIReader: + """Get appropriate reader for input path.""" + img_path = Path(img_path) + if self.mode == "wsi": + return WSIReader.open(img_path) + img = imread(img_path) + # initialise metadata for VirtualWSIReader. + # here, we simulate a whole-slide image, but with a single level. + metadata = WSIMeta( + mpp=np.array([1.0, 1.0]), + objective_power=10, + axes="YXS", + slide_dimensions=np.array(img.shape[:2][::-1]), + level_downsamples=[1.0], + level_dimensions=[np.array(img.shape[:2][::-1])], + ) + return VirtualWSIReader( + img, + info=metadata, + ) + + def __len__(self: WSIStreamDataset) -> int: + """Return the length of the instance attributes.""" + return len(self.mp_shared_space.patch_inputs) + + @staticmethod + def collate_fn(batch: list | np.ndarray) -> torch.Tensor: + """Prototype to handle reading exception. + + This will exclude any sample with `None` from the batch. As + such, wrapping `__getitem__` with try-catch and return `None` + upon exceptions will prevent crashing the entire program. But as + a side effect, the batch may not have the size as defined. + + """ + batch = [v for v in batch if v is not None] + return torch.utils.data.dataloader.default_collate(batch) + + def __getitem__(self: WSIStreamDataset, idx: int) -> tuple: + """Get an item from the dataset.""" + # ! no need to lock as we do not modify source value in shared space + if self.wsi_idx != self.mp_shared_space.wsi_idx: + self.wsi_idx = int(self.mp_shared_space.wsi_idx.item()) + self.reader = self._get_reader(self.wsi_paths[self.wsi_idx]) + + # this is in XY and at requested resolution (not baseline) + bounds = self.mp_shared_space.patch_inputs[idx] + bounds = bounds.numpy() # expected to be a torch.Tensor + + # be the same as bounds br-tl, unless bounds are of float + patch_data_ = [] + scale_factors = self.ioconfig.scale_to_highest( + self.ioconfig.input_resolutions, + self.ioconfig.resolution_unit, + ) + for idy, resolution in enumerate(self.ioconfig.input_resolutions): + resolution_bounds = np.round(bounds * scale_factors[idy]) + patch_data = self.reader.read_bounds( + resolution_bounds.astype(np.int32), + coord_space="resolution", + pad_constant_values=0, # expose this ? + **resolution, + ) + + if self.preproc is not None: + patch_data = patch_data.copy() + patch_data = self.preproc(patch_data) + patch_data_.append(patch_data) + if len(patch_data_) == 1: + patch_data_ = patch_data_[0] + + bound = self.mp_shared_space.patch_outputs[idx] + return patch_data_, bound + + +class WSIPatchDataset(PatchDatasetABC): + """Define a WSI-level patch dataset. + + Attributes: + reader (:class:`.WSIReader`): + A WSI Reader or Virtual Reader for reading pyramidal image + or large tile in pyramidal way. + inputs: + List of coordinates to read from the `reader`, each + coordinate is of the form `[start_x, start_y, end_x, + end_y]`. + patch_input_shape: + A tuple (int, int) or ndarray of shape (2,). Expected size to + read from `reader` at requested `resolution` and `units`. + Expected to be `(height, width)`. + resolution: + See (:class:`.WSIReader`) for details. + units: + See (:class:`.WSIReader`) for details. + preproc_func: + Preprocessing function used to transform the input data. It will + be called on each patch before returning it. + + """ + + def __init__( # skipcq: PY-R1000 # noqa: PLR0915 + self: WSIPatchDataset, + img_path: str | Path, + mode: str = "wsi", + mask_path: str | Path | None = None, + patch_input_shape: IntPair = None, + stride_shape: IntPair = None, + resolution: Resolution = None, + units: Units = None, + min_mask_ratio: float = 0, + preproc_func: Callable | None = None, + *, + auto_get_mask: bool = True, + ) -> None: + """Create a WSI-level patch dataset. + + Args: + mode (str): + Can be either `wsi` or `tile` to denote the image to + read is either a whole-slide image or a large image + tile. + img_path (str or Path): + Valid to pyramidal whole-slide image or large tile to + read. + mask_path (str or Path): + Valid mask image. + patch_input_shape: + A tuple (int, int) or ndarray of shape (2,). Expected + shape to read from `reader` at requested `resolution` + and `units`. Expected to be positive and of (height, + width). Note, this is not at `resolution` coordinate + space. + stride_shape: + A tuple (int, int) or ndarray of shape (2,). Expected + stride shape to read at requested `resolution` and + `units`. Expected to be positive and of (height, width). + Note, this is not at level 0. + resolution (Resolution): + Requested resolution corresponding to units. Check + (:class:`WSIReader`) for details. + units (Units): + Units in which `resolution` is defined. + auto_get_mask (bool): + If `True`, then automatically get simple threshold mask using + WSIReader.tissue_mask() function. + min_mask_ratio (float): + Only patches with positive area percentage above this value are + included. Defaults to 0. + preproc_func (Callable): + Preprocessing function used to transform the input data. If + supplied, the function will be called on each patch before + returning it. + + Examples: + >>> # A user defined preproc func and expected behavior + >>> preproc_func = lambda img: img/2 # reduce intensity by half + >>> transformed_img = preproc_func(img) + >>> # Create a dataset to get patches from WSI with above + >>> # preprocessing function + >>> ds = WSIPatchDataset( + ... img_path='/A/B/C/wsi.svs', + ... mode="wsi", + ... patch_input_shape=[512, 512], + ... stride_shape=[256, 256], + ... auto_get_mask=False, + ... preproc_func=preproc_func + ... ) + + """ + super().__init__() + + # Is there a generic func for path test in toolbox? + if not Path.is_file(Path(img_path)): + msg = "`img_path` must be a valid file path." + raise ValueError(msg) + if mode not in ["wsi", "tile"]: + msg = f"`{mode}` is not supported." + raise ValueError(msg) + patch_input_shape = np.array(patch_input_shape) + stride_shape = np.array(stride_shape) + + if ( + not np.issubdtype(patch_input_shape.dtype, np.integer) + or np.size(patch_input_shape) > 2 # noqa: PLR2004 + or np.any(patch_input_shape < 0) + ): + msg = f"Invalid `patch_input_shape` value {patch_input_shape}." + raise ValueError(msg) + if ( + not np.issubdtype(stride_shape.dtype, np.integer) + or np.size(stride_shape) > 2 # noqa: PLR2004 + or np.any(stride_shape < 0) + ): + msg = f"Invalid `stride_shape` value {stride_shape}." + raise ValueError(msg) + + self.preproc_func = preproc_func + img_path = Path(img_path) + if mode == "wsi": + self.reader = WSIReader.open(img_path) + else: + logger.warning( + "WSIPatchDataset only reads image tile at " + '`units="baseline"` and `resolution=1.0`.', + stacklevel=2, + ) + img = imread(img_path) + axes = "YXS"[: len(img.shape)] + # initialise metadata for VirtualWSIReader. + # here, we simulate a whole-slide image, but with a single level. + # ! should we expose this so that use can provide their metadata ? + metadata = WSIMeta( + mpp=np.array([1.0, 1.0]), + axes=axes, + objective_power=10, + slide_dimensions=np.array(img.shape[:2][::-1]), + level_downsamples=[1.0], + level_dimensions=[np.array(img.shape[:2][::-1])], + ) + # infer value such that read if mask provided is through + # 'mpp' or 'power' as varying 'baseline' is locked atm + units = "mpp" + resolution = 1.0 + self.reader = VirtualWSIReader( + img, + info=metadata, + ) + + # may decouple into misc ? + # the scaling factor will scale base level to requested read resolution/units + wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units) + + # use all patches, as long as it overlaps source image + self.inputs = PatchExtractor.get_coordinates( + image_shape=wsi_shape, + patch_input_shape=patch_input_shape[::-1], + stride_shape=stride_shape[::-1], + input_within_bound=False, + ) + + mask_reader = None + if mask_path is not None: + mask_path = Path(mask_path) + if not Path.is_file(mask_path): + msg = "`mask_path` must be a valid file path." + raise ValueError(msg) + mask = imread(mask_path) # assume to be gray + mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) + mask = np.array(mask > 0, dtype=np.uint8) + + mask_reader = VirtualWSIReader(mask) + mask_reader.info = self.reader.info + elif auto_get_mask and mode == "wsi" and mask_path is None: + # if no mask provided and `wsi` mode, generate basic tissue + # mask on the fly + mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") + # ? will this mess up ? + mask_reader.info = self.reader.info + + if mask_reader is not None: + selected = PatchExtractor.filter_coordinates( + mask_reader, # must be at the same resolution + self.inputs, # must already be at requested resolution + wsi_shape=wsi_shape, + min_mask_ratio=min_mask_ratio, + ) + self.inputs = self.inputs[selected] + + if len(self.inputs) == 0: + msg = "No patch coordinates remain after filtering." + raise ValueError(msg) + + self.patch_input_shape = patch_input_shape + self.resolution = resolution + self.units = units + + # Perform check on the input + self._check_input_integrity(mode="wsi") + + def __getitem__(self: WSIPatchDataset, idx: int) -> dict: + """Get an item from the dataset.""" + coords = self.inputs[idx] + # Read image patch from the whole-slide image + patch = self.reader.read_bounds( + coords, + resolution=self.resolution, + units=self.units, + pad_constant_values=255, + coord_space="resolution", + ) + + # Apply preprocessing to selected patch + patch = self._preproc(patch) + + return {"image": patch, "coords": np.array(coords)} + + +class PatchDataset(PatchDatasetABC): + """Define PatchDataset for torch inference. + + Define a simple patch dataset, which inherits from the + `torch.utils.data.Dataset` class. + + Attributes: + inputs (list or np.ndarray): + Either a list of patches, where each patch is a ndarray or a + list of valid path with its extension be (".jpg", ".jpeg", + ".tif", ".tiff", ".png") pointing to an image. + labels (list): + List of labels for sample at the same index in `inputs`. + Default is `None`. + + Examples: + >>> # A user defined preproc func and expected behavior + >>> preproc_func = lambda img: img/2 # reduce intensity by half + >>> transformed_img = preproc_func(img) + >>> # create a dataset to get patches preprocessed by the above function + >>> ds = PatchDataset( + ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], + ... labels=["labels1", "labels2"], + ... ) + + """ + + def __init__( + self: PatchDataset, + inputs: np.ndarray | list, + labels: list | None = None, + ) -> None: + """Initialize :class:`PatchDataset`.""" + super().__init__() + + self.data_is_npy_alike = False + + self.inputs = inputs + self.labels = labels + + # perform check on the input + self._check_input_integrity(mode="patch") + + def __getitem__(self: PatchDataset, idx: int) -> dict: + """Get an item from the dataset.""" + patch = self.inputs[idx] + + # Mode 0 is list of paths + if not self.data_is_npy_alike: + patch = self.load_img(patch) + + # Apply preprocessing to selected patch + patch = self._preproc(patch) + + data = { + "image": patch, + } + if self.labels is not None: + data["label"] = self.labels[idx] + return data + + return data diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 4293fae0c..9c00ac4a2 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -1,7 +1,15 @@ """Engines to run models implemented in tiatoolbox.""" -from tiatoolbox.models.engine import ( +from . import ( + engine_abc, nucleus_instance_segmentor, patch_predictor, semantic_segmentor, ) + +__all__ = [ + "engine_abc", + "nucleus_instance_segmentor", + "patch_predictor", + "semantic_segmentor", +] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py new file mode 100644 index 000000000..8f0adc310 --- /dev/null +++ b/tiatoolbox/models/engine/engine_abc.py @@ -0,0 +1,1232 @@ +"""Defines Abstract Base Class for TIAToolbox Engines.""" + +from __future__ import annotations + +import copy +import shutil +from abc import ABC +from pathlib import Path +from typing import TYPE_CHECKING, TypedDict + +import numpy as np +import torch +import tqdm +import zarr +from torch import nn +from typing_extensions import Unpack + +from tiatoolbox import DuplicateFilter, logger, rcParam +from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.architecture.utils import compile_model +from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset +from tiatoolbox.models.models_abc import load_torch_model +from tiatoolbox.utils.misc import ( + dict_to_store, + dict_to_zarr, + write_to_zarr_in_cache_mode, +) + +from .io_config import ModelIOConfigABC + +if TYPE_CHECKING: # pragma: no cover + import os + + from torch.utils.data import DataLoader + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.typing import IntPair, Resolution, Units + from tiatoolbox.wsicore.wsireader import WSIReader + + +def prepare_engines_save_dir( + save_dir: os | Path | None, + *, + patch_mode: bool, + overwrite: bool = False, +) -> Path | None: + """Create a save directory. + + If patch_mode is False and the save directory is not defined, + this function will raise an error. + + If patch_mode is True and the save directory is defined it will + create save_dir otherwise returns None. + + Args: + save_dir (str or Path): + Path to output directory. + patch_mode(bool): + Whether to treat input image as a patch or WSI. + overwrite (bool): + Whether to overwrite the results. Default = False. + + Returns: + :class:`Path`: + Path to output directory. + + Raises: + OSError: + If the save directory is not defined. + + """ + if patch_mode is True: + if save_dir is not None: + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=overwrite) + return save_dir + + if save_dir is None: + msg = ( + "Input WSIs detected but no save directory provided." + "Please provide a 'save_dir'." + ) + raise OSError(msg) + + logger.info( + "When providing multiple whole slide images, " + "the outputs will be saved and the locations of outputs " + "will be returned to the calling function when `run()`" + "finishes successfully.", + ) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=overwrite) + + return save_dir + + +class EngineABCRunParams(TypedDict, total=False): + """Class describing the input parameters for the :func:`EngineABC.run()` method. + + Attributes: + batch_size (int): + Number of image patches to feed to the model in a forward pass. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. + ioconfig (ModelIOConfigABC): + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. + return_labels (bool): + Whether to return the labels with the predictions. + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + output_file (str): + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. + patch_input_shape (tuple): + Shape of patches input to the model as tuple of height and width (HW). + Patches are requested at read resolution, not with respect to level 0, + and must be positive. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + verbose (bool): + Whether to output logging information. + + """ + + batch_size: int + cache_mode: bool + cache_size: int + class_dict: dict + device: str + ioconfig: ModelIOConfigABC + num_loader_workers: int + num_post_proc_workers: int + output_file: str + patch_input_shape: IntPair + resolution: Resolution + return_labels: bool + scale_factor: tuple[float, float] + stride_shape: IntPair + units: Units + verbose: bool + + +class EngineABC(ABC): # noqa: B024 + """Abstract base class for TIAToolbox deep learning engines to run CNN models. + + Args: + model (str | ModelABC): + A PyTorch model. Default is `None`. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. + batch_size (int): + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = EngineABC( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". + verbose (bool): + Whether to output logging information. Default value is False. + + Attributes: + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (str): + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. + _ioconfig (ModelIOConfigABC): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + batch_size (int): + Number of images fed into the model each time. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + units (Units): + Units of resolution used for reading the image. Choose + from either `baseline`, `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. + Default value is `baseline`. + verbose (bool): + Whether to output logging information. Default value is False. + + Examples: + >>> # Inherit from EngineABC + >>> class TestEngineABC(EngineABC): + >>> def __init__( + >>> self, + >>> model, + >>> weights, + >>> verbose, + >>> ): + >>> super().__init__(model=model, weights=weights, verbose=verbose) + >>> # Define all the abstract classes + + >>> data = np.array([np.ndarray, np.ndarray]) + >>> engine = TestEngineABC(model="resnet18-kather100k") + >>> output = engine.run(data, patch_mode=True) + + >>> # array of list of 2 image patches as input + >>> data = np.array([np.ndarray, np.ndarray]) + >>> engine = TestEngineABC(model="resnet18-kather100k") + >>> output = engine.run(data, patch_mode=True) + + >>> # list of 2 image files as input + >>> image = ['path/image1.png', 'path/image2.png'] + >>> engine = TestEngineABC(model="resnet18-kather100k") + >>> output = engine.run(image, patch_mode=False) + + >>> # list of 2 wsi files as input + >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> engine = TestEngineABC(model="resnet18-kather100k") + >>> output = engine.run(wsi_file, patch_mode=False) + + """ + + def __init__( + self: EngineABC, + model: str | ModelABC, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = False, + ) -> None: + """Initialize Engine.""" + self.images = None + self.masks = None + self.patch_mode = None + self.device = device + + # Initialize model with specified weights and ioconfig. + self.model, self.ioconfig = self._initialize_model_ioconfig( + model=model, + weights=weights, + ) + self.model.to(device=self.device) + self.model = ( + compile_model( # for runtime, such as after wrapping with nn.DataParallel + self.model, + mode=rcParam["torch_compile_mode"], + ) + ) + self._ioconfig = self.ioconfig # runtime ioconfig + + self.batch_size = batch_size + self.cache_mode: bool = False + self.cache_size: int = self.batch_size if self.batch_size else 10000 + self.labels: list | None = None + self.num_loader_workers = num_loader_workers + self.num_post_proc_workers = num_post_proc_workers + self.patch_input_shape: IntPair | None = None + self.resolution: Resolution | None = None + self.return_labels: bool = False + self.stride_shape: IntPair | None = None + self.units: Units | None = None + self.verbose = verbose + + @staticmethod + def _initialize_model_ioconfig( + model: str | ModelABC, + weights: str | Path | None, + ) -> tuple[nn.Module, ModelIOConfigABC | None]: + """Helper function to initialize model and ioconfig attributes. + + If a pretrained model provided by the TIAToolbox is requested. The model + can be specified as a string otherwise :class:`torch.nn.Module` is required. + This function also loads the :class:`ModelIOConfigABC` using the information + from the pretrained models in TIAToolbox. If ioconfig is not available then it + should be provided in the :func:`run()` function. + + Args: + model (str | ModelABC): + A PyTorch model. Default is `None`. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. + + weights (str | Path | None): + Path to pretrained weights. If no pretrained weights are provided + and the `model` is provided by TIAToolbox, then pretrained weights will + be automatically loaded from the TIA servers. + + Returns: + ModelABC: + The requested PyTorch model as a :class:`ModelABC` instance. + + ModelIOConfigABC | None: + The model io configuration for TIAToolbox pretrained models. + If the specified model is not in TIAToolbox model zoo, then the function + returns None. + + """ + if not isinstance(model, (str, nn.Module)): + msg = "Input model must be a string or 'torch.nn.Module'." + raise TypeError(msg) + + if isinstance(model, str): + # ioconfig is retrieved from the pretrained model in the toolbox. + # list of pretrained models in the TIA Toolbox is available here: + # https://tia-toolbox.readthedocs.io/en/latest/pretrained.html + # no need to provide ioconfig in EngineABC.run() this case. + return get_pretrained_model(model, weights) + + if weights is not None: + model = load_torch_model(model=model, weights=weights) + + return model, None + + def get_dataloader( + self: EngineABC, + images: str | Path | list[str | Path] | np.ndarray, + masks: Path | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + ) -> torch.utils.data.DataLoader: + """Pre-process images and masks and return dataloader for inference. + + Args: + images (list of str or :class:`Path` or :class:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. When `patch_mode` is False + the function expects list of str/paths to WSIs. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + ioconfig (ModelIOConfigABC): + A :class:`ModelIOConfigABC` object. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + + Returns: + torch.utils.data.DataLoader: + :class:`torch.utils.data.DataLoader` for inference. + + """ + if labels: + # if a labels is provided, then return with the prediction + self.return_labels = bool(labels) + + if not patch_mode: + dataset = WSIPatchDataset( + img_path=images, + mode="wsi", + mask_path=masks, + patch_input_shape=ioconfig.patch_input_shape, + stride_shape=ioconfig.stride_shape, + resolution=ioconfig.input_resolutions[0]["resolution"], + units=ioconfig.input_resolutions[0]["units"], + ) + + dataset.preproc_func = self.model.preproc_func + + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_loader_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) + + dataset = PatchDataset(inputs=images, labels=labels) + dataset.preproc_func = self.model.preproc_func + + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_loader_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) + + @staticmethod + def _update_model_output(raw_predictions: dict, raw_output: dict) -> dict: + """Helper function to append raw output during inference.""" + for key, value in raw_output.items(): + if raw_predictions[key] is None: + raw_predictions[key] = value + else: + raw_predictions[key] = np.append(raw_predictions[key], value, axis=0) + + return raw_predictions + + def _get_coordinates(self: EngineABC, batch_data: dict) -> np.ndarray: + """Helper function to collect coordinates for AnnotationStore.""" + if self.patch_mode: + coordinates = [0, 0, *batch_data["image"].shape[1:3]] + return np.tile(coordinates, reps=(batch_data["image"].shape[0], 1)) + return batch_data["coords"].numpy() + + def infer_patches( + self: EngineABC, + dataloader: DataLoader, + save_path: Path | None, + *, + return_coordinates: bool = False, + ) -> dict | Path: + """Runs model inference on image patches and returns output as a dictionary. + + Args: + dataloader (DataLoader): + An :class:`torch.utils.data.DataLoader` object to run inference. + save_path (Path | None): + If `cache_mode` is True then path to save zarr file must be provided. + return_coordinates (bool): + Whether to save coordinates in the output. This is required when + this function is called by `infer_wsi` and `patch_mode` is False. + + Returns: + dict or Path: + Result of model inference as a dictionary. Returns path to + saved zarr file if `cache_mode` is True. + + """ + progress_bar = None + + if self.verbose: + progress_bar = tqdm.tqdm( + total=int(len(dataloader)), + leave=True, + ncols=80, + ascii=True, + position=0, + ) + + keys = ["probabilities"] + + if self.return_labels: + keys.append("labels") + + if return_coordinates: + keys.append("coordinates") + + raw_predictions = {key: None for key in keys} + + zarr_group = None + + if self.cache_mode: + zarr_group = zarr.open(save_path, mode="w") + + for _, batch_data in enumerate(dataloader): + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + if return_coordinates: + batch_output["coordinates"] = self._get_coordinates(batch_data) + + if self.return_labels: # be careful of `s` + if isinstance(batch_data["label"], torch.Tensor): + batch_output["labels"] = batch_data["label"].numpy() + else: + batch_output["labels"] = batch_data["label"] + + raw_predictions = self._update_model_output( + raw_predictions=raw_predictions, + raw_output=batch_output, + ) + + if self.cache_mode: + zarr_group = write_to_zarr_in_cache_mode( + zarr_group=zarr_group, output_data_to_save=raw_predictions + ) + raw_predictions = {key: None for key in keys} + + if progress_bar: + progress_bar.update() + + if progress_bar: + progress_bar.close() + + return save_path if self.cache_mode else raw_predictions + + def post_process_patches( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post-process raw patch predictions from inference. + + The output of :func:`infer_patches()` with patch prediction information will be + post-processed using this function. The processed output will be saved in the + respective input format. If `cache_mode` is True, the function processes the + input using zarr group with size specified by `cache_size`. + + Args: + raw_predictions (dict | Path): + A dictionary or path to zarr with patch prediction information. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`EngineRunParams` for accepted keyword arguments. + + Returns: + dict or Path: + Returns patch based output after post-processing. Returns path to + saved zarr file if `cache_mode` is True. + + """ + _ = kwargs.get("return_labels") # Key values required for post-processing + + if self.cache_mode: # cache mode + _ = zarr.open(raw_predictions, mode="w") + + return raw_predictions + + def save_predictions( + self: EngineABC, + processed_predictions: dict | Path, + output_type: str, + save_dir: Path | None = None, + **kwargs: dict, + ) -> dict | AnnotationStore | Path: + """Save model predictions. + + Args: + processed_predictions (dict | Path): + A dictionary or path to zarr with model prediction information. + save_dir (Path): + Optional output path to directory to save the patch dataset output to a + `.zarr` or `.db` file, provided `patch_mode` is True. If the + `patch_mode` is False then `save_dir` is required. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (EngineABCRunParams): + Keyword Args required to save the output. + + Returns: + dict or Path or :class:`AnnotationStore`: + If the `output_type` is "AnnotationStore", the function returns + the patch predictor output as an SQLiteStore containing Annotations + for each or the Path to a `.db` file depending on whether a + save_dir Path is provided. Otherwise, the function defaults to + returning patch predictor output, either as a dict or the Path to a + `.zarr` file depending on whether a save_dir Path is provided. + + """ + if ( + self.cache_mode or not save_dir + ) and output_type.lower() != "annotationstore": + return processed_predictions + + save_path = Path(kwargs.get("output_file", save_dir / "output.db")) + + if output_type.lower() == "annotationstore": + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + + processed_predictions_path: str | Path | None = None + + # Need to add support for zarr conversion. + if self.cache_mode: + processed_predictions_path = processed_predictions + processed_predictions = zarr.open(processed_predictions, mode="r") + + out_file = dict_to_store( + processed_predictions, + scale_factor, + class_dict, + save_path, + ) + if processed_predictions_path is not None: + shutil.rmtree(processed_predictions_path) + + return out_file + + return ( + dict_to_zarr( + processed_predictions, + save_path, + **kwargs, + ) + if isinstance(processed_predictions, dict) + else processed_predictions + ) + + def infer_wsi( + self: EngineABC, + dataloader: DataLoader, + save_path: Path, + **kwargs: EngineABCRunParams, + ) -> Path: + """Model inference on a WSI. + + Args: + dataloader (DataLoader): + A torch dataloader to process WSIs. + + save_path (Path): + Path to save the intermediate output. The intermediate output is saved + in a zarr file. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`EngineRunParams` for accepted keyword arguments. + + Returns: + save_path (Path): + Path to zarr file where intermediate output is saved. + + """ + _ = kwargs.get("patch_mode", False) + return self.infer_patches( + dataloader=dataloader, + save_path=save_path, + return_coordinates=True, + ) + + # This is not a static model for child classes. + def post_process_wsi( # skipcq: PYL-R0201 + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output. + + Takes the raw output from patch predictions and post-processes it to improve the + results e.g., using information from neighbouring patches. + + """ + _ = kwargs.get("return_labels") # Key values required for post-processing + return raw_predictions + + def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: + """Helper function to load ioconfig. + + If the model is provided by TIAToolbox it will load the default ioconfig. + Otherwise, ioconfig must be specified. + + Args: + ioconfig (ModelIOConfigABC): + IO configuration to run the engines. + + Raises: + ValueError: + If no io configuration is provided or found in the pretrained TIAToolbox + models. + + Returns: + ModelIOConfigABC: + The ioconfig used for the run. + + """ + if self.ioconfig is None and ioconfig is None: + msg = ( + "Please provide a valid ModelIOConfigABC. " + "No default ModelIOConfigABC found." + ) + logger.warning(msg) + + if ioconfig and isinstance(ioconfig, ModelIOConfigABC): + self.ioconfig = ioconfig + + return self.ioconfig + + def _update_ioconfig( + self: EngineABC, + ioconfig: ModelIOConfigABC, + patch_input_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> ModelIOConfigABC: + """Update IOConfig. + + Args: + ioconfig (:class:`ModelIOConfigABC`): + Input ioconfig for PatchPredictor. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride using during tile and WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. + + Returns: + Updated Patch Predictor IO configuration. + + """ + config_flag = ( + patch_input_shape is None, + resolution is None, + units is None, + ) + if isinstance(ioconfig, ModelIOConfigABC): + return ioconfig + + if self.ioconfig is None and any(config_flag): + msg = ( + "Must provide either " + "`ioconfig` or `patch_input_shape`, `resolution`, and `units`." + ) + raise ValueError( + msg, + ) + + if stride_shape is None: + stride_shape = patch_input_shape + + if self.ioconfig: + ioconfig = copy.deepcopy(self.ioconfig) + # ! not sure if there is a nicer way to set this + if patch_input_shape is not None: + ioconfig.patch_input_shape = patch_input_shape + if stride_shape is not None: + ioconfig.stride_shape = stride_shape + if resolution is not None: + ioconfig.input_resolutions[0]["resolution"] = resolution + if units is not None: + ioconfig.input_resolutions[0]["units"] = units + + return ioconfig + + return ModelIOConfigABC( + input_resolutions=[{"resolution": resolution, "units": units}], + patch_input_shape=patch_input_shape, + stride_shape=stride_shape, + output_resolutions=[], + ) + + @staticmethod + def _validate_images_masks(images: list | np.ndarray) -> list | np.ndarray: + """Validate input images for a run.""" + if not isinstance(images, (list, np.ndarray)): + msg = "Input must be a list of file paths or a numpy array." + raise TypeError( + msg, + ) + + if isinstance(images, np.ndarray) and images.ndim != 4: # noqa: PLR2004 + msg = ( + "The input numpy array should be four dimensional." + "The shape of the numpy array should be NHWC." + ) + raise ValueError(msg) + + return images + + @staticmethod + def _validate_input_numbers( + images: list | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ) -> None: + """Validates number of input images, masks and labels.""" + if masks is None and labels is None: + return + + len_images = len(images) + + if masks is not None and len_images != len(masks): + msg = ( + f"len(masks) is not equal to len(images) " + f": {len(masks)} != {len(images)}" + ) + raise ValueError( + msg, + ) + + if labels is not None and len_images != len(labels): + msg = ( + f"len(labels) is not equal to len(images) " + f": {len(labels)} != {len(images)}" + ) + raise ValueError( + msg, + ) + return + + def _update_run_params( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + save_dir: os | Path | None = None, + ioconfig: ModelIOConfigABC | None = None, + output_type: str = "dict", + *, + overwrite: bool = False, + patch_mode: bool, + **kwargs: Unpack[EngineABCRunParams], + ) -> Path | None: + """Updates runtime parameters. + + Updates runtime parameters for an EngineABC for EngineABC.run(). + + """ + for key in kwargs: + setattr(self, key, kwargs.get(key)) + + self.patch_mode = patch_mode + if not self.patch_mode: + self.cache_mode = True # if input is WSI run using cache mode. + + if self.cache_mode and self.batch_size > self.cache_size: + self.batch_size = self.cache_size + + self._validate_input_numbers(images=images, masks=masks, labels=labels) + if output_type.lower() not in ["dict", "zarr", "annotationstore"]: + msg = "output_type must be 'dict' or 'zarr' or 'annotationstore'." + raise TypeError(msg) + + self.images = self._validate_images_masks(images=images) + + if masks is not None: + self.masks = self._validate_images_masks(images=masks) + + self.labels = labels + + # if necessary move model parameters to "cpu" or "gpu" and update ioconfig + self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) + self.model.to(device=self.device) + self._ioconfig = self._update_ioconfig( + ioconfig, + self.patch_input_shape, + self.stride_shape, + self.resolution, + self.units, + ) + + return prepare_engines_save_dir( + save_dir=save_dir, + patch_mode=patch_mode, + overwrite=overwrite, + ) + + def _run_patch_mode( + self: EngineABC, output_type: str, save_dir: Path, **kwargs: EngineABCRunParams + ) -> dict | AnnotationStore | Path: + """Runs the Engine in the patch mode. + + Input arguments are passed from :func:`EngineABC.run()`. + + """ + save_path = None + if self.cache_mode: + output_file = Path(kwargs.get("output_file", "output.zarr")) + save_path = save_dir / (str(output_file.stem) + ".zarr") + + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + + dataloader = self.get_dataloader( + images=self.images, + masks=self.masks, + labels=self.labels, + patch_mode=True, + ) + raw_predictions = self.infer_patches( + dataloader=dataloader, + save_path=save_path, + return_coordinates=output_type == "annotationstore", + ) + processed_predictions = self.post_process_patches( + raw_predictions=raw_predictions, + **kwargs, + ) + logger.removeFilter(duplicate_filter) + + out = self.save_predictions( + processed_predictions=processed_predictions, + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + + if save_dir: + msg = f"Output file saved at {out}." + logger.info(msg=msg) + return out + + return out + + @staticmethod + def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, float]: + """Calculates scale factor for final output. + + Uses the dataloader resolution and the WSI resolution to calculate scale + factor for final WSI output. + + Args: + dataloader (DataLoader): + Dataloader for the current run. + + Returns: + scale_factor (float | tuple[float, float]): + Scale factor for final output. + + """ + # get units and resolution from dataloader. + dataloader_units = dataloader.dataset.units + dataloader_resolution = dataloader.dataset.resolution + + # if dataloader units is baseline slide resolution is 1.0. + # in this case dataloader resolution / slide resolution will be + # equal to dataloader resolution. + + if dataloader_units in ["mpp", "level", "power"]: + wsimeta_dict = dataloader.dataset.reader.info.as_dict() + + if dataloader_units == "mpp": + slide_resolution = wsimeta_dict[dataloader_units] + scale_factor = np.divide(dataloader_resolution, slide_resolution) + return scale_factor[0], scale_factor[1] + + if dataloader_units == "level": + downsample_ratio = wsimeta_dict["level_downsamples"][dataloader_resolution] + return downsample_ratio, downsample_ratio + + if dataloader_units == "power": + slide_objective_power = wsimeta_dict["objective_power"] + return ( + slide_objective_power / dataloader_resolution, + slide_objective_power / dataloader_resolution, + ) + + return dataloader_resolution + + def _run_wsi_mode( + self: EngineABC, + output_type: str, + save_dir: Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | AnnotationStore | Path: + """Runs the Engine in the WSI mode (patch_mode = False). + + Input arguments are passed from :func:`EngineABC.run()`. + + """ + suffix = ".zarr" + if output_type == "AnnotationStore": + suffix = ".db" + + out = {image: save_dir / (str(image.stem) + suffix) for image in self.images} + + save_path = { + image: save_dir / (str(image.stem) + ".zarr") for image in self.images + } + + for image_num, image in enumerate(self.images): + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + mask = self.masks[image_num] if self.masks is not None else None + dataloader = self.get_dataloader( + images=image, + masks=mask, + patch_mode=False, + ioconfig=self._ioconfig, + ) + + scale_factor = self._calculate_scale_factor(dataloader=dataloader) + + raw_predictions = self.infer_wsi( + dataloader=dataloader, + save_path=save_path[image], + **kwargs, + ) + processed_predictions = self.post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, + ) + kwargs["output_file"] = out[image] + kwargs["scale_factor"] = scale_factor + out[image] = self.save_predictions( + processed_predictions=processed_predictions, + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + logger.removeFilter(duplicate_filter) + msg = f"Output file saved at {out[image]}." + logger.info(msg=msg) + + return out + + def run( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[EngineABCRunParams], + ) -> AnnotationStore | Path | str | dict: + """Run the engine on input images. + + Args: + images (list, ndarray): + List of inputs to process. when using `patch` mode, the + input must be either a list of images, a list of image + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. + ioconfig (IOPatchPredictorConfig): + IO configuration. + save_dir (str or pathlib.Path): + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "dict", "zarr" or "AnnotationStore". Default value is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (EngineABCRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. + + Returns: + (:class:`numpy.ndarray`, dict): + Model predictions of the input dataset. If multiple + whole slide images are provided as input, + or save_output is True, then results are saved to + `save_dir` and a dictionary indicating save location for + each input is returned. + + The dict has the following format: + + - img_path: path of the input image. + - raw: path to save location for raw prediction, + saved in .json. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> class PatchPredictor(EngineABC): + >>> # Define all Abstract methods. + >>> ... + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = predictor.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" + >>> output = predictor.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... {'/path/to/wsi1.db'} + + """ + save_dir = self._update_run_params( + images=images, + masks=masks, + labels=labels, + save_dir=save_dir, + ioconfig=ioconfig, + overwrite=overwrite, + patch_mode=patch_mode, + output_type=output_type, + **kwargs, + ) + + if patch_mode: + return self._run_patch_mode( + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + + # All inherited classes will get scale_factors, + # highest_input_resolution, implement dataloader, + # pre-processing, post-processing and save_output + # for WSIs separately. + return self._run_wsi_mode( + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py new file mode 100644 index 000000000..f6c9b9c2c --- /dev/null +++ b/tiatoolbox/models/engine/io_config.py @@ -0,0 +1,452 @@ +"""Defines IOConfig for Model Engines.""" + +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.typing import Units + + +@dataclass +class ModelIOConfigABC: + """Defines a data class for holding a deep learning model's I/O information. + + Enforcing such that following attributes must always be defined by + the subclass. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + + Examples: + >>> # Defining io for a base network and converting to baseline. + >>> ioconfig = ModelIOConfigABC( + ... input_resolutions=[{"units": "mpp", "resolution": 0.5}], + ... output_resolutions=[{"units": "mpp", "resolution": 1.0}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + >>> ioconfig = ioconfig.to_baseline() + + """ + + input_resolutions: list[dict] + patch_input_shape: list[int] | np.ndarray | tuple[int, int] + stride_shape: list[int] | np.ndarray | tuple[int, int] = None + output_resolutions: list[dict] = field(default_factory=list) + + def __post_init__(self: ModelIOConfigABC) -> None: + """Perform post initialization tasks.""" + if self.stride_shape is None: + self.stride_shape = self.patch_input_shape + + self.resolution_unit = self.input_resolutions[0]["units"] + + if self.resolution_unit == "mpp": + self.highest_input_resolution = min( + self.input_resolutions, + key=lambda x: x["resolution"], + ) + else: + self.highest_input_resolution = max( + self.input_resolutions, + key=lambda x: x["resolution"], + ) + + self._validate() + + def _validate(self: ModelIOConfigABC) -> None: + """Validate the data format.""" + resolutions = self.input_resolutions + self.output_resolutions + units = {v["units"] for v in resolutions} + + if len(units) != 1: + msg = ( + f"Multiple resolution units found: `{units}`. " + f"Mixing resolution units is not allowed." + ) + raise ValueError( + msg, + ) + + if units.pop() not in [ + "power", + "baseline", + "mpp", + ]: + msg = f"Invalid resolution units `{units}`." + raise ValueError(msg) + + @staticmethod + def scale_to_highest(resolutions: list[dict], units: Units) -> np.array: + """Get the scaling factor from input resolutions. + + This will convert resolutions to a scaling factor with respect to + the highest resolution found in the input resolutions list. If a model + requires images at multiple resolutions. This helps to read the image a + single resolution. The image will be read at the highest required resolution + and will be scaled for low resolution requirements using interpolation. + + Args: + resolutions (list): + A list of resolutions where one is defined as + `{'resolution': value, 'unit': value}` + units (Units): + Resolution units. + + Returns: + :class:`numpy.ndarray`: + A 1D array of scaling factors having the same length as + `resolutions`. + + Examples: + >>> # Defining io for a base network and converting to baseline. + >>> ioconfig = ModelIOConfigABC( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.5}, + ... ], + ... output_resolutions=[{"units": "mpp", "resolution": 1.0}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + >>> ioconfig = ioconfig.scale_to_highest() + ... array([1. , 0.5]) # output + >>> + >>> # Defining io for a base network and converting to baseline. + >>> ioconfig = ModelIOConfigABC( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.5}, + ... {"units": "mpp", "resolution": 0.25}, + ... ], + ... output_resolutions=[{"units": "mpp", "resolution": 1.0}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + >>> ioconfig = ioconfig.scale_to_highest() + ... array([0.5 , 1.]) # output + + """ + old_vals = [v["resolution"] for v in resolutions] + if units not in {"baseline", "mpp", "power"}: + msg = ( + f"Unknown units `{units}`. " + f"Units should be one of 'baseline', 'mpp' or 'power'." + ) + raise ValueError( + msg, + ) + if units == "baseline": + return old_vals + if units == "mpp": + return np.min(old_vals) / np.array(old_vals) + return np.array(old_vals) / np.max(old_vals) + + def to_baseline(self: ModelIOConfigABC) -> ModelIOConfigABC: + """Returns a new config object converted to baseline form. + + This will return a new :class:`ModelIOConfigABC` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + resolutions = self.input_resolutions + self.output_resolutions + save_resolution = getattr(self, "save_resolution", None) + if save_resolution is not None: + resolutions.append(save_resolution) + + scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) + num_input_resolutions = len(self.input_resolutions) + + end_idx = num_input_resolutions + input_resolutions = [ + {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] + ] + + num_input_resolutions = len(self.input_resolutions) + num_output_resolutions = len(self.output_resolutions) + + end_idx = num_input_resolutions + num_output_resolutions + output_resolutions = [ + {"units": "baseline", "resolution": v} + for v in scale_factors[num_input_resolutions:end_idx] + ] + + return replace( + self, + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + ) + + +@dataclass +class IOSegmentorConfig(ModelIOConfigABC): + """Contains semantic segmentor input and output information. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... ) + ... + >>> # Defining io for a network having 3 input and 2 output + >>> # at the same resolution, the output is then merged at a + >>> # different resolution. + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... {"units": "mpp", "resolution": 0.75}, + ... ], + ... output_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... ], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... save_resolution={"units": "mpp", "resolution": 4.0}, + ... ) + + """ + + patch_output_shape: list[int] | np.ndarray | tuple[int, int] = None + save_resolution: dict = None + + def to_baseline(self: IOSegmentorConfig) -> IOSegmentorConfig: + """Returns a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + new_config = super().to_baseline() + resolutions = self.input_resolutions + self.output_resolutions + if self.save_resolution is not None: + resolutions.append(self.save_resolution) + + scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) + + save_resolution = None + if self.save_resolution is not None: + save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} + + return replace( + self, + input_resolutions=new_config.input_resolutions, + output_resolutions=new_config.output_resolutions, + save_resolution=save_resolution, + ) + + +class IOPatchPredictorConfig(ModelIOConfigABC): + """Contains patch predictor input and output information. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + + Examples: + >>> # Defining io for a patch predictor network + >>> ioconfig = IOPatchPredictorConfig( + ... input_resolutions=[{"units": "mpp", "resolution": 0.5}], + ... output_resolutions=[{"units": "mpp", "resolution": 0.5}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + + """ + + +@dataclass +class IOInstanceSegmentorConfig(IOSegmentorConfig): + """Contains instance segmentor input and output information. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + margin (int): + Tile margin to accumulate the output. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + margin (int): + Tile margin to accumulate the output. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IOInstanceSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... margin=128, + ... tile_shape=(1024, 1024), + ... ) + >>> # Defining io for a network having 3 input and 2 output + >>> # at the same resolution, the output is then merged at a + >>> # different resolution. + >>> ioconfig = IOInstanceSegmentorConfig( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... {"units": "mpp", "resolution": 0.75}, + ... ], + ... output_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... ], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... save_resolution={"units": "mpp", "resolution": 4.0}, + ... margin=128, + ... tile_shape=(1024, 1024), + ... ) + + """ + + margin: int = None + tile_shape: tuple[int, int] = None + + def to_baseline(self: IOInstanceSegmentorConfig) -> IOInstanceSegmentorConfig: + """Returns a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + return super().to_baseline() diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 2d3df757f..9db4830d2 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -31,20 +31,19 @@ from shapely.geometry import box as shapely_box from shapely.strtree import STRtree +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.models.engine.nucleus_instance_segmentor import ( NucleusInstanceSegmentor, _process_instance_predictions, ) -from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, - WSIStreamDataset, -) if TYPE_CHECKING: # pragma: no cover import torch from tiatoolbox.type_hints import IntBounds + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig + # Python is yet to be able to natively pickle Object method/static method. # Only top-level function is passable to multi-processing as caller. @@ -293,19 +292,23 @@ def __init__( def _predict_one_wsi( self: MultiTaskSegmentor, wsi_idx: int, - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, ) -> None: """Make a prediction on tile/wsi. Args: - wsi_idx (int): Index of the tile/wsi to be processed within `self`. - ioconfig (IOSegmentorConfig): Object which defines I/O placement during - inference and when assembling back to full tile/wsi. - save_path (str): Location to save output prediction as well as possible + wsi_idx (int): + Index of the tile/wsi to be processed within `self`. + ioconfig (IOInstanceSegmentorConfig): + Object which defines I/O placement + during inference and when assembling back to full tile/wsi. + save_path (str): + Location to save output prediction as well as possible intermediate results. - mode (str): `tile` or `wsi` to indicate run mode. + mode (str): + `tile` or `wsi` to indicate run mode. """ cache_dir = f"{self._cache_dir}/" diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 6649324b1..39325d35b 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -4,7 +4,7 @@ import uuid from collections import deque -from typing import Callable +from typing import TYPE_CHECKING, Callable # replace with the sql database once the PR in place import joblib @@ -14,13 +14,13 @@ from shapely.geometry import box as shapely_box from shapely.strtree import STRtree -from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, - SemanticSegmentor, - WSIStreamDataset, -) +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset +from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor from tiatoolbox.tools.patchextraction import PatchExtractor +if TYPE_CHECKING: # pragma: no cover + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig + def _process_instance_predictions( inst_dict: dict, @@ -406,7 +406,7 @@ def __init__( @staticmethod def _get_tile_info( image_shape: list[int] | np.ndarray, - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, ) -> list[list, ...]: """Generating tile information. @@ -424,7 +424,7 @@ def _get_tile_info( image_shape (:class:`numpy.ndarray`, list(int)): The shape of WSI to extract the tile from, assumed to be in `[width, height]`. - ioconfig (:obj:IOSegmentorConfig): + ioconfig (:obj:IOInstanceSegmentorConfig): The input and output configuration objects. Returns: @@ -439,7 +439,7 @@ def _get_tile_info( - :class:`numpy.ndarray` - Horizontal strip tiles - :class:`numpy.ndarray` - Removal flags - :py:obj:`list` - Tiles and flags - - :class:`numpy.ndarray` - Cross-section tiles + - :class:`numpy.ndarray` - Cross section tiles - :class:`numpy.ndarray` - Removal flags """ @@ -675,7 +675,7 @@ def _predict_one_wsi( Args: wsi_idx (int): Index of the tile/wsi to be processed within `self`. - ioconfig (IOSegmentorConfig): + ioconfig (IOInstanceSegmentorConfig): Object which defines I/O placement during inference and when assembling back to full tile/wsi. save_path (str): diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 088f78687..35b2c7d56 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -1,56 +1,94 @@ -"""This module implements patch level prediction.""" +"""Defines PatchPredictor Engine.""" from __future__ import annotations -import copy -from collections import OrderedDict -from pathlib import Path -from typing import TYPE_CHECKING, Callable +import math +from typing import TYPE_CHECKING -import numpy as np -import torch -import tqdm +import zarr +from typing_extensions import Unpack -from tiatoolbox import logger, rcParam -from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.architecture.utils import compile_model -from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset -from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig -from tiatoolbox.models.models_abc import model_to -from tiatoolbox.utils import save_as_json -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from .engine_abc import EngineABC, EngineABCRunParams if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.type_hints import IntPair, Resolution, Units + import os + from pathlib import Path + import numpy as np -class IOPatchPredictorConfig(IOSegmentorConfig): - """Contains patch predictor input and output information.""" + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import ModelIOConfigABC + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore import WSIReader - def __init__( - self: IOPatchPredictorConfig, - patch_input_shape: IntPair = None, - input_resolutions: Resolution = None, - stride_shape: IntPair = None, - **kwargs: dict, - ) -> None: - """Initialize :class:`IOPatchPredictorConfig`.""" - stride_shape = patch_input_shape if stride_shape is None else stride_shape - super().__init__( - input_resolutions=input_resolutions, - output_resolutions=[], - stride_shape=stride_shape, - patch_input_shape=patch_input_shape, - patch_output_shape=patch_input_shape, - save_resolution=None, - **kwargs, - ) +class PredictorRunParams(EngineABCRunParams): + """Class describing the input parameters for the :func:`EngineABC.run()` method. + + Attributes: + batch_size (int): + Number of image patches to feed to the model in a forward pass. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. + ioconfig (ModelIOConfigABC): + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. + return_labels (bool): + Whether to return the labels with the predictions. + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + output_file (str): + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. + patch_input_shape (tuple): + Shape of patches input to the model as tuple of height and width (HW). + Patches are requested at read resolution, not with respect to level 0, + and must be positive. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + verbose (bool): + Whether to output logging information. + + """ + + return_probabilities: bool -class PatchPredictor: - r"""Patch level predictor. - The models provided by tiatoolbox should give the following results: +class PatchPredictor(EngineABC): + r"""Patch level prediction for digital histology images. + + The models provided by TIAToolbox should give the following results: .. list-table:: PatchPredictor performance on the Kather100K dataset [1] :widths: 15 15 @@ -135,83 +173,155 @@ class PatchPredictor: - 0.867 Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with. - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, - refer to the `docs + model (str | ModelABC): + A PyTorch model or name of pretrained model. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case-insensitive. - pretrained_weights (str): - Path to the weight of the corresponding `pretrained_model`. - - >>> predictor = PatchPredictor( - ... pretrained_model="resnet18-kather100k", - ... pretrained_weights="resnet18_local_weight") - + of weights using the `weights` parameter. Default is `None`. batch_size (int): - Number of images fed into the model each time. + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. num_loader_workers (int): - Number of workers to load the data. Take note that they will - also perform preprocessing. + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = PatchPredictor( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Attributes: - img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): - A HWC image or a path to WSI. - mode (str): - Type of input to process. Choose from either `patch`, `tile` - or `wsi`. - model (nn.Module): - Defined PyTorch model. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (str): + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, refer to the `docs `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case insensitive. + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. + _ioconfig (ModelIOConfigABC): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. batch_size (int): Number of images fed into the model each time. - num_loader_worker (int): - Number of workers used in torch.utils.data.DataLoader. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + units (Units): + Units of resolution used for reading the image. Choose + from either `baseline`, `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. + Default value is `baseline`. verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Examples: >>> # list of 2 image patches as input - >>> data = [img1, img2] - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> data = ['path/img.svs', 'path/img.svs'] + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') >>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') >>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] - >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") - >>> output = predictor.predict(tile_file, mode='tile') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(tile_file, mode='tile') >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] - >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") - >>> output = predictor.predict(wsi_file, mode='wsi') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(wsi_file, mode='wsi') References: [1] Kather, Jakob Nikolas, et al. "Predicting survival from colorectal cancer @@ -226,658 +336,175 @@ class PatchPredictor: def __init__( self: PatchPredictor, + model: str | ModelABC, batch_size: int = 8, num_loader_workers: int = 0, - model: torch.nn.Module = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, + num_post_proc_workers: int = 0, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, ) -> None: """Initialize :class:`PatchPredictor`.""" - super().__init__() - - self.imgs = None - self.mode = None - - if model is None and pretrained_model is None: - msg = "Must provide either `model` or `pretrained_model`." - raise ValueError(msg) - - if model is not None: - self.model = model - ioconfig = None # retrieve iostate from provided model ? - else: - model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) - - self.ioconfig = ioconfig # for storing original - self._ioconfig = None # for storing runtime - self.model = ( - compile_model( # for runtime, such as after wrapping with nn.DataParallel - model, - mode=rcParam["torch_compile_mode"], - ) + super().__init__( + model=model, + batch_size=batch_size, + num_loader_workers=num_loader_workers, + num_post_proc_workers=num_post_proc_workers, + weights=weights, + device=device, + verbose=verbose, ) - self.pretrained_model = pretrained_model - self.batch_size = batch_size - self.num_loader_worker = num_loader_workers - self.verbose = verbose - - @staticmethod - def merge_predictions( - img: str | Path | np.ndarray, - output: dict, - resolution: Resolution | None = None, - units: Units | None = None, - postproc_func: Callable | None = None, - *, - return_raw: bool = False, - ) -> np.ndarray: - """Merge patch level predictions to form a 2-dimensional prediction map. - - #! Improve how the below reads. - The prediction map will contain values from 0 to N, where N is - the number of classes. Here, 0 is the background which has not - been processed by the model and N is the number of classes - predicted by the model. - - Args: - img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): - A HWC image or a path to WSI. - output (dict): - Output generated by the model. - resolution (Resolution): - Resolution of merged predictions. - units (Units): - Units of resolution used when merging predictions. This - must be the same `units` used when processing the data. - postproc_func (callable): - A function to post-process raw prediction from model. By - default, internal code uses the `np.argmax` function. - return_raw (bool): - Return raw result without applying the `postproc_func` - on the assembled image. - - Returns: - :class:`numpy.ndarray`: - Merged predictions as a 2D array. - Examples: - >>> # pseudo output dict from model with 2 patches - >>> output = { - ... 'resolution': 1.0, - ... 'units': 'baseline', - ... 'probabilities': [[0.45, 0.55], [0.90, 0.10]], - ... 'predictions': [1, 0], - ... 'coordinates': [[0, 0, 2, 2], [2, 2, 4, 4]], - ... } - >>> merged = PatchPredictor.merge_predictions( - ... np.zeros([4, 4]), - ... output, - ... resolution=1.0, - ... units='baseline' - ... ) - >>> merged - ... array([[2, 2, 0, 0], - ... [2, 2, 0, 0], - ... [0, 0, 1, 1], - ... [0, 0, 1, 1]]) - - """ - reader = WSIReader.open(img) - if isinstance(reader, VirtualWSIReader): - logger.warning( - "Image is not pyramidal hence read is forced to be " - "at `units='baseline'` and `resolution=1.0`.", - stacklevel=2, - ) - resolution = 1.0 - units = "baseline" - - canvas_shape = reader.slide_dimensions(resolution=resolution, units=units) - canvas_shape = canvas_shape[::-1] # XY to YX - - # may crash here, do we need to deal with this ? - output_shape = reader.slide_dimensions( - resolution=output["resolution"], - units=output["units"], - ) - output_shape = output_shape[::-1] # XY to YX - fx = np.array(canvas_shape) / np.array(output_shape) - - if "probabilities" not in output: - coordinates = output["coordinates"] - predictions = output["predictions"] - denominator = None - output = np.zeros(list(canvas_shape), dtype=np.float32) - else: - coordinates = output["coordinates"] - predictions = output["probabilities"] - num_class = np.array(predictions[0]).shape[0] - denominator = np.zeros(canvas_shape) - output = np.zeros([*list(canvas_shape), num_class], dtype=np.float32) - - for idx, bound in enumerate(coordinates): - prediction = predictions[idx] - # assumed to be in XY - # top-left for output placement - tl = np.ceil(np.array(bound[:2]) * fx).astype(np.int32) - # bot-right for output placement - br = np.ceil(np.array(bound[2:]) * fx).astype(np.int32) - output[tl[1] : br[1], tl[0] : br[0]] += prediction - if denominator is not None: - denominator[tl[1] : br[1], tl[0] : br[0]] += 1 - - # deal with overlapping regions - if denominator is not None: - output = output / (np.expand_dims(denominator, -1) + 1.0e-8) - if not return_raw: - # convert raw probabilities to predictions - if postproc_func is not None: - output = postproc_func(output) - else: - output = np.argmax(output, axis=-1) - # to make sure background is 0 while class will be 1...N - output[denominator > 0] += 1 - return output - - def _predict_engine( + def post_process_cache_mode( self: PatchPredictor, - dataset: torch.utils.data.Dataset, - device: str = "cpu", - *, - return_probabilities: bool = False, - return_labels: bool = False, - return_coordinates: bool = False, - ) -> np.ndarray: - """Make a prediction on a dataset. The dataset may be mutated. - - Args: - dataset (torch.utils.data.Dataset): - PyTorch dataset object created using - `tiatoolbox.models.data.classification.Patch_Dataset`. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return labels. - return_coordinates (bool): - Whether to return patch coordinates. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". - - Returns: - :class:`numpy.ndarray`: - Model predictions of the input dataset - - """ - dataset.preproc_func = self.model.preproc_func - - # preprocessing must be defined with the dataset - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=self.num_loader_worker, - batch_size=self.batch_size, - drop_last=False, - shuffle=False, - ) - - if self.verbose: - pbar = tqdm.tqdm( - total=int(len(dataloader)), - leave=True, - ncols=80, - ascii=True, - position=0, + raw_predictions: Path, + **kwargs: Unpack[PredictorRunParams], + ) -> Path: + """Returns an array from raw predictions.""" + return_probabilities = kwargs.get("return_probabilities") + zarr_group = zarr.open(str(raw_predictions), mode="r+") + + num_iter = math.ceil(len(zarr_group["probabilities"]) / self.batch_size) + start = 0 + for _ in range(num_iter): + # Probabilities for post-processing + probabilities = zarr_group["probabilities"][start : start + self.batch_size] + start = start + self.batch_size + predictions = self.model.postproc_func( + probabilities, ) - - # use external for testing - model = model_to(model=self.model, device=device) - - cum_output = { - "probabilities": [], - "predictions": [], - "coordinates": [], - "labels": [], - } - for _, batch_data in enumerate(dataloader): - batch_output_probabilities = self.model.infer_batch( - model, - batch_data["image"], - device=device, - ) - # We get the index of the class with the maximum probability - batch_output_predictions = self.model.postproc_func( - batch_output_probabilities, + if "predictions" in zarr_group: + zarr_group["predictions"].append(predictions) + continue + + zarr_dataset = zarr_group.create_dataset( + name="predictions", + shape=predictions.shape, + compressor=zarr_group["probabilities"].compressor, ) + zarr_dataset[:] = predictions - # tolist might be very expensive - cum_output["probabilities"].extend(batch_output_probabilities.tolist()) - cum_output["predictions"].extend(batch_output_predictions.tolist()) - if return_coordinates: - cum_output["coordinates"].extend(batch_data["coords"].tolist()) - if return_labels: # be careful of `s` - # We do not use tolist here because label may be of mixed types - # and hence collated as list by torch - cum_output["labels"].extend(list(batch_data["label"])) - - if self.verbose: - pbar.update() - if self.verbose: - pbar.close() - - if not return_probabilities: - cum_output.pop("probabilities") - if not return_labels: - cum_output.pop("labels") - if not return_coordinates: - cum_output.pop("coordinates") - - return cum_output - - def _update_ioconfig( - self: PatchPredictor, - ioconfig: IOPatchPredictorConfig, - patch_input_shape: IntPair, - stride_shape: IntPair, - resolution: Resolution, - units: Units, - ) -> IOPatchPredictorConfig: - """Updates the ioconfig. + if return_probabilities is not False: + return raw_predictions - Args: - ioconfig (IOPatchPredictorConfig): - Input ioconfig for PatchPredictor. - patch_input_shape (IntPair): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (IntPair): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. - - Returns: - IOPatchPredictorConfig: - Updated Patch Predictor IO configuration. + del zarr_group["probabilities"] - """ - config_flag = ( - patch_input_shape is None, - resolution is None, - units is None, - ) - if ioconfig: - return ioconfig + return raw_predictions - if self.ioconfig is None and any(config_flag): - msg = ( - "Must provide either " - "`ioconfig` or `patch_input_shape`, `resolution`, and `units`." - ) - raise ValueError( - msg, - ) - - if stride_shape is None: - stride_shape = patch_input_shape - - if self.ioconfig: - ioconfig = copy.deepcopy(self.ioconfig) - # ! not sure if there is a nicer way to set this - if patch_input_shape is not None: - ioconfig.patch_input_shape = patch_input_shape - if stride_shape is not None: - ioconfig.stride_shape = stride_shape - if resolution is not None: - ioconfig.input_resolutions[0]["resolution"] = resolution - if units is not None: - ioconfig.input_resolutions[0]["units"] = units - - return ioconfig - - return IOPatchPredictorConfig( - input_resolutions=[{"resolution": resolution, "units": units}], - patch_input_shape=patch_input_shape, - stride_shape=stride_shape, - ) + def post_process_patches( + self: PatchPredictor, + raw_predictions: dict | Path, + **kwargs: Unpack[PredictorRunParams], + ) -> dict | Path: + """Post-process raw patch predictions from inference. - @staticmethod - def _prepare_save_dir(save_dir: str | Path, imgs: list | np.ndarray) -> Path: - """Create directory if not defined and number of images is more than 1. + The output of :func:`infer_patches()` with patch prediction information will be + post-processed using this function. The processed output will be saved in the + respective input format. If `cache_mode` is True, the function processes the + input using zarr group with size specified by `cache_size`. Args: - save_dir (str or Path): - Path to output directory. - imgs (list, ndarray): - List of inputs to process. + raw_predictions (dict | Path): + A dictionary or path to zarr with patch prediction information. + **kwargs (PredictorRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`PredictorRunParams` for accepted keyword arguments. Returns: - :class:`Path`: - Path to output directory. + dict or Path: + Returns patch based output after post-processing. Returns path to + saved zarr file if `cache_mode` is True. """ - if save_dir is None and len(imgs) > 1: - logger.warning( - "More than 1 WSIs detected but there is no save directory set." - "All subsequent output will be saved to current runtime" - "location under folder 'output'. Overwriting may happen!", - stacklevel=2, - ) - save_dir = Path.cwd() / "output" - elif save_dir is not None and len(imgs) > 1: - logger.warning( - "When providing multiple whole-slide images / tiles, " - "we save the outputs and return the locations " - "to the corresponding files.", - stacklevel=2, - ) - - if save_dir is not None: - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=False) + return_probabilities = kwargs.get("return_probabilities") + if self.cache_mode: + return self.post_process_cache_mode(raw_predictions, **kwargs) - return save_dir + probabilities = raw_predictions.get("probabilities") - def _predict_patch( - self: PatchPredictor, - imgs: list | np.ndarray, - labels: list, - device: str = "cpu", - *, - return_probabilities: bool, - return_labels: bool, - ) -> np.ndarray: - """Process patch mode. + predictions = self.model.postproc_func( + probabilities, + ) - Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - labels (list): - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return the labels with the predictions. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". + raw_predictions["predictions"] = predictions - Returns: - :class:`numpy.ndarray`: - Model predictions of the input dataset + if return_probabilities is not False: + return raw_predictions - """ - if labels: - # if a labels is provided, then return with the prediction - return_labels = bool(labels) - - if labels and len(labels) != len(imgs): - msg = f"len(labels) != len(imgs) : {len(labels)} != {len(imgs)}" - raise ValueError( - msg, - ) + del raw_predictions["probabilities"] - # don't return coordinates if patches are already extracted - return_coordinates = False - dataset = PatchDataset(imgs, labels) - return self._predict_engine( - dataset, - return_probabilities=return_probabilities, - return_labels=return_labels, - return_coordinates=return_coordinates, - device=device, - ) + return raw_predictions - def _predict_tile_wsi( # noqa: PLR0913 + def post_process_wsi( self: PatchPredictor, - imgs: list, - masks: list | None, - labels: list, - mode: str, - ioconfig: IOPatchPredictorConfig, - save_dir: str | Path, - highest_input_resolution: list[dict], - device: str = "cpu", - *, - save_output: bool, - return_probabilities: bool, - merge_predictions: bool, - ) -> list | dict: - """Predict on Tile and WSIs. + raw_predictions: dict | Path, + **kwargs: Unpack[PredictorRunParams], + ) -> dict | Path: + """Post process WSI output. - Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - masks (list): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. - labels (list): - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - mode (str): - Type of input to process. Choose from either `patch`, - `tile` or `wsi`. - return_probabilities (bool): - Whether to return per-class probabilities. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". - ioconfig (IOPatchPredictorConfig): - Patch Predictor IO configuration.. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable for `mode='wsi'` or - `mode='tile'`. - save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - save_output (bool): - Whether to save output for a single file. default=False - highest_input_resolution (list(dict)): - Highest available input resolution. - - - Returns: - dict: - Results are saved to `save_dir` and a dictionary indicating save - location for each input is returned. The dict is in the following - format: - - img_path: path of the input image. - - raw: path to save location for raw prediction, - saved in .json. - - merged: path to .npy contain merged - predictions if - `merge_predictions` is `True`. + Takes the raw output from patch predictions and post-processes it to improve the + results e.g., using information from neighbouring patches. """ - # return coordinates of patches processed within a tile / whole-slide image - return_coordinates = True + return self.post_process_cache_mode(raw_predictions, **kwargs) - input_is_path_like = isinstance(imgs[0], (str, Path)) - default_save_dir = ( - imgs[0].parent / "output" if input_is_path_like else Path.cwd() - ) - save_dir = default_save_dir if save_dir is None else Path(save_dir) - - # None if no output - outputs = None - - self._ioconfig = ioconfig - # generate a list of output file paths if number of input images > 1 - file_dict = OrderedDict() - - if len(imgs) > 1: - save_output = True - - for idx, img_path in enumerate(imgs): - img_path_ = Path(img_path) - img_label = None if labels is None else labels[idx] - img_mask = None if masks is None else masks[idx] - - dataset = WSIPatchDataset( - img_path_, - mode=mode, - mask_path=img_mask, - patch_input_shape=ioconfig.patch_input_shape, - stride_shape=ioconfig.stride_shape, - resolution=ioconfig.input_resolutions[0]["resolution"], - units=ioconfig.input_resolutions[0]["units"], - ) - output_model = self._predict_engine( - dataset, - return_labels=False, - return_probabilities=return_probabilities, - return_coordinates=return_coordinates, - device=device, - ) - output_model["label"] = img_label - # add extra information useful for downstream analysis - output_model["pretrained_model"] = self.pretrained_model - output_model["resolution"] = highest_input_resolution["resolution"] - output_model["units"] = highest_input_resolution["units"] - - outputs = [output_model] # assign to a list - merged_prediction = None - if merge_predictions: - merged_prediction = self.merge_predictions( - img_path_, - output_model, - resolution=output_model["resolution"], - units=output_model["units"], - postproc_func=self.model.postproc, - ) - outputs.append(merged_prediction) - - if save_output: - # dynamic 0 padding - img_code = f"{idx:0{len(str(len(imgs)))}d}" - - save_info = {} - save_path = save_dir / img_code - raw_save_path = f"{save_path}.raw.json" - save_info["raw"] = raw_save_path - save_as_json(output_model, raw_save_path) - if merge_predictions: - merged_file_path = f"{save_path}.merged.npy" - np.save(merged_file_path, merged_prediction) - save_info["merged"] = merged_file_path - file_dict[str(img_path_)] = save_info - - return file_dict if save_output else outputs - - def predict( # noqa: PLR0913 + def run( self: PatchPredictor, - imgs: list, - masks: list | None = None, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, labels: list | None = None, - mode: str = "patch", - ioconfig: IOPatchPredictorConfig | None = None, - patch_input_shape: tuple[int, int] | None = None, - stride_shape: tuple[int, int] | None = None, - resolution: Resolution | None = None, - units: Units = None, - device: str = "cpu", + ioconfig: ModelIOConfigABC | None = None, *, - return_probabilities: bool = False, - return_labels: bool = False, - merge_predictions: bool = False, - save_dir: str | Path | None = None, - save_output: bool = False, - ) -> np.ndarray | list | dict: - """Make a prediction for a list of input data. + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[PredictorRunParams], + ) -> AnnotationStore | Path | str | dict: + """Run the engine on input images. Args: - imgs (list, ndarray): + images (list, ndarray): List of inputs to process. when using `patch` mode, the input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - masks (list): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - mode (str): - Type of input to process. Choose from either `patch`, - `tile` or `wsi`. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return the labels with the predictions. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. ioconfig (IOPatchPredictorConfig): - Patch Predictor IO configuration. - patch_input_shape (tuple): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. Choose - from either `level`, `power` or `mpp`. Please see - :obj:`WSIReader` for details. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable for `mode='wsi'` or - `mode='tile'`. + IO configuration. save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - save_output (bool): - Whether to save output for a single file. default=False + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr" or "AnnotationStore". Default value is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (PredictorRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. Returns: - (:class:`numpy.ndarray` or list or dict): + (:class:`numpy.ndarray`, dict): Model predictions of the input dataset. If multiple - image tiles or whole-slide images are provided as input, + whole slide images are provided as input, or save_output is True, then results are saved to `save_dir` and a dictionary indicating save location for each input is returned. @@ -887,84 +514,37 @@ def predict( # noqa: PLR0913 - img_path: path of the input image. - raw: path to save location for raw prediction, saved in .json. - - merged: path to .npy contain merged - predictions if `merge_predictions` is `True`. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] - >>> predictor = PatchPredictor( - ... pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(wsis, mode="wsi") + >>> class PatchPredictor(EngineABC): + >>> # Define all Abstract methods. + >>> ... + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = predictor.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" + >>> output = predictor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] - ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} - >>> output['wsi2.svs'] - ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} + ... {'/path/to/wsi1.db'} """ - if mode not in ["patch", "wsi", "tile"]: - msg = f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`" - raise ValueError( - msg, - ) - if mode == "patch": - return self._predict_patch( - imgs, - labels, - return_probabilities=return_probabilities, - return_labels=return_labels, - device=device, - ) - - if not isinstance(imgs, list): - msg = "Input to `tile` and `wsi` mode must be a list of file paths." - raise TypeError( - msg, - ) - - if mode == "wsi" and masks is not None and len(masks) != len(imgs): - msg = f"len(masks) != len(imgs) : {len(masks)} != {len(imgs)}" - raise ValueError( - msg, - ) - - ioconfig = self._update_ioconfig( - ioconfig, - patch_input_shape, - stride_shape, - resolution, - units, - ) - if mode == "tile": - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - ioconfig = ioconfig.to_baseline() - - fx_list = ioconfig.scale_to_highest( - ioconfig.input_resolutions, - ioconfig.input_resolutions[0]["units"], - ) - fx_list = zip(fx_list, ioconfig.input_resolutions) - fx_list = sorted(fx_list, key=lambda x: x[0]) - highest_input_resolution = fx_list[0][1] - - save_dir = self._prepare_save_dir(save_dir, imgs) - - return self._predict_tile_wsi( - imgs=imgs, + return super().run( + images=images, masks=masks, labels=labels, - mode=mode, - return_probabilities=return_probabilities, - device=device, ioconfig=ioconfig, - merge_predictions=merge_predictions, + patch_mode=patch_mode, save_dir=save_dir, - save_output=save_output, - highest_input_resolution=highest_input_resolution, + overwrite=overwrite, + output_type=output_type, + **kwargs, ) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 791f369f7..eb7d17b75 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -20,14 +20,15 @@ from tiatoolbox import logger, rcParam from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.architecture.utils import compile_model -from tiatoolbox.models.models_abc import IOConfigABC, model_to +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset +from tiatoolbox.models.models_abc import model_to from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader -if TYPE_CHECKING: # pragma: no cover - from multiprocessing.managers import Namespace +from .io_config import IOSegmentorConfig +if TYPE_CHECKING: # pragma: no cover from tiatoolbox.type_hints import IntPair, Resolution, Units @@ -111,327 +112,6 @@ def _prepare_save_output( return is_on_drive, count_canvas, cum_canvas -class IOSegmentorConfig(IOConfigABC): - """Contain semantic segmentor input and output information. - - Args: - input_resolutions (list): - Resolution of each input head of model inference, must be in - the same order as `target model.forward()`. - output_resolutions (list): - Resolution of each output head from model inference, must be - in the same order as target model.infer_batch(). - patch_input_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest input in (height, width). - patch_output_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest output in (height, width). - save_resolution (dict): - Resolution to save all output. - - Examples: - >>> # Defining io for a network having 1 input and 1 output at the - >>> # same resolution - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... ) - - Examples: - >>> # Defining io for a network having 3 input and 2 output - >>> # at the same resolution, the output is then merged at a - >>> # different resolution. - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... {"units": "mpp", "resolution": 0.75}, - ... ], - ... output_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... ], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... save_resolution={"units": "mpp", "resolution": 4.0}, - ... ) - - """ - - # We pre-define to follow enforcement, actual initialisation in init - input_resolutions = None - output_resolutions = None - - def __init__( - self: IOSegmentorConfig, - input_resolutions: list[dict], - output_resolutions: list[dict], - patch_input_shape: IntPair, - patch_output_shape: IntPair, - save_resolution: dict | None = None, - **kwargs: dict, - ) -> None: - """Initialize :class:`IOSegmentorConfig`.""" - self._kwargs = kwargs - self.patch_input_shape = patch_input_shape - self.patch_output_shape = patch_output_shape - self.stride_shape = None - self.input_resolutions = input_resolutions - self.output_resolutions = output_resolutions - - self.resolution_unit = input_resolutions[0]["units"] - self.save_resolution = save_resolution - - for variable, value in kwargs.items(): - self.__setattr__(variable, value) - - self._validate() - - if self.resolution_unit == "mpp": - self.highest_input_resolution = min( - self.input_resolutions, - key=lambda x: x["resolution"], - ) - else: - self.highest_input_resolution = max( - self.input_resolutions, - key=lambda x: x["resolution"], - ) - - def _validate(self: IOSegmentorConfig) -> None: - """Validate the data format.""" - resolutions = self.input_resolutions + self.output_resolutions - units = [v["units"] for v in resolutions] - units = np.unique(units) - if len(units) != 1 or units[0] not in [ - "power", - "baseline", - "mpp", - ]: - msg = f"Invalid resolution units `{units[0]}`." - raise ValueError(msg) - - @staticmethod - def scale_to_highest(resolutions: list[dict], units: Units) -> np.ndarray: - """Get the scaling factor from input resolutions. - - This will convert resolutions to a scaling factor with respect to - the highest resolution found in the input resolutions list. - - Args: - resolutions (list): - A list of resolutions where one is defined as - `{'resolution': value, 'unit': value}` - units (Units): - Units that the resolutions are at. - - Returns: - :class:`numpy.ndarray`: - A 1D array of scaling factors having the same length as - `resolutions` - - """ - old_val = [v["resolution"] for v in resolutions] - if units not in ["baseline", "mpp", "power"]: - msg = ( - f"Unknown units `{units}`. " - f"Units should be one of 'baseline', 'mpp' or 'power'." - ) - raise ValueError( - msg, - ) - if units == "baseline": - return old_val - if units == "mpp": - return np.min(old_val) / np.array(old_val) - return np.array(old_val) / np.max(old_val) - - def to_baseline(self: IOSegmentorConfig) -> IOSegmentorConfig: - """Return a new config object converted to baseline form. - - This will return a new :class:`IOSegmentorConfig` where - resolutions have been converted to baseline format with the - highest possible resolution found in both input and output as - reference. - - """ - resolutions = self.input_resolutions + self.output_resolutions - if self.save_resolution is not None: - resolutions.append(self.save_resolution) - - scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) - num_input_resolutions = len(self.input_resolutions) - num_output_resolutions = len(self.output_resolutions) - - end_idx = num_input_resolutions - input_resolutions = [ - {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] - ] - end_idx = num_input_resolutions + num_output_resolutions - output_resolutions = [ - {"units": "baseline", "resolution": v} - for v in scale_factors[num_input_resolutions:end_idx] - ] - - save_resolution = None - if self.save_resolution is not None: - save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} - return IOSegmentorConfig( - input_resolutions=input_resolutions, - output_resolutions=output_resolutions, - patch_input_shape=self.patch_input_shape, - patch_output_shape=self.patch_output_shape, - save_resolution=save_resolution, - **self._kwargs, - ) - - -class WSIStreamDataset(torch_data.Dataset): - """Reading a wsi in parallel mode with persistent workers. - - To speed up the inference process for multiple WSIs. The - `torch.utils.data.Dataloader` is set to run in persistent mode. - Normally, this will prevent workers from altering their initial - states (such as provided input etc.). To sidestep this, we use a - shared parallel workspace context manager to send data and signal - from the main thread, thus allowing each worker to load a new wsi as - well as corresponding patch information. - - Args: - mp_shared_space (:class:`Namespace`): - A shared multiprocessing space, must be from - `torch.multiprocessing`. - ioconfig (:class:`IOSegmentorConfig`): - An object which contains I/O placement for patches. - wsi_paths (list): List of paths pointing to a WSI or tiles. - preproc (Callable): - Pre-processing function to be applied to a patch. - mode (str): - Either `"wsi"` or `"tile"` to indicate the format of images - in `wsi_paths`. - - Examples: - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... ) - >>> mp_manager = torch_mp.Manager() - >>> mp_shared_space = mp_manager.Namespace() - >>> mp_shared_space.signal = 1 # adding variable to the shared space - >>> wsi_paths = ['A.svs', 'B.svs'] - >>> wsi_dataset = WSIStreamDataset(ioconfig, wsi_paths, mp_shared_space) - - """ - - def __init__( - self: WSIStreamDataset, - ioconfig: IOSegmentorConfig, - wsi_paths: list[str | Path], - mp_shared_space: Namespace, - preproc: Callable[[np.ndarray], np.ndarray] | None = None, - mode: str = "wsi", - ) -> None: - """Initialize :class:`WSIStreamDataset`.""" - super().__init__() - self.mode = mode - self.preproc = preproc - self.ioconfig = copy.deepcopy(ioconfig) - - if mode == "tile": - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - self.ioconfig = self.ioconfig.to_baseline() - - self.mp_shared_space = mp_shared_space - self.wsi_paths = wsi_paths - self.wsi_idx = None # to be received externally via thread communication - self.reader = None - - def _get_reader(self: WSIStreamDataset, img_path: str | Path) -> WSIReader: - """Get appropriate reader for input path.""" - img_path = Path(img_path) - if self.mode == "wsi": - return WSIReader.open(img_path) - img = imread(img_path) - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - objective_power=10, - axes="YXS", - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - return VirtualWSIReader( - img, - info=metadata, - ) - - def __len__(self: WSIStreamDataset) -> int: - """Return the length of the instance attributes.""" - return len(self.mp_shared_space.patch_inputs) - - @staticmethod - def collate_fn(batch: list | np.ndarray) -> torch.Tensor: - """Prototype to handle reading exception. - - This will exclude any sample with `None` from the batch. As - such, wrapping `__getitem__` with try-catch and return `None` - upon exceptions will prevent crashing the entire program. But as - a side effect, the batch may not have the size as defined. - - """ - batch = [v for v in batch if v is not None] - return torch.utils.data.dataloader.default_collate(batch) - - def __getitem__(self: WSIStreamDataset, idx: int) -> tuple: - """Get an item from the dataset.""" - # ! no need to lock as we do not modify source value in shared space - if self.wsi_idx != self.mp_shared_space.wsi_idx: - self.wsi_idx = int(self.mp_shared_space.wsi_idx.item()) - self.reader = self._get_reader(self.wsi_paths[self.wsi_idx]) - - # this is in XY and at requested resolution (not baseline) - bounds = self.mp_shared_space.patch_inputs[idx] - bounds = bounds.numpy() # expected to be a torch.Tensor - - # be the same as bounds br-tl, unless bounds are of float - patch_data_ = [] - scale_factors = self.ioconfig.scale_to_highest( - self.ioconfig.input_resolutions, - self.ioconfig.resolution_unit, - ) - for idy, resolution in enumerate(self.ioconfig.input_resolutions): - resolution_bounds = np.round(bounds * scale_factors[idy]) - patch_data = self.reader.read_bounds( - resolution_bounds.astype(np.int32), - coord_space="resolution", - pad_constant_values=0, # expose this ? - **resolution, - ) - - if self.preproc is not None: - patch_data = patch_data.copy() - patch_data = self.preproc(patch_data) - patch_data_.append(patch_data) - if len(patch_data_) == 1: - patch_data_ = patch_data_[0] - - bound = self.mp_shared_space.patch_outputs[idx] - return patch_data_, bound - - class SemanticSegmentor: """Pixel-wise segmentation predictor. @@ -1075,8 +755,8 @@ def _prepare_save_dir(save_dir: str | Path | None) -> tuple[Path, Path]: return save_dir, cache_dir + @staticmethod def _update_ioconfig( - self: SemanticSegmentor, ioconfig: IOSegmentorConfig, mode: str, patch_input_shape: IntPair, @@ -1124,17 +804,7 @@ def _update_ioconfig( if stride_shape is None: stride_shape = patch_output_shape - if ioconfig is None and patch_input_shape is None: - if self.ioconfig is None: - msg = ( - "Must provide either `ioconfig` or `patch_input_shape` " - "and `patch_output_shape`" - ) - raise ValueError( - msg, - ) - ioconfig = copy.deepcopy(self.ioconfig) - elif ioconfig is None: + if ioconfig is None: ioconfig = IOSegmentorConfig( input_resolutions=[{"resolution": resolution, "units": units}], output_resolutions=[{"resolution": resolution, "units": units}], @@ -1197,7 +867,7 @@ def _predict_wsi_handle_exception( of file paths. wsi_idx (int): index of current WSI being processed. - img_path(str or Path): + img_path(str): Path to current image. mode (str): Type of input to process. Choose from either `tile` or @@ -1209,7 +879,7 @@ def _predict_wsi_handle_exception( `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - save_dir (str or Path): + save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. @@ -1358,6 +1028,28 @@ def predict( # noqa: PLR0913 save_dir, self._cache_dir = self._prepare_save_dir(save_dir) + if ioconfig is None: + ioconfig = copy.deepcopy(self.ioconfig) + + if ioconfig is None and patch_input_shape is None: + msg = ( + "Must provide either `ioconfig` or " + "`patch_input_shape` and `patch_output_shape`" + ) + raise ValueError( + msg, + ) + + if resolution is None and units is None: + if ioconfig is None: + msg = f"Invalid resolution: `{resolution}` and units: `{units}`. " + raise ValueError( + msg, + ) + + resolution = ioconfig.input_resolutions[0]["resolution"] + units = ioconfig.input_resolutions[0]["units"] + ioconfig = self._update_ioconfig( ioconfig, mode, @@ -1628,7 +1320,7 @@ def predict( # noqa: PLR0913 Resolution used for reading the image. units (Units): Units of resolution used for reading the image. - save_dir (str): + save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index a3af4e7f0..0e2b7d81f 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -8,6 +8,7 @@ import torch import torch._dynamo from torch import device as torch_device +from torch import nn torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001 @@ -18,29 +19,29 @@ import numpy as np -class IOConfigABC(ABC): - """Define an abstract class for holding predictor I/O information. +def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: + """Helper function to load a torch model. - Enforcing such that following attributes must always be defined by - the subclass. - - """ + Args: + model (torch.nn.Module): + A torch model. + weights (str or Path): + Path to pretrained weights. - @property - @abstractmethod - def input_resolutions(self: IOConfigABC) -> None: - """Abstract method to update input_resolution.""" - raise NotImplementedError + Returns: + torch.nn.Module: + Torch model with pretrained weights loaded on CPU. - @property - @abstractmethod - def output_resolutions(self: IOConfigABC) -> None: - """Abstract method to update output_resolutions.""" - raise NotImplementedError + """ + # ! assume to be saved in single GPU mode + # always load on to the CPU + saved_state_dict = torch.load(weights, map_location="cpu") + model.load_state_dict(saved_state_dict, strict=True) + return model def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: - """Transfers model to specified device e.g., "cpu" or "cuda". + """Transfers model to cpu/gpu. Args: model (torch.nn.Module): @@ -50,7 +51,7 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: Returns: torch.nn.Module: - The model after being moved to specified device. + The model after being moved to cpu/gpu. """ if device != "cpu": @@ -78,11 +79,7 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: @staticmethod @abstractmethod - def infer_batch( - model: torch.nn.Module, - batch_data: np.ndarray, - device: str, - ) -> None: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -179,7 +176,7 @@ def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. Args: - model (torch.nn.Module): + self (ModelABC): PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index a12dd96b2..364ff76a0 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1201,19 +1201,54 @@ def add_from_dat( store.append_many(anns) +def patch_predictions_as_annotations( + preds: list | np.ndarray, + keys: list, + class_dict: dict, + class_probs: list | np.ndarray, + patch_coords: list, + classes_predicted: list, + labels: list, +) -> list: + """Helper function to generate annotation per patch predictions.""" + annotations = [] + for i, _ in enumerate(patch_coords): + if "probabilities" in keys: + props = { + f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted + } + else: + props = {} + if "labels" in keys: + props["label"] = class_dict[labels[i]] + if len(preds) > 0: + props["type"] = class_dict[preds[i]] + annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) + + return annotations + + +def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray | list) -> np.ndarray: + """Converts a zarr array into a numpy array.""" + if isinstance(zarr_array, zarr.core.Array): + return zarr_array[:] + + return np.array(zarr_array).astype(float) + + def dict_to_store( - patch_output: dict, - scale_factor: tuple[int, int], + patch_output: dict | zarr.group, + scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, ) -> AnnotationStore | Path: """Converts (and optionally saves) output of TIAToolbox engines as AnnotationStore. Args: - patch_output (dict): - A dictionary in the TIAToolbox Engines output format. Important - keys are "probabilities", "predictions", "coordinates", and "labels". - scale_factor (tuple[int, int]): + patch_output (dict | zarr.Group): + A dictionary with "probabilities", "predictions", "coordinates", + and "labels" keys. + scale_factor (tuple[float, float]): The scale factor to use when loading the annotations. All coordinates will be multiplied by this factor to allow conversion of annotations saved at non-baseline resolution to baseline. @@ -1235,45 +1270,50 @@ def dict_to_store( # we cant create annotations without coordinates msg = "Patch output must contain coordinates." raise ValueError(msg) + # get relevant keys - class_probs = patch_output.get("probabilities", []) - preds = patch_output.get("predictions", []) + class_probs = get_zarr_array(patch_output.get("probabilities", [])) + preds = get_zarr_array(patch_output.get("predictions", [])) + patch_coords = np.array(patch_output.get("coordinates", [])) if not np.all(np.array(scale_factor) == 1): patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp + patch_coords = patch_coords.astype(float) labels = patch_output.get("labels", []) # get classes to consider if len(class_probs) == 0: classes_predicted = np.unique(preds).tolist() else: classes_predicted = range(len(class_probs[0])) + if class_dict is None: # if no class dict create a default one - class_dict = {i: i for i in np.unique(preds + labels).tolist()} + if len(class_probs) == 0: + class_dict = {i: i for i in np.unique(np.append(preds, labels)).tolist()} + else: + class_dict = {i: i for i in range(len(class_probs[0]))} # find what keys we need to save keys = ["predictions"] keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output] # put patch predictions into a store - annotations = [] - for i, pred in enumerate(preds): - if "probabilities" in keys: - props = { - f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted - } - else: - props = {} - if "labels" in keys: - props["label"] = class_dict[labels[i]] - props["type"] = class_dict[pred] - annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) + annotations = patch_predictions_as_annotations( + preds, + keys, + class_dict, + class_probs, + patch_coords, + classes_predicted, + labels, + ) + store = SQLiteStore() - keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) + _ = store.append_many(annotations, [str(i) for i in range(len(annotations))]) # if a save director is provided, then dump store into a file if save_path: - # ensure parent directory exisits + # ensure parent directory exists save_path.parent.absolute().mkdir(parents=True, exist_ok=True) # ensure proper db extension save_path = save_path.parent.absolute() / (save_path.stem + ".db") @@ -1313,14 +1353,155 @@ def dict_to_zarr( save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") # save to zarr - predictions_array = np.array(raw_predictions["predictions"]) + probabilities_array = np.array(raw_predictions["probabilities"]) z = zarr.open( - save_path, + str(save_path), mode="w", - shape=predictions_array.shape, + shape=probabilities_array.shape, chunks=chunks, compressor=compressor, ) - z[:] = predictions_array + z[:] = probabilities_array return save_path + + +def wsi_batch_output_to_zarr_group( + wsi_batch_zarr_group: zarr.group | None, + batch_output_probabilities: np.ndarray, + batch_output_predictions: np.ndarray, + batch_output_coordinates: np.ndarray | None, + batch_output_label: np.ndarray | None, + save_path: Path, + **kwargs: dict, +) -> zarr.group | Path: + """Saves the intermediate batch outputs of TIAToolbox engines to a zarr file. + + Args: + wsi_batch_zarr_group (zarr.group): + Optional zarr group name consisting of zarrs to save the batch output + values. + batch_output_probabilities (np.ndarray): + Probability batch output from infer wsi. + batch_output_predictions (np.ndarray): + Predictions batch output from infer wsi. + batch_output_coordinates (np.ndarray): + Coordinates batch output from infer wsi. + batch_output_label (np.ndarray): + Labels batch output from infer wsi. + save_path (str or Path): + Path to save the zarr file. + **kwargs (dict): + Keyword Args to update wsi_batch_output_to_zarr_group attributes. + + Returns: + Path to the zarr file storing the :class:`EngineABC` output. + + """ + # Default values for Compressor and Chunks set if not received from kwargs. + compressor = ( + kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) + ) + chunks = kwargs.get("chunks", 10000) + + # case 1 - new zarr group + if not wsi_batch_zarr_group: + # ensure proper zarr extension and create persistant zarr group + save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") + wsi_batch_zarr_group = zarr.open(save_path, mode="w") + + # populate the zarr group for the first time + probabilities_zarr = wsi_batch_zarr_group.create_dataset( + name="probabilities", + shape=batch_output_probabilities.shape, + chunks=chunks, + compressor=compressor, + ) + probabilities_zarr[:] = batch_output_probabilities + + predictions_zarr = wsi_batch_zarr_group.create_dataset( + name="predictions", + shape=batch_output_predictions.shape, + chunks=chunks, + compressor=compressor, + ) + predictions_zarr[:] = batch_output_predictions + + if batch_output_coordinates is not None: + coordinates_zarr = wsi_batch_zarr_group.create_dataset( + name="coordinates", + shape=batch_output_coordinates.shape, + chunks=chunks, + compressor=compressor, + ) + coordinates_zarr[:] = batch_output_coordinates + + if batch_output_label is not None: + labels_zarr = wsi_batch_zarr_group.create_dataset( + name="labels", + shape=batch_output_label.shape, + chunks=chunks, + compressor=compressor, + ) + labels_zarr[:] = batch_output_label + + # case 2 - append to existing zarr group + probabilities_zarr = wsi_batch_zarr_group["probabilities"] + probabilities_zarr.append(batch_output_probabilities) + + predictions_zarr = wsi_batch_zarr_group["predictions"] + predictions_zarr.append(batch_output_predictions) + + if batch_output_coordinates is not None: + coordinates_zarr = wsi_batch_zarr_group["coordinates"] + coordinates_zarr.append(batch_output_coordinates) + + if batch_output_label is not None: + labels_zarr = wsi_batch_zarr_group["labels"] + labels_zarr.append(batch_output_label) + + return wsi_batch_zarr_group + + +def write_to_zarr_in_cache_mode( + zarr_group: zarr.group, + output_data_to_save: dict, + **kwargs: dict, +) -> zarr.group | Path: + """Saves the intermediate batch outputs of TIAToolbox engines to a zarr file. + + Args: + zarr_group (zarr.group): + Zarr group name consisting of zarr(s) to save the batch output + values. + output_data_to_save (dict): + Output data from the Engine to save to Zarr. Expects the data saved in + dictionary to be a numpy array. + **kwargs (dict): + Keyword Args to update zarr_group attributes. + + Returns: + Path to the zarr file storing the :class:`EngineABC` output. + + """ + # Default values for Compressor and Chunks set if not received from kwargs. + compressor = kwargs.get("compressor", numcodecs.Zstd(level=1)) + + # case 1 - new zarr group + if not zarr_group: + for key, value in output_data_to_save.items(): + # populate the zarr group for the first time + zarr_dataset = zarr_group.create_dataset( + name=key, + shape=value.shape, + compressor=compressor, + ) + zarr_dataset[:] = value + + return zarr_group + + # case 2 - append to existing zarr group + for key, value in output_data_to_save.items(): + zarr_group[key].append(value) + + return zarr_group