From ef55e9539189abddbce62594a585d8c1b0354dd9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 24 Mar 2023 11:18:31 +0000 Subject: [PATCH 01/36] :wrench: Use `pyproject.toml` for `bdist_wheel` configuration - Use `pyproject.toml` for `bdist_wheel` configuration --- pyproject.toml | 7 +++++++ setup.cfg | 6 ------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c6924538e..669feba45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,3 +34,10 @@ ] ignore_errors = true omit = ['tests/*', 'tiatoolbox/__main__.py', '*/utils/env_detection.py'] + +[build-system] + requires = ["setuptools"] + build-backend = "setuptools.build_meta" + +[tool.distutils.bdist_wheel] + universal = true diff --git a/setup.cfg b/setup.cfg index 638c5f2ec..b459f3c7d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,9 +19,6 @@ replace = version: {new_version} # TIAToolbox version search = TOOLBOX_VER: {current_version} replace = TOOLBOX_VER: {new_version} -[bdist_wheel] -universal = 1 - [flake8] exclude = docs, *__init__*, setup.py max-line-length = 88 @@ -31,9 +28,6 @@ dictionaries = en_US,python,technical max-cognitive-complexity = 14 max-expression-complexity = 7 -[aliases] -test = pytest - [tool:pytest] collect_ignore = ['setup.py', 'benchmark/'] From 8ba6defc5b1fa49b3f0f5fb121728bee322427d3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 31 Mar 2023 10:54:39 +0100 Subject: [PATCH 02/36] :zap: Improve `Engines` performance and implementation - Improve `Engines` performance and implementation --- tiatoolbox/models/engine/engines_abc.py | 75 +++++++++++++++++++ .../models/engine/semantic_segmentor.py | 2 +- tiatoolbox/models/models_abc.py | 19 ----- 3 files changed, 76 insertions(+), 20 deletions(-) create mode 100644 tiatoolbox/models/engine/engines_abc.py diff --git a/tiatoolbox/models/engine/engines_abc.py b/tiatoolbox/models/engine/engines_abc.py new file mode 100644 index 000000000..d8ff14a1f --- /dev/null +++ b/tiatoolbox/models/engine/engines_abc.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple, Union + +import numpy as np + + +class IOConfigABC(ABC): + """Define an abstract class for holding predictor I/O information. + + Enforcing such that following attributes must always be defined by + the subclass. + + """ + + def __init__( + self, + input_resolutions: List[dict], + patch_input_shape: Union[List[int], np.ndarray, Tuple[int]], + stride_shape: Union[List[int], np.ndarray, Tuple[int]], + **kwargs, + ): + self._kwargs = kwargs + self.resolution_unit = input_resolutions[0]["units"] + self.patch_input_shape = patch_input_shape + self.stride_shape = stride_shape + + self._validate() + + if self.resolution_unit == "mpp": + self.highest_input_resolution = min( + self.input_resolutions, key=lambda x: x["resolution"] + ) + else: + self.highest_input_resolution = max( + self.input_resolutions, key=lambda x: x["resolution"] + ) + + def _validate(self): + """Validate the data format.""" + resolutions = self.input_resolutions + self.output_resolutions + units = [v["units"] for v in resolutions] + units = np.unique(units) + if len(units) != 1 or units[0] not in [ + "power", + "baseline", + "mpp", + ]: + raise ValueError(f"Invalid resolution units `{units[0]}`.") + + @property + @abstractmethod + def input_resolutions(self): + raise NotImplementedError + + @property + @abstractmethod + def output_resolutions(self): + raise NotImplementedError + + +class EnginesABC(ABC): + """Abstract base class for engines used in tiatoolbox.""" + + def __init__(self): + super().__init__() + + @abstractmethod + def process_patch(self): + raise NotImplementedError + + # how to deal with patches, list of patches/numpy arrays, WSIs + # how to communicate with sub-processes. + # define how to deal with patches as numpy/zarr arrays. + # convert list of patches/numpy arrays to zarr and then pass to each sub-processes. + # define how to read WSIs, read the image and convert to zarr array. diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index a28d6a421..f06a65029 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -20,7 +20,7 @@ import tqdm from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.models_abc import IOConfigABC +from tiatoolbox.models.engine.engines_abc import IOConfigABC from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import misc from tiatoolbox.utils.misc import imread diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 32ecce5fa..5cba4b32f 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -4,25 +4,6 @@ import torch.nn as nn -class IOConfigABC(ABC): - """Define an abstract class for holding predictor I/O information. - - Enforcing such that following attributes must always be defined by - the subclass. - - """ - - @property - @abstractmethod - def input_resolutions(self): - raise NotImplementedError - - @property - @abstractmethod - def output_resolutions(self): - raise NotImplementedError - - class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" From fac10007db8833d9259053322ce5290505205e0e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 28 Apr 2023 10:29:27 +0100 Subject: [PATCH 03/36] :recycle: Refactor engines_abc.py - Refactor engines_abc.py --- tiatoolbox/models/engine/{engines_abc.py => engine_abc.py} | 2 +- tiatoolbox/models/engine/semantic_segmentor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename tiatoolbox/models/engine/{engines_abc.py => engine_abc.py} (98%) diff --git a/tiatoolbox/models/engine/engines_abc.py b/tiatoolbox/models/engine/engine_abc.py similarity index 98% rename from tiatoolbox/models/engine/engines_abc.py rename to tiatoolbox/models/engine/engine_abc.py index d8ff14a1f..3f51e5681 100644 --- a/tiatoolbox/models/engine/engines_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -58,7 +58,7 @@ def output_resolutions(self): raise NotImplementedError -class EnginesABC(ABC): +class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox.""" def __init__(self): diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index f06a65029..872f2161e 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -20,7 +20,7 @@ import tqdm from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.engine.engines_abc import IOConfigABC +from tiatoolbox.models.engine.engine_abc import IOConfigABC from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import misc from tiatoolbox.utils.misc import imread From 36fd6290d5dd8492947d7a60be9f062bf9e6298e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Jul 2023 08:03:56 +0000 Subject: [PATCH 04/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/models_abc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index d0810f159..01b4ff690 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -4,7 +4,6 @@ from torch import nn - class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" From e608f7bacbe2b07878366c66a6b8ac616f96cd80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Jul 2023 18:20:27 +0000 Subject: [PATCH 05/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/engine/engine_abc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 3f51e5681..365218edd 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -28,11 +28,11 @@ def __init__( if self.resolution_unit == "mpp": self.highest_input_resolution = min( - self.input_resolutions, key=lambda x: x["resolution"] + self.input_resolutions, key=lambda x: x["resolution"], ) else: self.highest_input_resolution = max( - self.input_resolutions, key=lambda x: x["resolution"] + self.input_resolutions, key=lambda x: x["resolution"], ) def _validate(self): From b956bf520d005f6db87b3f3800b7e99d2f677148 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Jul 2023 21:47:21 +0000 Subject: [PATCH 06/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/engine/engine_abc.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 365218edd..ad709fda7 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -28,11 +28,13 @@ def __init__( if self.resolution_unit == "mpp": self.highest_input_resolution = min( - self.input_resolutions, key=lambda x: x["resolution"], + self.input_resolutions, + key=lambda x: x["resolution"], ) else: self.highest_input_resolution = max( - self.input_resolutions, key=lambda x: x["resolution"], + self.input_resolutions, + key=lambda x: x["resolution"], ) def _validate(self): @@ -45,7 +47,8 @@ def _validate(self): "baseline", "mpp", ]: - raise ValueError(f"Invalid resolution units `{units[0]}`.") + msg = f"Invalid resolution units `{units[0]}`." + raise ValueError(msg) @property @abstractmethod From d49fd0b0c961b53e99121c9e4e254440c10bf66c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 31 Jul 2023 15:58:34 +0100 Subject: [PATCH 07/36] :recycle: Refactor base code from `IOSegmentorConfig` to `ModelIOConfigABC` (#618) - Moved all ioconfigs to a single file. - Used dataclass to define ioconfig. - Refactor base code from `IOSegmentorConfig` to `ModelIOConfigABC` - Use `ModelIOConfigABC` for `PatchPredictor` instead of `IOSegmentorConfig` --------- Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: John Pocock Co-authored-by: Mark Eastwood <20169086+measty@users.noreply.github.com> Co-authored-by: Mostafa Jahanifar <74412979+mostafajahanifar@users.noreply.github.com> Co-authored-by: Adam Shephard <39619155+adamshephard@users.noreply.github.com> --- .github/workflows/python-package.yml | 2 +- tests/engines/__init__.py | 1 + tests/engines/test_ioconfig.py | 23 + tests/models/test_feature_extractor.py | 6 +- tests/models/test_multi_task_segmentor.py | 18 +- .../models/test_nucleus_instance_segmentor.py | 18 +- tests/models/test_patch_predictor.py | 14 +- tests/models/test_semantic_segmentation.py | 79 ++- tiatoolbox/cli/nucleus_instance_segment.py | 4 +- tiatoolbox/data/pretrained_model.yaml | 100 ++-- tiatoolbox/models/__init__.py | 18 +- tiatoolbox/models/engine/engine_abc.py | 63 +-- tiatoolbox/models/engine/io_config.py | 450 ++++++++++++++++++ .../models/engine/multi_task_segmentor.py | 26 +- .../engine/nucleus_instance_segmentor.py | 14 +- tiatoolbox/models/engine/patch_predictor.py | 31 +- .../models/engine/semantic_segmentor.py | 223 ++------- whitelist.txt | 1 + 18 files changed, 650 insertions(+), 441 deletions(-) create mode 100644 tests/engines/__init__.py create mode 100644 tests/engines/test_ioconfig.py create mode 100644 tiatoolbox/models/engine/io_config.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 1fbadaba9..47a4224cb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -8,7 +8,7 @@ on: branches: [ develop, pre-release, master, main ] tags: v* pull_request: - branches: [ develop, pre-release, master, main ] + branches: [ develop, pre-release, master, main, dev-define-engines-abc] jobs: build: diff --git a/tests/engines/__init__.py b/tests/engines/__init__.py new file mode 100644 index 000000000..193a523c1 --- /dev/null +++ b/tests/engines/__init__.py @@ -0,0 +1 @@ +"""Unit test package for tiatoolbox engines.""" diff --git a/tests/engines/test_ioconfig.py b/tests/engines/test_ioconfig.py new file mode 100644 index 000000000..41169298b --- /dev/null +++ b/tests/engines/test_ioconfig.py @@ -0,0 +1,23 @@ +"""Tests for IOconfig.""" + +import pytest + +from tiatoolbox.models import ModelIOConfigABC + + +def test_validation_error_io_config() -> None: + """Test Validation Error for ModelIOConfigABC.""" + with pytest.raises(ValueError, match=r".*Multiple resolution units found.*"): + ModelIOConfigABC( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + {"units": "mpp", "resolution": 0.25}, + ], + patch_input_shape=(224, 224), + ) + + with pytest.raises(ValueError, match=r"Invalid resolution units.*"): + ModelIOConfigABC( + input_resolutions=[{"units": "level", "resolution": 1.0}], + patch_input_shape=(224, 224), + ) diff --git a/tests/models/test_feature_extractor.py b/tests/models/test_feature_extractor.py index 18ee51bb2..bd3461bfe 100644 --- a/tests/models/test_feature_extractor.py +++ b/tests/models/test_feature_extractor.py @@ -6,11 +6,9 @@ import numpy as np import torch +from tiatoolbox.models import IOSegmentorConfig from tiatoolbox.models.architecture.vanilla import CNNBackbone -from tiatoolbox.models.engine.semantic_segmentor import ( - DeepFeatureExtractor, - IOSegmentorConfig, -) +from tiatoolbox.models.engine.semantic_segmentor import DeepFeatureExtractor from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.wsicore.wsireader import WSIReader diff --git a/tests/models/test_multi_task_segmentor.py b/tests/models/test_multi_task_segmentor.py index 21689a3b6..bb3d924b7 100644 --- a/tests/models/test_multi_task_segmentor.py +++ b/tests/models/test_multi_task_segmentor.py @@ -12,7 +12,11 @@ import numpy as np import pytest -from tiatoolbox.models import IOSegmentorConfig, MultiTaskSegmentor, SemanticSegmentor +from tiatoolbox.models import ( + IOInstanceSegmentorConfig, + MultiTaskSegmentor, + SemanticSegmentor, +) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite from tiatoolbox.utils.metrics import f1_detection @@ -178,7 +182,7 @@ def test_masked_segmentor(remote_sample, tmp_path): # resolution for travis testing, not the correct ones resolution = 4.0 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, @@ -186,7 +190,7 @@ def test_masked_segmentor(remote_sample, tmp_path): {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[512, 512], + tile_shape=(512, 512), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], @@ -304,10 +308,10 @@ def test_empty_image(tmp_path): output_types=["semantic"], ) - bcc_wsi_ioconfig = IOSegmentorConfig( + bcc_wsi_ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": 0.25}], output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=2048, + tile_shape=(2048, 2048), patch_input_shape=[1024, 1024], patch_output_shape=[512, 512], stride_shape=[512, 512], @@ -352,7 +356,7 @@ def test_functionality_semantic(remote_sample, tmp_path): output_types=["semantic"], ) - bcc_wsi_ioconfig = IOSegmentorConfig( + bcc_wsi_ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": 0.25}], output_resolutions=[{"units": "mpp", "resolution": 0.25}], tile_shape=2048, @@ -393,7 +397,7 @@ def test_crash_segmentor(remote_sample, tmp_path): # resolution for travis testing, not the correct ones resolution = 4.0 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py index 6de01c664..026a6647b 100644 --- a/tests/models/test_nucleus_instance_segmentor.py +++ b/tests/models/test_nucleus_instance_segmentor.py @@ -15,7 +15,7 @@ from tiatoolbox import cli from tiatoolbox.models import ( - IOSegmentorConfig, + IOInstanceSegmentorConfig, NucleusInstanceSegmentor, SemanticSegmentor, ) @@ -63,7 +63,7 @@ def helper_tile_info(): # | 12 | 13 | 14 | 15 | # --------------------- # ! assume flag index ordering: left right top bottom - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": 0.25}], output_resolutions=[ {"units": "mpp", "resolution": 0.25}, @@ -71,7 +71,7 @@ def helper_tile_info(): {"units": "mpp", "resolution": 0.25}, ], margin=1, - tile_shape=[4, 4], + tile_shape=(4, 4), stride_shape=[4, 4], patch_input_shape=[4, 4], patch_output_shape=[4, 4], @@ -247,7 +247,7 @@ def test_crash_segmentor(remote_sample, tmp_path): # resolution for travis testing, not the correct ones resolution = 4.0 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, @@ -255,7 +255,7 @@ def test_crash_segmentor(remote_sample, tmp_path): {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[512, 512], + tile_shape=(512, 512), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], @@ -299,14 +299,14 @@ def test_functionality_ci(remote_sample, tmp_path): # * test run on wsi, test run with worker # resolution for travis testing, not the correct ones - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[1024, 1024], + tile_shape=(1024, 1024), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], @@ -340,7 +340,7 @@ def test_functionality_merge_tile_predictions_ci(remote_sample, tmp_path): mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) resolution = 0.5 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, @@ -348,7 +348,7 @@ def test_functionality_merge_tile_predictions_ci(remote_sample, tmp_path): {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[512, 512], + tile_shape=(512, 512), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index 86d39d4cd..7113b10f6 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -459,19 +459,6 @@ def __getitem__(self, idx): # ------------------------------------------------------------------------------------- -def test_io_patch_predictor_config(): - """Test for IOConfig.""" - # test for creating - cfg = IOPatchPredictorConfig( - patch_input_shape=[224, 224], - stride_shape=[224, 224], - input_resolutions=[{"resolution": 0.5, "units": "mpp"}], - # test adding random kwarg and they should be accessible as kwargs - crop_from_source=True, - ) - assert cfg.crop_from_source - - # ------------------------------------------------------------------------------------- # Engine # ------------------------------------------------------------------------------------- @@ -545,6 +532,7 @@ def test_io_config_delegation(remote_sample, tmp_path): patch_input_shape=[512, 512], stride_shape=[256, 256], input_resolutions=[{"resolution": 1.35, "units": "mpp"}], + output_resolutions=[], ) predictor.predict( [mini_wsi_svs], diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 2c58d007a..d30dd355f 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -1,7 +1,4 @@ """Test for Semantic Segmentor.""" - -import copy - # ! The garbage collector import gc import multiprocessing @@ -18,13 +15,10 @@ from torch import nn from tiatoolbox import cli -from tiatoolbox.models import SemanticSegmentor +from tiatoolbox.models import IOSegmentorConfig, SemanticSegmentor from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.utils import centre_crop -from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, - WSIStreamDataset, -) +from tiatoolbox.models.engine.semantic_segmentor import WSIStreamDataset from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imread, imwrite @@ -113,39 +107,6 @@ def infer_batch(model, batch_data, on_gpu): def test_segmentor_ioconfig(): """Test for IOConfig.""" - default_config = { - "input_resolutions": [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - "output_resolutions": [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - "patch_input_shape": [2048, 2048], - "patch_output_shape": [1024, 1024], - "stride_shape": [512, 512], - } - - # error when uniform resolution units are not uniform - xconfig = copy.deepcopy(default_config) - xconfig["input_resolutions"] = [ - {"units": "mpp", "resolution": 0.25}, - {"units": "power", "resolution": 0.50}, - ] - with pytest.raises(ValueError, match=r".*Invalid resolution units.*"): - _ = IOSegmentorConfig(**xconfig) - - # error when uniform resolution units are not supported - xconfig = copy.deepcopy(default_config) - xconfig["input_resolutions"] = [ - {"units": "alpha", "resolution": 0.25}, - {"units": "alpha", "resolution": 0.50}, - ] - with pytest.raises(ValueError, match=r".*Invalid resolution units.*"): - _ = IOSegmentorConfig(**xconfig) - ioconfig = IOSegmentorConfig( input_resolutions=[ {"units": "mpp", "resolution": 0.25}, @@ -268,8 +229,8 @@ def test_crash_segmentor(remote_sample): model = _CNNTo1() semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) # fake injection to trigger Segmentor to create parallel - # post processing workers because baseline Semantic Segmentor does not support - # post processing out of the box. It only contains condition to create it + # post-processing workers because baseline Semantic Segmentor does not support + # post-processing out of the box. It only contains condition to create it # for any subclass semantic_segmentor.num_postproc_workers = 1 @@ -301,7 +262,7 @@ def test_crash_segmentor(remote_sample): crash_on_exception=True, ) with pytest.raises(ValueError, match=r".*already exists.*"): - semantic_segmentor.predict([], mode="tile", patch_input_shape=[2048, 2048]) + semantic_segmentor.predict([], mode="tile", patch_input_shape=(2048, 2048)) _rm_dir("output") # default output dir test # * test not providing any io_config info when not using pretrained model @@ -314,26 +275,42 @@ def test_crash_segmentor(remote_sample): ) _rm_dir("output") # default output dir test - # * Test crash propagation when parallelize post processing + # * Test crash propagation when parallelize post-processing _rm_dir("output") semantic_segmentor.num_postproc_workers = 2 semantic_segmentor.model.forward = _crash_func with pytest.raises(ValueError, match=r"Propagation Crash."): semantic_segmentor.predict( [mini_wsi_svs], - patch_input_shape=[2048, 2048], + patch_input_shape=(2048, 2048), mode="wsi", on_gpu=ON_GPU, crash_on_exception=True, + resolution=1.0, + units="baseline", ) + + _rm_dir("output") + + with pytest.raises(ValueError, match=r"Invalid resolution.*"): + semantic_segmentor.predict( + [mini_wsi_svs], + patch_input_shape=(2048, 2048), + mode="wsi", + on_gpu=ON_GPU, + crash_on_exception=True, + ) + _rm_dir("output") # test ignore crash semantic_segmentor.predict( [mini_wsi_svs], - patch_input_shape=[2048, 2048], + patch_input_shape=(2048, 2048), mode="wsi", on_gpu=ON_GPU, crash_on_exception=False, + resolution=1.0, + units="baseline", ) _rm_dir("output") @@ -429,7 +406,7 @@ def test_functional_segmentor_merging(tmp_path): _rm_dir(save_dir) save_dir.mkdir() - # * with out of bound location + # * with an out of bound location canvas = semantic_segmentor.merge_prediction( [4, 4], [ @@ -465,8 +442,8 @@ def test_functional_segmentor(remote_sample, tmp_path): model = _CNNTo1() semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) # fake injection to trigger Segmentor to create parallel - # post processing workers because baseline Semantic Segmentor does not support - # post processing out of the box. It only contains condition to create it + # post-processing workers because baseline Semantic Segmentor does not support + # post-processing out of the box. It only contains condition to create it # for any subclass semantic_segmentor.num_postproc_workers = 1 @@ -486,7 +463,7 @@ def test_functional_segmentor(remote_sample, tmp_path): [mini_wsi_jpg], mode="tile", on_gpu=ON_GPU, - patch_input_shape=[512, 512], + patch_input_shape=(512, 512), resolution=1 / resolution, units="baseline", crash_on_exception=True, diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index f851aa761..f34dc8f79 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -63,7 +63,7 @@ def nucleus_instance_segment( verbose, ): """Process an image/directory of input images with a patch classification CNN.""" - from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor + from tiatoolbox.models import IOInstanceSegmentorConfig, NucleusInstanceSegmentor from tiatoolbox.utils import save_as_json files_all, masks_all, output_path = prepare_model_cli( @@ -74,7 +74,7 @@ def nucleus_instance_segment( ) ioconfig = prepare_ioconfig_seg( - IOSegmentorConfig, + IOInstanceSegmentorConfig, pretrained_weights, yaml_config_path, ) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 9eb539efc..280f6dc0a 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -6,7 +6,7 @@ alexnet-kather100k: backbone: alexnet num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -20,7 +20,7 @@ resnet18-kather100k: backbone: resnet18 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -34,7 +34,7 @@ resnet34-kather100k: backbone: resnet34 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -48,7 +48,7 @@ resnet50-kather100k: backbone: resnet50 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -62,7 +62,7 @@ resnet101-kather100k: backbone: resnet101 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -76,7 +76,7 @@ resnext50_32x4d-kather100k: backbone: resnext50_32x4d num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -90,7 +90,7 @@ resnext101_32x8d-kather100k: backbone: resnext101_32x8d num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -104,7 +104,7 @@ wide_resnet50_2-kather100k: backbone: wide_resnet50_2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -118,7 +118,7 @@ wide_resnet101_2-kather100k: backbone: wide_resnet101_2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -132,7 +132,7 @@ densenet121-kather100k: backbone: densenet121 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -146,7 +146,7 @@ densenet161-kather100k: backbone: densenet161 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -160,7 +160,7 @@ densenet169-kather100k: backbone: densenet169 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -174,7 +174,7 @@ densenet201-kather100k: backbone: densenet201 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -188,7 +188,7 @@ mobilenet_v2-kather100k: backbone: mobilenet_v2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -202,7 +202,7 @@ mobilenet_v3_large-kather100k: backbone: mobilenet_v3_large num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -216,7 +216,7 @@ mobilenet_v3_small-kather100k: backbone: mobilenet_v3_small num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -230,7 +230,7 @@ googlenet-kather100k: backbone: googlenet num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -245,7 +245,7 @@ alexnet-pcam: backbone: alexnet num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -259,7 +259,7 @@ resnet18-pcam: backbone: resnet18 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -273,7 +273,7 @@ resnet34-pcam: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -287,7 +287,7 @@ resnet50-pcam: backbone: resnet50 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -301,7 +301,7 @@ resnet101-pcam: backbone: resnet101 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -315,7 +315,7 @@ resnext50_32x4d-pcam: backbone: resnext50_32x4d num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -329,7 +329,7 @@ resnext101_32x8d-pcam: backbone: resnext101_32x8d num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -343,7 +343,7 @@ wide_resnet50_2-pcam: backbone: wide_resnet50_2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -357,7 +357,7 @@ wide_resnet101_2-pcam: backbone: wide_resnet101_2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -371,7 +371,7 @@ densenet121-pcam: backbone: densenet121 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -385,7 +385,7 @@ densenet161-pcam: backbone: densenet161 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -399,7 +399,7 @@ densenet169-pcam: backbone: densenet169 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -413,7 +413,7 @@ densenet201-pcam: backbone: densenet201 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -427,7 +427,7 @@ mobilenet_v2-pcam: backbone: mobilenet_v2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -441,7 +441,7 @@ mobilenet_v3_large-pcam: backbone: mobilenet_v3_large num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -455,7 +455,7 @@ mobilenet_v3_small-pcam: backbone: mobilenet_v3_small num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -469,7 +469,7 @@ googlenet-pcam: backbone: googlenet num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -484,7 +484,7 @@ resnet18-idars-tumour: backbone: resnet18 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [512, 512] stride_shape: [512, 512] @@ -497,7 +497,7 @@ resnet34-idars-msi: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -510,7 +510,7 @@ resnet34-idars-braf: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -523,7 +523,7 @@ resnet34-idars-cimp: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -536,7 +536,7 @@ resnet34-idars-cin: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -549,7 +549,7 @@ resnet34-idars-tp53: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -562,7 +562,7 @@ resnet34-idars-hm: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -579,7 +579,7 @@ fcn-tissue_mask: encoder: "resnet50" decoder_block: [3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'mpp', 'resolution': 2.0} @@ -600,7 +600,7 @@ fcn_resnet50_unet-bcss: encoder: "resnet50" decoder_block: [3, 3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'mpp', 'resolution': 0.25} @@ -625,7 +625,7 @@ unet_tissue_mask_tsef: encoder: "resnet50" decoder_block: [3, 3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 1.0} @@ -644,7 +644,7 @@ hovernet_fast-pannuke: num_types: 6 mode: "fast" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -667,7 +667,7 @@ hovernet_fast-monusac: num_types: 5 mode: "fast" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -690,7 +690,7 @@ hovernet_original-consep: num_types: 5 mode: "original" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -713,7 +713,7 @@ hovernet_original-kumar: num_types: null # None in python ?, only do instance segmentation mode: "original" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -735,7 +735,7 @@ hovernetplus-oed: num_types: 3 num_layers: 5 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.50} @@ -759,7 +759,7 @@ micronet-consep: num_input_channels: 3 num_output_channels: 2 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index a1057495d..db809725e 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -8,17 +8,17 @@ from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN -from .engine.multi_task_segmentor import MultiTaskSegmentor -from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor -from .engine.patch_predictor import ( +from .engine.io_config import ( + IOInstanceSegmentorConfig, IOPatchPredictorConfig, - PatchDataset, - PatchPredictor, - WSIPatchDataset, + IOSegmentorConfig, + ModelIOConfigABC, ) +from .engine.multi_task_segmentor import MultiTaskSegmentor +from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from .engine.patch_predictor import PatchDataset, PatchPredictor, WSIPatchDataset from .engine.semantic_segmentor import ( DeepFeatureExtractor, - IOSegmentorConfig, SemanticSegmentor, WSIStreamDataset, ) @@ -35,4 +35,8 @@ "NucleusInstanceSegmentor", "PatchPredictor", "SemanticSegmentor", + "IOPatchPredictorConfig", + "IOSegmentorConfig", + "IOInstanceSegmentorConfig", + "ModelIOConfigABC", ] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index ad709fda7..223c4a051 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,74 +1,17 @@ +"""Defines Abstract Base Class for TIAToolbox Model Engines.""" from abc import ABC, abstractmethod -from typing import List, Tuple, Union - -import numpy as np - - -class IOConfigABC(ABC): - """Define an abstract class for holding predictor I/O information. - - Enforcing such that following attributes must always be defined by - the subclass. - - """ - - def __init__( - self, - input_resolutions: List[dict], - patch_input_shape: Union[List[int], np.ndarray, Tuple[int]], - stride_shape: Union[List[int], np.ndarray, Tuple[int]], - **kwargs, - ): - self._kwargs = kwargs - self.resolution_unit = input_resolutions[0]["units"] - self.patch_input_shape = patch_input_shape - self.stride_shape = stride_shape - - self._validate() - - if self.resolution_unit == "mpp": - self.highest_input_resolution = min( - self.input_resolutions, - key=lambda x: x["resolution"], - ) - else: - self.highest_input_resolution = max( - self.input_resolutions, - key=lambda x: x["resolution"], - ) - - def _validate(self): - """Validate the data format.""" - resolutions = self.input_resolutions + self.output_resolutions - units = [v["units"] for v in resolutions] - units = np.unique(units) - if len(units) != 1 or units[0] not in [ - "power", - "baseline", - "mpp", - ]: - msg = f"Invalid resolution units `{units[0]}`." - raise ValueError(msg) - - @property - @abstractmethod - def input_resolutions(self): - raise NotImplementedError - - @property - @abstractmethod - def output_resolutions(self): - raise NotImplementedError class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox.""" def __init__(self): + """Initialize Engine.""" super().__init__() @abstractmethod def process_patch(self): + """Process an image patch.""" raise NotImplementedError # how to deal with patches, list of patches/numpy arrays, WSIs diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py new file mode 100644 index 000000000..6ad2f634c --- /dev/null +++ b/tiatoolbox/models/engine/io_config.py @@ -0,0 +1,450 @@ +"""Defines IOConfig for Model Engines.""" +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.wsicore.wsimeta import Units + + +@dataclass +class ModelIOConfigABC: + """Defines a data class for holding a deep learning model's I/O information. + + Enforcing such that following attributes must always be defined by + the subclass. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + + Examples: + >>> # Defining io for a base network and converting to baseline. + >>> ioconfig = ModelIOConfigABC( + ... input_resolutions=[{"units": "mpp", "resolution": 0.5}], + ... output_resolutions=[{"units": "mpp", "resolution": 1.0}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + >>> ioconfig = ioconfig.to_baseline() + + """ + + input_resolutions: list[dict] + patch_input_shape: list[int] | np.ndarray | tuple[int, int] + stride_shape: list[int] | np.ndarray | tuple[int, int] = None + output_resolutions: list[dict] = field(default_factory=list) + + def __post_init__(self: ModelIOConfigABC) -> None: + """Perform post initialization tasks.""" + if self.stride_shape is None: + self.stride_shape = self.patch_input_shape + + self.resolution_unit = self.input_resolutions[0]["units"] + + if self.resolution_unit == "mpp": + self.highest_input_resolution = min( + self.input_resolutions, + key=lambda x: x["resolution"], + ) + else: + self.highest_input_resolution = max( + self.input_resolutions, + key=lambda x: x["resolution"], + ) + + self._validate() + + def _validate(self: ModelIOConfigABC) -> None: + """Validate the data format.""" + resolutions = self.input_resolutions + self.output_resolutions + units = {v["units"] for v in resolutions} + + if len(units) != 1: + msg = ( + f"Multiple resolution units found: `{units}`. " + f"Mixing resolution units is not allowed." + ) + raise ValueError( + msg, + ) + + if units.pop() not in [ + "power", + "baseline", + "mpp", + ]: + msg = f"Invalid resolution units `{units}`." + raise ValueError(msg) + + @staticmethod + def scale_to_highest(resolutions: list[dict], units: Units) -> np.array: + """Get the scaling factor from input resolutions. + + This will convert resolutions to a scaling factor with respect to + the highest resolution found in the input resolutions list. If a model + requires images at multiple resolutions. This helps to read the image a + single resolution. The image will be read at the highest required resolution + and will be scaled for low resolution requirements using interpolation. + + Args: + resolutions (list): + A list of resolutions where one is defined as + `{'resolution': value, 'unit': value}` + units (Units): + Resolution units. + + Returns: + :class:`numpy.ndarray`: + A 1D array of scaling factors having the same length as + `resolutions`. + + Examples: + >>> # Defining io for a base network and converting to baseline. + >>> ioconfig = ModelIOConfigABC( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.5}, + ... ], + ... output_resolutions=[{"units": "mpp", "resolution": 1.0}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + >>> ioconfig = ioconfig.scale_to_highest() + ... array([1. , 0.5]) # output + >>> + >>> # Defining io for a base network and converting to baseline. + >>> ioconfig = ModelIOConfigABC( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.5}, + ... {"units": "mpp", "resolution": 0.25}, + ... ], + ... output_resolutions=[{"units": "mpp", "resolution": 1.0}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + >>> ioconfig = ioconfig.scale_to_highest() + ... array([0.5 , 1.]) # output + + """ + old_vals = [v["resolution"] for v in resolutions] + if units not in {"baseline", "mpp", "power"}: + msg = ( + f"Unknown units `{units}`. " + f"Units should be one of 'baseline', 'mpp' or 'power'." + ) + raise ValueError( + msg, + ) + if units == "baseline": + return old_vals + if units == "mpp": + return np.min(old_vals) / np.array(old_vals) + return np.array(old_vals) / np.max(old_vals) + + def to_baseline(self: ModelIOConfigABC) -> ModelIOConfigABC: + """Returns a new config object converted to baseline form. + + This will return a new :class:`ModelIOConfigABC` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + resolutions = self.input_resolutions + self.output_resolutions + save_resolution = getattr(self, "save_resolution", None) + if save_resolution is not None: + resolutions.append(save_resolution) + + scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) + num_input_resolutions = len(self.input_resolutions) + + end_idx = num_input_resolutions + input_resolutions = [ + {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] + ] + + num_input_resolutions = len(self.input_resolutions) + num_output_resolutions = len(self.output_resolutions) + + end_idx = num_input_resolutions + num_output_resolutions + output_resolutions = [ + {"units": "baseline", "resolution": v} + for v in scale_factors[num_input_resolutions:end_idx] + ] + + return replace( + self, + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + ) + + +@dataclass +class IOSegmentorConfig(ModelIOConfigABC): + """Contains semantic segmentor input and output information. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... ) + >>> # Defining io for a network having 3 input and 2 output + >>> # at the same resolution, the output is then merged at a + >>> # different resolution. + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... {"units": "mpp", "resolution": 0.75}, + ... ], + ... output_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... ], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... save_resolution={"units": "mpp", "resolution": 4.0}, + ... ) + + """ + + patch_output_shape: list[int] | np.ndarray | tuple[int, int] = None + save_resolution: dict = None + + def to_baseline(self: IOSegmentorConfig) -> IOSegmentorConfig: + """Returns a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + new_config = super().to_baseline() + resolutions = self.input_resolutions + self.output_resolutions + if self.save_resolution is not None: + resolutions.append(self.save_resolution) + + scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) + + save_resolution = None + if self.save_resolution is not None: + save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} + + return replace( + self, + input_resolutions=new_config.input_resolutions, + output_resolutions=new_config.output_resolutions, + save_resolution=save_resolution, + ) + + +class IOPatchPredictorConfig(ModelIOConfigABC): + """Contains patch predictor input and output information. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + + Examples: + >>> # Defining io for a patch predictor network + >>> ioconfig = IOPatchPredictorConfig( + ... input_resolutions=[{"units": "mpp", "resolution": 0.5}], + ... output_resolutions=[{"units": "mpp", "resolution": 0.5}], + ... patch_input_shape=(224, 224), + ... stride_shape=(224, 224), + ... ) + + """ + + +@dataclass +class IOInstanceSegmentorConfig(IOSegmentorConfig): + """Contains instance segmentor input and output information. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + margin (int): + Tile margin to accumulate the output. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + margin (int): + Tile margin to accumulate the output. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IOInstanceSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... margin=128, + ... tile_shape=(1024, 1024), + ... ) + >>> # Defining io for a network having 3 input and 2 output + >>> # at the same resolution, the output is then merged at a + >>> # different resolution. + >>> ioconfig = IOInstanceSegmentorConfig( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... {"units": "mpp", "resolution": 0.75}, + ... ], + ... output_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... ], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), + ... save_resolution={"units": "mpp", "resolution": 4.0}, + ... margin=128, + ... tile_shape=(1024, 1024), + ... ) + + """ + + margin: int = None + tile_shape: tuple[int, int] = None + + def to_baseline(self: IOInstanceSegmentorConfig) -> IOInstanceSegmentorConfig: + """Returns a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + return super().to_baseline() diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7857842fa..22aa95538 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -34,14 +34,14 @@ NucleusInstanceSegmentor, _process_instance_predictions, ) -from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, - WSIStreamDataset, -) + +from .semantic_segmentor import WSIStreamDataset if TYPE_CHECKING: # pragma: no cover import torch + from .io_config import IOInstanceSegmentorConfig + # Python is yet to be able to natively pickle Object method/static method. # Only top-level function is passable to multi-processing as caller. @@ -65,7 +65,7 @@ def _process_tile_predictions( using the output from each task. Args: - ioconfig (:class:`IOSegmentorConfig`): Object defines information + ioconfig (:class:`IOInstanceSegmentorConfig`): Object defines information about input and output placement of patches. tile_bounds (:class:`numpy.array`): Boundary of the current tile, defined as (top_left_x, top_left_y, bottom_x, bottom_y). @@ -286,19 +286,23 @@ def __init__( def _predict_one_wsi( self, wsi_idx: int, - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, ): """Make a prediction on tile/wsi. Args: - wsi_idx (int): Index of the tile/wsi to be processed within `self`. - ioconfig (IOSegmentorConfig): Object which defines I/O placement during - inference and when assembling back to full tile/wsi. - save_path (str): Location to save output prediction as well as possible + wsi_idx (int): + Index of the tile/wsi to be processed within `self`. + ioconfig (IOInstanceSegmentorConfig): + Object which defines I/O placement + during inference and when assembling back to full tile/wsi. + save_path (str): + Location to save output prediction as well as possible intermediate results. - mode (str): `tile` or `wsi` to indicate run mode. + mode (str): + `tile` or `wsi` to indicate run mode. """ cache_dir = f"{self._cache_dir}/" diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index bd96ab0b5..8f7983a89 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -3,7 +3,7 @@ import uuid from collections import deque -from typing import Callable +from typing import TYPE_CHECKING, Callable # replace with the sql database once the PR in place import joblib @@ -14,12 +14,14 @@ from shapely.strtree import STRtree from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, SemanticSegmentor, WSIStreamDataset, ) from tiatoolbox.tools.patchextraction import PatchExtractor +if TYPE_CHECKING: # pragma: no cover + from .io_config import IOInstanceSegmentorConfig + def _process_instance_predictions( inst_dict, @@ -404,7 +406,7 @@ def __init__( @staticmethod def _get_tile_info( image_shape: list[int] | np.ndarray, - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, ): """Generating tile information. @@ -422,7 +424,7 @@ def _get_tile_info( image_shape (:class:`numpy.ndarray`, list(int)): The shape of WSI to extract the tile from, assumed to be in `[width, height]`. - ioconfig (:obj:IOSegmentorConfig): + ioconfig (:obj:IOInstanceSegmentorConfig): The input and output configuration objects. Returns: @@ -659,7 +661,7 @@ def _infer_once(self): def _predict_one_wsi( self, wsi_idx: int, - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, ): @@ -668,7 +670,7 @@ def _predict_one_wsi( Args: wsi_idx (int): Index of the tile/wsi to be processed within `self`. - ioconfig (IOSegmentorConfig): + ioconfig (IOInstanceSegmentorConfig): Object which defines I/O placement during inference and when assembling back to full tile/wsi. save_path (str): diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index ed363e1b4..8042afeaf 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -13,37 +13,15 @@ from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset -from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig from tiatoolbox.utils import misc, save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from .io_config import IOPatchPredictorConfig + if TYPE_CHECKING: # pragma: no cover from tiatoolbox.typing import Resolution, Units -class IOPatchPredictorConfig(IOSegmentorConfig): - """Contains patch predictor input and output information.""" - - def __init__( - self, - patch_input_shape=None, - input_resolutions=None, - stride_shape=None, - **kwargs, - ) -> None: - """Initialize :class:`IOPatchPredictorConfig`.""" - stride_shape = patch_input_shape if stride_shape is None else stride_shape - super().__init__( - input_resolutions=input_resolutions, - output_resolutions=[], - stride_shape=stride_shape, - patch_input_shape=patch_input_shape, - patch_output_shape=patch_input_shape, - save_resolution=None, - **kwargs, - ) - - class PatchPredictor: r"""Patch level predictor. @@ -471,10 +449,10 @@ def _update_ioconfig( resolution, units, ): - """Updates the ioconfig. + """Update the ioconfig. Args: - ioconfig (IOPatchPredictorConfig): + ioconfig (:class:`IOPatchPredictorConfig`): Input ioconfig for PatchPredictor. patch_input_shape (tuple): Size of patches input to the model. Patches are at @@ -533,6 +511,7 @@ def _update_ioconfig( input_resolutions=[{"resolution": resolution, "units": units}], patch_input_shape=patch_input_shape, stride_shape=stride_shape, + output_resolutions=[], ) @staticmethod diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 0df56d29d..6805281dd 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -18,11 +18,12 @@ from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.engine.engine_abc import IOConfigABC from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread, misc from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader +from .io_config import IOSegmentorConfig + if TYPE_CHECKING: # pragma: no cover from multiprocessing.managers import Namespace @@ -106,184 +107,6 @@ def _prepare_save_output( return is_on_drive, count_canvas, cum_canvas -class IOSegmentorConfig(IOConfigABC): - """Contain semantic segmentor input and output information. - - Args: - input_resolutions (list): - Resolution of each input head of model inference, must be in - the same order as `target model.forward()`. - output_resolutions (list): - Resolution of each output head from model inference, must be - in the same order as target model.infer_batch(). - patch_input_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest input in (height, width). - patch_output_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest output in (height, width). - save_resolution (dict): - Resolution to save all output. - - Examples: - >>> # Defining io for a network having 1 input and 1 output at the - >>> # same resolution - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... ) - - Examples: - >>> # Defining io for a network having 3 input and 2 output - >>> # at the same resolution, the output is then merged at a - >>> # different resolution. - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... {"units": "mpp", "resolution": 0.75}, - ... ], - ... output_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... ], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... save_resolution={"units": "mpp", "resolution": 4.0}, - ... ) - - """ - - # We pre-define to follow enforcement, actual initialisation in init - input_resolutions = None - output_resolutions = None - - def __init__( - self, - input_resolutions: list[dict], - output_resolutions: list[dict], - patch_input_shape: list[int] | np.ndarray, - patch_output_shape: list[int] | np.ndarray, - save_resolution: dict | None = None, - **kwargs, - ) -> None: - """Initialize :class:`IOSegmentorConfig`.""" - self._kwargs = kwargs - self.patch_input_shape = patch_input_shape - self.patch_output_shape = patch_output_shape - self.stride_shape = None - self.input_resolutions = input_resolutions - self.output_resolutions = output_resolutions - - self.resolution_unit = input_resolutions[0]["units"] - self.save_resolution = save_resolution - - for variable, value in kwargs.items(): - self.__setattr__(variable, value) - - self._validate() - - if self.resolution_unit == "mpp": - self.highest_input_resolution = min( - self.input_resolutions, - key=lambda x: x["resolution"], - ) - else: - self.highest_input_resolution = max( - self.input_resolutions, - key=lambda x: x["resolution"], - ) - - def _validate(self): - """Validate the data format.""" - resolutions = self.input_resolutions + self.output_resolutions - units = [v["units"] for v in resolutions] - units = np.unique(units) - if len(units) != 1 or units[0] not in [ - "power", - "baseline", - "mpp", - ]: - msg = f"Invalid resolution units `{units[0]}`." - raise ValueError(msg) - - @staticmethod - def scale_to_highest(resolutions: list[dict], units: Units): - """Get the scaling factor from input resolutions. - - This will convert resolutions to a scaling factor with respect to - the highest resolution found in the input resolutions list. - - Args: - resolutions (list): - A list of resolutions where one is defined as - `{'resolution': value, 'unit': value}` - units (Units): - Units that the resolutions are at. - - Returns: - :class:`numpy.ndarray`: - A 1D array of scaling factors having the same length as - `resolutions` - - """ - old_val = [v["resolution"] for v in resolutions] - if units not in ["baseline", "mpp", "power"]: - msg = ( - f"Unknown units `{units}`. " - f"Units should be one of 'baseline', 'mpp' or 'power'." - ) - raise ValueError( - msg, - ) - if units == "baseline": - return old_val - if units == "mpp": - return np.min(old_val) / np.array(old_val) - return np.array(old_val) / np.max(old_val) - - def to_baseline(self): - """Return a new config object converted to baseline form. - - This will return a new :class:`IOSegmentorConfig` where - resolutions have been converted to baseline format with the - highest possible resolution found in both input and output as - reference. - - """ - resolutions = self.input_resolutions + self.output_resolutions - if self.save_resolution is not None: - resolutions.append(self.save_resolution) - - scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) - num_input_resolutions = len(self.input_resolutions) - num_output_resolutions = len(self.output_resolutions) - - end_idx = num_input_resolutions - input_resolutions = [ - {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] - ] - end_idx = num_input_resolutions + num_output_resolutions - output_resolutions = [ - {"units": "baseline", "resolution": v} - for v in scale_factors[num_input_resolutions:end_idx] - ] - - save_resolution = None - if self.save_resolution is not None: - save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} - return IOSegmentorConfig( - input_resolutions=input_resolutions, - output_resolutions=output_resolutions, - patch_input_shape=self.patch_input_shape, - patch_output_shape=self.patch_output_shape, - save_resolution=save_resolution, - **self._kwargs, - ) - - class WSIStreamDataset(torch_data.Dataset): """Reading a wsi in parallel mode with persistent workers. @@ -1064,8 +887,8 @@ def _prepare_save_dir(save_dir): return save_dir, cache_dir + @staticmethod def _update_ioconfig( - self, ioconfig, mode, patch_input_shape, @@ -1113,17 +936,7 @@ def _update_ioconfig( if stride_shape is None: stride_shape = patch_output_shape - if ioconfig is None and patch_input_shape is None: - if self.ioconfig is None: - msg = ( - "Must provide either `ioconfig` or `patch_input_shape` " - "and `patch_output_shape`" - ) - raise ValueError( - msg, - ) - ioconfig = copy.deepcopy(self.ioconfig) - elif ioconfig is None: + if ioconfig is None: ioconfig = IOSegmentorConfig( input_resolutions=[{"resolution": resolution, "units": units}], output_resolutions=[{"resolution": resolution, "units": units}], @@ -1252,8 +1065,8 @@ def predict( patch_input_shape=None, patch_output_shape=None, stride_shape=None, - resolution=1.0, - units="baseline", + resolution=None, + units=None, save_dir=None, crash_on_exception=False, ): @@ -1342,6 +1155,28 @@ def predict( save_dir, self._cache_dir = self._prepare_save_dir(save_dir) + if ioconfig is None: + ioconfig = copy.deepcopy(self.ioconfig) + + if ioconfig is None and patch_input_shape is None: + msg = ( + "Must provide either `ioconfig` or " + "`patch_input_shape` and `patch_output_shape`" + ) + raise ValueError( + msg, + ) + + if resolution is None and units is None: + if ioconfig is None: + msg = f"Invalid resolution: `{resolution}` and units: `{units}`. " + raise ValueError( + msg, + ) + + resolution = ioconfig.input_resolutions[0]["resolution"] + units = ioconfig.input_resolutions[0]["units"] + ioconfig = self._update_ioconfig( ioconfig, mode, @@ -1605,7 +1440,7 @@ def predict( Resolution used for reading the image. units (Units): Units of resolution used for reading the image. - save_dir (str): + save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. diff --git a/whitelist.txt b/whitelist.txt index 07a1b13c3..d1e723f26 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -96,6 +96,7 @@ coord coords csv cuda +customizable cv2 dataframe dataset From 117218735295f000bc1e35df14f0902a8a8d1a64 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 31 Jul 2023 19:47:55 +0100 Subject: [PATCH 08/36] :recycle: Move Dataset Classes to `dataset_abc.py` (#637) - Move `WSIStreamDataset`, `PatchDatasetABC` and `WSIPatchDataset` to `dataset_abc.py` --------- Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/models/test_dataset.py | 3 +- tiatoolbox/models/__init__.py | 20 +- tiatoolbox/models/dataset/__init__.py | 20 +- tiatoolbox/models/dataset/classification.py | 293 ------------ tiatoolbox/models/dataset/dataset_abc.py | 440 ++++++++++++++++++ tiatoolbox/models/engine/__init__.py | 6 + tiatoolbox/models/engine/io_config.py | 3 +- .../models/engine/multi_task_segmentor.py | 5 +- .../engine/nucleus_instance_segmentor.py | 6 +- tiatoolbox/models/engine/patch_predictor.py | 9 +- .../models/engine/semantic_segmentor.py | 148 +----- 11 files changed, 489 insertions(+), 464 deletions(-) diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index 72622731c..370b8034b 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -7,7 +7,8 @@ import pytest from tiatoolbox import rcParam -from tiatoolbox.models.dataset import DatasetInfoABC, KatherPatchDataset, PatchDataset +from tiatoolbox.models import PatchDataset +from tiatoolbox.models.dataset import DatasetInfoABC, KatherPatchDataset from tiatoolbox.utils import download_data, unzip_data from tiatoolbox.utils import env_detection as toolbox_env diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index db809725e..ecd173ced 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -1,6 +1,5 @@ """Models package for the models implemented in tiatoolbox.""" -from tiatoolbox.models import architecture, dataset, engine, models_abc - +from . import architecture, dataset, engine, models_abc from .architecture.hovernet import HoVerNet from .architecture.hovernetplus import HoVerNetPlus from .architecture.idars import IDaRS @@ -8,6 +7,7 @@ from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN +from .dataset import PatchDataset, WSIPatchDataset, WSIStreamDataset from .engine.io_config import ( IOInstanceSegmentorConfig, IOPatchPredictorConfig, @@ -16,14 +16,14 @@ ) from .engine.multi_task_segmentor import MultiTaskSegmentor from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor -from .engine.patch_predictor import PatchDataset, PatchPredictor, WSIPatchDataset -from .engine.semantic_segmentor import ( - DeepFeatureExtractor, - SemanticSegmentor, - WSIStreamDataset, -) +from .engine.patch_predictor import PatchPredictor +from .engine.semantic_segmentor import DeepFeatureExtractor, SemanticSegmentor __all__ = [ + "architecture", + "dataset", + "engine", + "models_abc", "HoVerNet", "HoVerNetPlus", "IDaRS", @@ -39,4 +39,8 @@ "IOSegmentorConfig", "IOInstanceSegmentorConfig", "ModelIOConfigABC", + "DeepFeatureExtractor", + "WSIStreamDataset", + "WSIPatchDataset", + "PatchDataset", ] diff --git a/tiatoolbox/models/dataset/__init__.py b/tiatoolbox/models/dataset/__init__.py index 9c09991fa..49d59a61a 100644 --- a/tiatoolbox/models/dataset/__init__.py +++ b/tiatoolbox/models/dataset/__init__.py @@ -1,9 +1,21 @@ """Contains dataset functionality for use with models in tiatoolbox.""" -from tiatoolbox.models.dataset.classification import ( +from tiatoolbox.models.dataset.classification import predefined_preproc_func + +from .dataset_abc import ( PatchDataset, + PatchDatasetABC, WSIPatchDataset, - predefined_preproc_func, + WSIStreamDataset, ) -from tiatoolbox.models.dataset.dataset_abc import PatchDatasetABC -from tiatoolbox.models.dataset.info import DatasetInfoABC, KatherPatchDataset +from .info import DatasetInfoABC, KatherPatchDataset + +__all__ = [ + "predefined_preproc_func", + "PatchDatasetABC", + "WSIPatchDataset", + "PatchDataset", + "WSIStreamDataset", + "DatasetInfoABC", + "KatherPatchDataset", +] diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index 48ccd35b2..bdd947f20 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -1,18 +1,8 @@ """Define classes and methods for classification datasets.""" -from pathlib import Path -import cv2 -import numpy as np import PIL from torchvision import transforms -from tiatoolbox import logger -from tiatoolbox.models.dataset import dataset_abc -from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils import imread -from tiatoolbox.wsicore.wsimeta import WSIMeta -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader - class _TorchPreprocCaller: """Wrapper for applying PyTorch transforms. @@ -64,286 +54,3 @@ def predefined_preproc_func(dataset_name): preprocs = preproc_dict[dataset_name] return _TorchPreprocCaller(preprocs) - - -class PatchDataset(dataset_abc.PatchDatasetABC): - """Define PatchDataset for torch inference. - - Define a simple patch dataset, which inherits from the - `torch.utils.data.Dataset` class. - - Attributes: - inputs: - Either a list of patches, where each patch is a ndarray or a - list of valid path with its extension be (".jpg", ".jpeg", - ".tif", ".tiff", ".png") pointing to an image. - labels: - List of labels for sample at the same index in `inputs`. - Default is `None`. - - Examples: - >>> # A user defined preproc func and expected behavior - >>> preproc_func = lambda img: img/2 # reduce intensity by half - >>> transformed_img = preproc_func(img) - >>> # create a dataset to get patches preprocessed by the above function - >>> ds = PatchDataset( - ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], - ... labels=["labels1", "labels2"], - ... ) - - """ - - def __init__(self, inputs, labels=None) -> None: - """Initialize :class:`PatchDataset`.""" - super().__init__() - - self.data_is_npy_alike = False - - self.inputs = inputs - self.labels = labels - - # perform check on the input - self._check_input_integrity(mode="patch") - - def __getitem__(self, idx): - """Get an item from the dataset.""" - patch = self.inputs[idx] - - # Mode 0 is list of paths - if not self.data_is_npy_alike: - patch = self.load_img(patch) - - # Apply preprocessing to selected patch - patch = self._preproc(patch) - - data = { - "image": patch, - } - if self.labels is not None: - data["label"] = self.labels[idx] - return data - - return data - - -class WSIPatchDataset(dataset_abc.PatchDatasetABC): - """Define a WSI-level patch dataset. - - Attributes: - reader (:class:`.WSIReader`): - A WSI Reader or Virtual Reader for reading pyramidal image - or large tile in pyramidal way. - inputs: - List of coordinates to read from the `reader`, each - coordinate is of the form `[start_x, start_y, end_x, - end_y]`. - patch_input_shape: - A tuple (int, int) or ndarray of shape (2,). Expected size to - read from `reader` at requested `resolution` and `units`. - Expected to be `(height, width)`. - resolution: - See (:class:`.WSIReader`) for details. - units: - See (:class:`.WSIReader`) for details. - preproc_func: - Preprocessing function used to transform the input data. It will - be called on each patch before returning it. - - """ - - def __init__( - self, - img_path, - mode="wsi", - mask_path=None, - patch_input_shape=None, - stride_shape=None, - resolution=None, - units=None, - auto_get_mask=True, - min_mask_ratio=0, - preproc_func=None, - ) -> None: - """Create a WSI-level patch dataset. - - Args: - mode (str): - Can be either `wsi` or `tile` to denote the image to - read is either a whole-slide image or a large image - tile. - img_path (str or Path): - Valid to pyramidal whole-slide image or large tile to - read. - mask_path (str or Path): - Valid mask image. - patch_input_shape: - A tuple (int, int) or ndarray of shape (2,). Expected - shape to read from `reader` at requested `resolution` - and `units`. Expected to be positive and of (height, - width). Note, this is not at `resolution` coordinate - space. - stride_shape: - A tuple (int, int) or ndarray of shape (2,). Expected - stride shape to read at requested `resolution` and - `units`. Expected to be positive and of (height, width). - Note, this is not at level 0. - resolution: - Check (:class:`.WSIReader`) for details. When - `mode='tile'`, value is fixed to be `resolution=1.0` and - `units='baseline'` units: check (:class:`.WSIReader`) for - details. - units: - Units in which `resolution` is defined. - auto_get_mask: - If `True`, then automatically get simple threshold mask using - WSIReader.tissue_mask() function. - min_mask_ratio: - Only patches with positive area percentage above this value are - included. Defaults to 0. - preproc_func: - Preprocessing function used to transform the input data. If - supplied, the function will be called on each patch before - returning it. - - Examples: - >>> # A user defined preproc func and expected behavior - >>> preproc_func = lambda img: img/2 # reduce intensity by half - >>> transformed_img = preproc_func(img) - >>> # Create a dataset to get patches from WSI with above - >>> # preprocessing function - >>> ds = WSIPatchDataset( - ... img_path='/A/B/C/wsi.svs', - ... mode="wsi", - ... patch_input_shape=[512, 512], - ... stride_shape=[256, 256], - ... auto_get_mask=False, - ... preproc_func=preproc_func - ... ) - - """ - super().__init__() - - # Is there a generic func for path test in toolbox? - if not Path.is_file(Path(img_path)): - msg = "`img_path` must be a valid file path." - raise ValueError(msg) - if mode not in ["wsi", "tile"]: - msg = f"`{mode}` is not supported." - raise ValueError(msg) - patch_input_shape = np.array(patch_input_shape) - stride_shape = np.array(stride_shape) - - if ( - not np.issubdtype(patch_input_shape.dtype, np.integer) - or np.size(patch_input_shape) > 2 - or np.any(patch_input_shape < 0) - ): - msg = f"Invalid `patch_input_shape` value {patch_input_shape}." - raise ValueError(msg) - if ( - not np.issubdtype(stride_shape.dtype, np.integer) - or np.size(stride_shape) > 2 - or np.any(stride_shape < 0) - ): - msg = f"Invalid `stride_shape` value {stride_shape}." - raise ValueError(msg) - - self.preproc_func = preproc_func - img_path = Path(img_path) - if mode == "wsi": - self.reader = WSIReader.open(img_path) - else: - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"` and `resolution=1.0`.', - stacklevel=2, - ) - img = imread(img_path) - axes = "YXS"[: len(img.shape)] - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - # ! should we expose this so that use can provide their metadata ? - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - axes=axes, - objective_power=10, - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - # infer value such that read if mask provided is through - # 'mpp' or 'power' as varying 'baseline' is locked atm - units = "mpp" - resolution = 1.0 - self.reader = VirtualWSIReader( - img, - info=metadata, - ) - - # may decouple into misc ? - # the scaling factor will scale base level to requested read resolution/units - wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units) - - # use all patches, as long as it overlaps source image - self.inputs = PatchExtractor.get_coordinates( - image_shape=wsi_shape, - patch_input_shape=patch_input_shape[::-1], - stride_shape=stride_shape[::-1], - input_within_bound=False, - ) - - mask_reader = None - if mask_path is not None: - mask_path = Path(mask_path) - if not Path.is_file(mask_path): - msg = "`mask_path` must be a valid file path." - raise ValueError(msg) - mask = imread(mask_path) # assume to be gray - mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) - mask = np.array(mask > 0, dtype=np.uint8) - - mask_reader = VirtualWSIReader(mask) - mask_reader.info = self.reader.info - elif auto_get_mask and mode == "wsi" and mask_path is None: - # if no mask provided and `wsi` mode, generate basic tissue - # mask on the fly - mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") - # ? will this mess up ? - mask_reader.info = self.reader.info - - if mask_reader is not None: - selected = PatchExtractor.filter_coordinates( - mask_reader, # must be at the same resolution - self.inputs, # must already be at requested resolution - wsi_shape=wsi_shape, - min_mask_ratio=min_mask_ratio, - ) - self.inputs = self.inputs[selected] - - if len(self.inputs) == 0: - msg = "No patch coordinates remain after filtering." - raise ValueError(msg) - - self.patch_input_shape = patch_input_shape - self.resolution = resolution - self.units = units - - # Perform check on the input - self._check_input_integrity(mode="wsi") - - def __getitem__(self, idx): - """Get an item from the dataset.""" - coords = self.inputs[idx] - # Read image patch from the whole-slide image - patch = self.reader.read_bounds( - coords, - resolution=self.resolution, - units=self.units, - pad_constant_values=255, - coord_space="resolution", - ) - - # Apply preprocessing to selected patch - patch = self._preproc(patch) - - return {"image": patch, "coords": np.array(coords)} diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 993e7377d..c74eefbbe 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -1,11 +1,25 @@ """Define dataset abstract classes.""" +from __future__ import annotations + +import copy from abc import ABC, abstractmethod from pathlib import Path +from typing import TYPE_CHECKING, Callable +import cv2 import numpy as np import torch +import torch.utils.data as torch_data +from tiatoolbox import logger +from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader + +if TYPE_CHECKING: # pragma: no cover + from multiprocessing.managers import Namespace + + from tiatoolbox.models.engine.io_config import IOSegmentorConfig class PatchDatasetABC(ABC, torch.utils.data.Dataset): @@ -156,3 +170,429 @@ def __len__(self) -> int: def __getitem__(self, idx): """Get an item from the dataset.""" ... # pragma: no cover + + +class WSIStreamDataset(torch_data.Dataset): + """Reading a wsi in parallel mode with persistent workers. + + To speed up the inference process for multiple WSIs. The + `torch.utils.data.Dataloader` is set to run in persistent mode. + Normally, this will prevent workers from altering their initial + states (such as provided input etc.). To sidestep this, we use a + shared parallel workspace context manager to send data and signal + from the main thread, thus allowing each worker to load a new wsi as + well as corresponding patch information. + + Args: + mp_shared_space (:class:`Namespace`): + A shared multiprocessing space, must be from + `torch.multiprocessing`. + ioconfig (:class:`IOSegmentorConfig`): + An object which contains I/O placement for patches. + wsi_paths (list): List of paths pointing to a WSI or tiles. + preproc (Callable): + Pre-processing function to be applied to a patch. + mode (str): + Either `"wsi"` or `"tile"` to indicate the format of images + in `wsi_paths`. + + Examples: + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... ) + >>> mp_manager = torch_mp.Manager() + >>> mp_shared_space = mp_manager.Namespace() + >>> mp_shared_space.signal = 1 # adding variable to the shared space + >>> wsi_paths = ['A.svs', 'B.svs'] + >>> wsi_dataset = WSIStreamDataset(ioconfig, wsi_paths, mp_shared_space) + + """ + + def __init__( + self, + ioconfig: IOSegmentorConfig, + wsi_paths: list[str | Path], + mp_shared_space: Namespace, + preproc: Callable[[np.ndarray], np.ndarray] | None = None, + mode="wsi", + ) -> None: + """Initialize :class:`WSIStreamDataset`.""" + super().__init__() + self.mode = mode + self.preproc = preproc + self.ioconfig = copy.deepcopy(ioconfig) + + if mode == "tile": + logger.warning( + "WSIPatchDataset only reads image tile at " + '`units="baseline"`. Resolutions will be converted ' + "to baseline value.", + stacklevel=2, + ) + self.ioconfig = self.ioconfig.to_baseline() + + self.mp_shared_space = mp_shared_space + self.wsi_paths = wsi_paths + self.wsi_idx = None # to be received externally via thread communication + self.reader = None + + def _get_reader(self, img_path): + """Get appropriate reader for input path.""" + img_path = Path(img_path) + if self.mode == "wsi": + return WSIReader.open(img_path) + img = imread(img_path) + # initialise metadata for VirtualWSIReader. + # here, we simulate a whole-slide image, but with a single level. + metadata = WSIMeta( + mpp=np.array([1.0, 1.0]), + objective_power=10, + axes="YXS", + slide_dimensions=np.array(img.shape[:2][::-1]), + level_downsamples=[1.0], + level_dimensions=[np.array(img.shape[:2][::-1])], + ) + return VirtualWSIReader( + img, + info=metadata, + ) + + def __len__(self) -> int: + """Return the length of the instance attributes.""" + return len(self.mp_shared_space.patch_inputs) + + @staticmethod + def collate_fn(batch): + """Prototype to handle reading exception. + + This will exclude any sample with `None` from the batch. As + such, wrapping `__getitem__` with try-catch and return `None` + upon exceptions will prevent crashing the entire program. But as + a side effect, the batch may not have the size as defined. + + """ + batch = [v for v in batch if v is not None] + return torch.utils.data.dataloader.default_collate(batch) + + def __getitem__(self, idx: int): + """Get an item from the dataset.""" + # ! no need to lock as we do not modify source value in shared space + if self.wsi_idx != self.mp_shared_space.wsi_idx: + self.wsi_idx = int(self.mp_shared_space.wsi_idx.item()) + self.reader = self._get_reader(self.wsi_paths[self.wsi_idx]) + + # this is in XY and at requested resolution (not baseline) + bounds = self.mp_shared_space.patch_inputs[idx] + bounds = bounds.numpy() # expected to be a torch.Tensor + + # be the same as bounds br-tl, unless bounds are of float + patch_data_ = [] + scale_factors = self.ioconfig.scale_to_highest( + self.ioconfig.input_resolutions, + self.ioconfig.resolution_unit, + ) + for idy, resolution in enumerate(self.ioconfig.input_resolutions): + resolution_bounds = np.round(bounds * scale_factors[idy]) + patch_data = self.reader.read_bounds( + resolution_bounds.astype(np.int32), + coord_space="resolution", + pad_constant_values=0, # expose this ? + **resolution, + ) + + if self.preproc is not None: + patch_data = patch_data.copy() + patch_data = self.preproc(patch_data) + patch_data_.append(patch_data) + if len(patch_data_) == 1: + patch_data_ = patch_data_[0] + + bound = self.mp_shared_space.patch_outputs[idx] + return patch_data_, bound + + +class WSIPatchDataset(PatchDatasetABC): + """Define a WSI-level patch dataset. + + Attributes: + reader (:class:`.WSIReader`): + A WSI Reader or Virtual Reader for reading pyramidal image + or large tile in pyramidal way. + inputs: + List of coordinates to read from the `reader`, each + coordinate is of the form `[start_x, start_y, end_x, + end_y]`. + patch_input_shape: + A tuple (int, int) or ndarray of shape (2,). Expected size to + read from `reader` at requested `resolution` and `units`. + Expected to be `(height, width)`. + resolution: + See (:class:`.WSIReader`) for details. + units: + See (:class:`.WSIReader`) for details. + preproc_func: + Preprocessing function used to transform the input data. It will + be called on each patch before returning it. + + """ + + def __init__( + self, + img_path, + mode="wsi", + mask_path=None, + patch_input_shape=None, + stride_shape=None, + resolution=None, + units=None, + auto_get_mask=True, + min_mask_ratio=0, + preproc_func=None, + ) -> None: + """Create a WSI-level patch dataset. + + Args: + mode (str): + Can be either `wsi` or `tile` to denote the image to + read is either a whole-slide image or a large image + tile. + img_path (str or Path): + Valid to pyramidal whole-slide image or large tile to + read. + mask_path (str or Path): + Valid mask image. + patch_input_shape: + A tuple (int, int) or ndarray of shape (2,). Expected + shape to read from `reader` at requested `resolution` + and `units`. Expected to be positive and of (height, + width). Note, this is not at `resolution` coordinate + space. + stride_shape: + A tuple (int, int) or ndarray of shape (2,). Expected + stride shape to read at requested `resolution` and + `units`. Expected to be positive and of (height, width). + Note, this is not at level 0. + resolution: + Check (:class:`.WSIReader`) for details. When + `mode='tile'`, value is fixed to be `resolution=1.0` and + `units='baseline'` units: check (:class:`.WSIReader`) for + details. + units: + Units in which `resolution` is defined. + auto_get_mask: + If `True`, then automatically get simple threshold mask using + WSIReader.tissue_mask() function. + min_mask_ratio: + Only patches with positive area percentage above this value are + included. Defaults to 0. + preproc_func: + Preprocessing function used to transform the input data. If + supplied, the function will be called on each patch before + returning it. + + Examples: + >>> # A user defined preproc func and expected behavior + >>> preproc_func = lambda img: img/2 # reduce intensity by half + >>> transformed_img = preproc_func(img) + >>> # Create a dataset to get patches from WSI with above + >>> # preprocessing function + >>> ds = WSIPatchDataset( + ... img_path='/A/B/C/wsi.svs', + ... mode="wsi", + ... patch_input_shape=[512, 512], + ... stride_shape=[256, 256], + ... auto_get_mask=False, + ... preproc_func=preproc_func + ... ) + + """ + super().__init__() + + # Is there a generic func for path test in toolbox? + if not Path.is_file(Path(img_path)): + msg = "`img_path` must be a valid file path." + raise ValueError(msg) + if mode not in ["wsi", "tile"]: + msg = f"`{mode}` is not supported." + raise ValueError(msg) + patch_input_shape = np.array(patch_input_shape) + stride_shape = np.array(stride_shape) + + if ( + not np.issubdtype(patch_input_shape.dtype, np.integer) + or np.size(patch_input_shape) > 2 + or np.any(patch_input_shape < 0) + ): + msg = f"Invalid `patch_input_shape` value {patch_input_shape}." + raise ValueError(msg) + if ( + not np.issubdtype(stride_shape.dtype, np.integer) + or np.size(stride_shape) > 2 + or np.any(stride_shape < 0) + ): + msg = f"Invalid `stride_shape` value {stride_shape}." + raise ValueError(msg) + + self.preproc_func = preproc_func + img_path = Path(img_path) + if mode == "wsi": + self.reader = WSIReader.open(img_path) + else: + logger.warning( + "WSIPatchDataset only reads image tile at " + '`units="baseline"` and `resolution=1.0`.', + stacklevel=2, + ) + img = imread(img_path) + axes = "YXS"[: len(img.shape)] + # initialise metadata for VirtualWSIReader. + # here, we simulate a whole-slide image, but with a single level. + # ! should we expose this so that use can provide their metadata ? + metadata = WSIMeta( + mpp=np.array([1.0, 1.0]), + axes=axes, + objective_power=10, + slide_dimensions=np.array(img.shape[:2][::-1]), + level_downsamples=[1.0], + level_dimensions=[np.array(img.shape[:2][::-1])], + ) + # infer value such that read if mask provided is through + # 'mpp' or 'power' as varying 'baseline' is locked atm + units = "mpp" + resolution = 1.0 + self.reader = VirtualWSIReader( + img, + info=metadata, + ) + + # may decouple into misc ? + # the scaling factor will scale base level to requested read resolution/units + wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units) + + # use all patches, as long as it overlaps source image + self.inputs = PatchExtractor.get_coordinates( + image_shape=wsi_shape, + patch_input_shape=patch_input_shape[::-1], + stride_shape=stride_shape[::-1], + input_within_bound=False, + ) + + mask_reader = None + if mask_path is not None: + mask_path = Path(mask_path) + if not Path.is_file(mask_path): + msg = "`mask_path` must be a valid file path." + raise ValueError(msg) + mask = imread(mask_path) # assume to be gray + mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) + mask = np.array(mask > 0, dtype=np.uint8) + + mask_reader = VirtualWSIReader(mask) + mask_reader.info = self.reader.info + elif auto_get_mask and mode == "wsi" and mask_path is None: + # if no mask provided and `wsi` mode, generate basic tissue + # mask on the fly + mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") + # ? will this mess up ? + mask_reader.info = self.reader.info + + if mask_reader is not None: + selected = PatchExtractor.filter_coordinates( + mask_reader, # must be at the same resolution + self.inputs, # must already be at requested resolution + wsi_shape=wsi_shape, + min_mask_ratio=min_mask_ratio, + ) + self.inputs = self.inputs[selected] + + if len(self.inputs) == 0: + msg = "No patch coordinates remain after filtering." + raise ValueError(msg) + + self.patch_input_shape = patch_input_shape + self.resolution = resolution + self.units = units + + # Perform check on the input + self._check_input_integrity(mode="wsi") + + def __getitem__(self, idx): + """Get an item from the dataset.""" + coords = self.inputs[idx] + # Read image patch from the whole-slide image + patch = self.reader.read_bounds( + coords, + resolution=self.resolution, + units=self.units, + pad_constant_values=255, + coord_space="resolution", + ) + + # Apply preprocessing to selected patch + patch = self._preproc(patch) + + return {"image": patch, "coords": np.array(coords)} + + +class PatchDataset(PatchDatasetABC): + """Define PatchDataset for torch inference. + + Define a simple patch dataset, which inherits from the + `torch.utils.data.Dataset` class. + + Attributes: + inputs: + Either a list of patches, where each patch is a ndarray or a + list of valid path with its extension be (".jpg", ".jpeg", + ".tif", ".tiff", ".png") pointing to an image. + labels: + List of labels for sample at the same index in `inputs`. + Default is `None`. + + Examples: + >>> # A user defined preproc func and expected behavior + >>> preproc_func = lambda img: img/2 # reduce intensity by half + >>> transformed_img = preproc_func(img) + >>> # create a dataset to get patches preprocessed by the above function + >>> ds = PatchDataset( + ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], + ... labels=["labels1", "labels2"], + ... ) + + """ + + def __init__(self, inputs, labels=None) -> None: + """Initialize :class:`PatchDataset`.""" + super().__init__() + + self.data_is_npy_alike = False + + self.inputs = inputs + self.labels = labels + + # perform check on the input + self._check_input_integrity(mode="patch") + + def __getitem__(self, idx): + """Get an item from the dataset.""" + patch = self.inputs[idx] + + # Mode 0 is list of paths + if not self.data_is_npy_alike: + patch = self.load_img(patch) + + # Apply preprocessing to selected patch + patch = self._preproc(patch) + + data = { + "image": patch, + } + if self.labels is not None: + data["label"] = self.labels[idx] + return data + + return data diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 2cba98a32..0a5968b44 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -4,3 +4,9 @@ patch_predictor, semantic_segmentor, ) + +__all__ = [ + "nucleus_instance_segmentor", + "patch_predictor", + "semantic_segmentor", +] diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index 6ad2f634c..8b397b798 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -7,7 +7,7 @@ import numpy as np if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.wsicore.wsimeta import Units + from tiatoolbox.typing import Units @dataclass @@ -257,6 +257,7 @@ class IOSegmentorConfig(ModelIOConfigABC): ... patch_output_shape=(1024, 1024), ... stride_shape=(512, 512), ... ) + ... >>> # Defining io for a network having 3 input and 2 output >>> # at the same resolution, the output is then merged at a >>> # different resolution. diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 22aa95538..71f37cb3a 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -30,13 +30,12 @@ from shapely.geometry import box as shapely_box from shapely.strtree import STRtree +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.models.engine.nucleus_instance_segmentor import ( NucleusInstanceSegmentor, _process_instance_predictions, ) -from .semantic_segmentor import WSIStreamDataset - if TYPE_CHECKING: # pragma: no cover import torch @@ -442,7 +441,7 @@ def callback(new_inst_dicts, remove_uuid_lists, tiles, bounds): callback(*future) continue # some errors happen, log it and propagate exception - # ! this will lead to discard a whole bunch of + # ! this will lead to discard a bunch of # ! inferred tiles within this current WSI if future.exception() is not None: raise future.exception() # noqa: RSE102 diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 8f7983a89..efb7ff67a 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -13,10 +13,8 @@ from shapely.geometry import box as shapely_box from shapely.strtree import STRtree -from tiatoolbox.models.engine.semantic_segmentor import ( - SemanticSegmentor, - WSIStreamDataset, -) +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset +from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor from tiatoolbox.tools.patchextraction import PatchExtractor if TYPE_CHECKING: # pragma: no cover diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 8042afeaf..4b0d158eb 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -11,16 +11,17 @@ import tqdm from tiatoolbox import logger -from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset from tiatoolbox.utils import misc, save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader -from .io_config import IOPatchPredictorConfig - if TYPE_CHECKING: # pragma: no cover from tiatoolbox.typing import Resolution, Units +from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset + +from .io_config import IOPatchPredictorConfig + class PatchPredictor: r"""Patch level predictor. diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 6805281dd..e22bdcc84 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -18,15 +18,14 @@ from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread, misc -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader from .io_config import IOSegmentorConfig if TYPE_CHECKING: # pragma: no cover - from multiprocessing.managers import Namespace - from tiatoolbox.typing import Resolution, Units @@ -107,149 +106,6 @@ def _prepare_save_output( return is_on_drive, count_canvas, cum_canvas -class WSIStreamDataset(torch_data.Dataset): - """Reading a wsi in parallel mode with persistent workers. - - To speed up the inference process for multiple WSIs. The - `torch.utils.data.Dataloader` is set to run in persistent mode. - Normally, this will prevent workers from altering their initial - states (such as provided input etc.). To sidestep this, we use a - shared parallel workspace context manager to send data and signal - from the main thread, thus allowing each worker to load a new wsi as - well as corresponding patch information. - - Args: - mp_shared_space (:class:`Namespace`): - A shared multiprocessing space, must be from - `torch.multiprocessing`. - ioconfig (:class:`IOSegmentorConfig`): - An object which contains I/O placement for patches. - wsi_paths (list): List of paths pointing to a WSI or tiles. - preproc (Callable): - Pre-processing function to be applied to a patch. - mode (str): - Either `"wsi"` or `"tile"` to indicate the format of images - in `wsi_paths`. - - Examples: - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... ) - >>> mp_manager = torch_mp.Manager() - >>> mp_shared_space = mp_manager.Namespace() - >>> mp_shared_space.signal = 1 # adding variable to the shared space - >>> wsi_paths = ['A.svs', 'B.svs'] - >>> wsi_dataset = WSIStreamDataset(ioconfig, wsi_paths, mp_shared_space) - - """ - - def __init__( - self, - ioconfig: IOSegmentorConfig, - wsi_paths: list[str | Path], - mp_shared_space: Namespace, - preproc: Callable[[np.ndarray], np.ndarray] | None = None, - mode="wsi", - ) -> None: - """Initialize :class:`WSIStreamDataset`.""" - super().__init__() - self.mode = mode - self.preproc = preproc - self.ioconfig = copy.deepcopy(ioconfig) - - if mode == "tile": - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - self.ioconfig = self.ioconfig.to_baseline() - - self.mp_shared_space = mp_shared_space - self.wsi_paths = wsi_paths - self.wsi_idx = None # to be received externally via thread communication - self.reader = None - - def _get_reader(self, img_path): - """Get appropriate reader for input path.""" - img_path = Path(img_path) - if self.mode == "wsi": - return WSIReader.open(img_path) - img = imread(img_path) - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - objective_power=10, - axes="YXS", - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - return VirtualWSIReader( - img, - info=metadata, - ) - - def __len__(self) -> int: - """Return the length of the instance attributes.""" - return len(self.mp_shared_space.patch_inputs) - - @staticmethod - def collate_fn(batch): - """Prototype to handle reading exception. - - This will exclude any sample with `None` from the batch. As - such, wrapping `__getitem__` with try-catch and return `None` - upon exceptions will prevent crashing the entire program. But as - a side effect, the batch may not have the size as defined. - - """ - batch = [v for v in batch if v is not None] - return torch.utils.data.dataloader.default_collate(batch) - - def __getitem__(self, idx: int): - """Get an item from the dataset.""" - # ! no need to lock as we do not modify source value in shared space - if self.wsi_idx != self.mp_shared_space.wsi_idx: - self.wsi_idx = int(self.mp_shared_space.wsi_idx.item()) - self.reader = self._get_reader(self.wsi_paths[self.wsi_idx]) - - # this is in XY and at requested resolution (not baseline) - bounds = self.mp_shared_space.patch_inputs[idx] - bounds = bounds.numpy() # expected to be a torch.Tensor - - # be the same as bounds br-tl, unless bounds are of float - patch_data_ = [] - scale_factors = self.ioconfig.scale_to_highest( - self.ioconfig.input_resolutions, - self.ioconfig.resolution_unit, - ) - for idy, resolution in enumerate(self.ioconfig.input_resolutions): - resolution_bounds = np.round(bounds * scale_factors[idy]) - patch_data = self.reader.read_bounds( - resolution_bounds.astype(np.int32), - coord_space="resolution", - pad_constant_values=0, # expose this ? - **resolution, - ) - - if self.preproc is not None: - patch_data = patch_data.copy() - patch_data = self.preproc(patch_data) - patch_data_.append(patch_data) - if len(patch_data_) == 1: - patch_data_ = patch_data_[0] - - bound = self.mp_shared_space.patch_outputs[idx] - return patch_data_, bound - - class SemanticSegmentor: """Pixel-wise segmentation predictor. From 613b5cfe9d98a0df468524ef29954166b51e2f8d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:27:14 +0000 Subject: [PATCH 09/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 2 +- tiatoolbox/models/engine/engine_abc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 87ea7f317..f754fcff2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -546,7 +546,7 @@ class chdir(AbstractContextManager): # noqa: N801 """ - def __init__(self, path): + def __init__(self, path) -> None: self.path = path self._old_cwd = [] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 223c4a051..69d66af73 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -5,7 +5,7 @@ class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox.""" - def __init__(self): + def __init__(self) -> None: """Initialize Engine.""" super().__init__() From 59a3553c8c37e632eb5f7cd5eb64b902f20fc046 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 29 Aug 2023 17:32:10 +0100 Subject: [PATCH 10/36] :recycle: Refactor Minor Changes to Keep #635 Clean (#705) * :recycle: Refactor minor changes to keep #635 clean --- pyproject.toml | 2 +- .../_test_feature_extractor.py} | 0 .../_test_multi_task_segmentor.py} | 0 .../_test_nucleus_instance_segmentor.py} | 0 .../_test_patch_predictor.py} | 0 .../_test_semantic_segmentation.py} | 2 +- tests/models/test_abc.py | 22 ++++++++++++++- tests/models/test_arch_vanilla.py | 2 +- tests/test_utils.py | 18 ------------- tests/test_wsimeta.py | 1 - tiatoolbox/models/__init__.py | 2 ++ tiatoolbox/models/dataset/dataset_abc.py | 17 +++++++----- tiatoolbox/models/engine/patch_predictor.py | 5 ++-- .../models/engine/semantic_segmentor.py | 8 ++++-- tiatoolbox/models/models_abc.py | 27 ++++++++++++++++++- tiatoolbox/utils/misc.py | 20 -------------- 16 files changed, 72 insertions(+), 54 deletions(-) rename tests/{models/test_feature_extractor.py => engines/_test_feature_extractor.py} (100%) rename tests/{models/test_multi_task_segmentor.py => engines/_test_multi_task_segmentor.py} (100%) rename tests/{models/test_nucleus_instance_segmentor.py => engines/_test_nucleus_instance_segmentor.py} (100%) rename tests/{models/test_patch_predictor.py => engines/_test_patch_predictor.py} (100%) rename tests/{models/test_semantic_segmentation.py => engines/_test_semantic_segmentation.py} (99%) diff --git a/pyproject.toml b/pyproject.toml index 0bead17ab..91551031f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ select = [ "C90", # mccabe "T10", # flake8-debugger "T20", # flake8-print - "ANN", # flake8-annotations + # "ANN", # flake8-annotations "ARG", # flake8-unused-arguments "BLE", # flake8-blind-except "COM", # flake8-commas diff --git a/tests/models/test_feature_extractor.py b/tests/engines/_test_feature_extractor.py similarity index 100% rename from tests/models/test_feature_extractor.py rename to tests/engines/_test_feature_extractor.py diff --git a/tests/models/test_multi_task_segmentor.py b/tests/engines/_test_multi_task_segmentor.py similarity index 100% rename from tests/models/test_multi_task_segmentor.py rename to tests/engines/_test_multi_task_segmentor.py diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/engines/_test_nucleus_instance_segmentor.py similarity index 100% rename from tests/models/test_nucleus_instance_segmentor.py rename to tests/engines/_test_nucleus_instance_segmentor.py diff --git a/tests/models/test_patch_predictor.py b/tests/engines/_test_patch_predictor.py similarity index 100% rename from tests/models/test_patch_predictor.py rename to tests/engines/_test_patch_predictor.py diff --git a/tests/models/test_semantic_segmentation.py b/tests/engines/_test_semantic_segmentation.py similarity index 99% rename from tests/models/test_semantic_segmentation.py rename to tests/engines/_test_semantic_segmentation.py index f923cb65b..fe334d5fe 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/engines/_test_semantic_segmentation.py @@ -289,7 +289,7 @@ def test_crash_segmentor(remote_sample: Callable) -> None: units="baseline", ) - _rm_dir("output") + shutil.rmtree("output", ignore_errors=True) with pytest.raises(ValueError, match=r"Invalid resolution.*"): semantic_segmentor.predict( diff --git a/tests/models/test_abc.py b/tests/models/test_abc.py index c097499f0..52fa4a6e8 100644 --- a/tests/models/test_abc.py +++ b/tests/models/test_abc.py @@ -1,8 +1,10 @@ """Unit test package for ABC and __init__ .""" +from __future__ import annotations import pytest -from tiatoolbox import rcParam +import tiatoolbox.models +from tiatoolbox import rcParam, utils from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils import env_detection as toolbox_env @@ -105,3 +107,21 @@ def infer_batch() -> None: # coverage setter check model.postproc_func = None # skipcq: PYL-W0201 assert model.postproc_func(2) == 0 + + +def test_model_to() -> None: + """Test for placing model on device.""" + import torchvision.models as torch_models + from torch import nn + + # Test on GPU + # no GPU on Travis so this will crash + if not utils.env_detection.has_gpu(): + model = torch_models.resnet18() + with pytest.raises((AssertionError, RuntimeError)): + _ = tiatoolbox.models.models_abc.model_to(on_gpu=True, model=model) + + # Test on CPU + model = torch_models.resnet18() + model = tiatoolbox.models.models_abc.model_to(on_gpu=False, model=model) + assert isinstance(model, nn.Module) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index 26020aa07..a2b1ac5c9 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -5,7 +5,7 @@ import torch from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.utils.misc import model_to +from tiatoolbox.models.models_abc import model_to ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator diff --git a/tests/test_utils.py b/tests/test_utils.py index 4eb2d42cd..ef6afa734 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1312,24 +1312,6 @@ def test_select_device() -> None: assert device == "cpu" -def test_model_to() -> None: - """Test for placing model on device.""" - import torchvision.models as torch_models - from torch import nn - - # Test on GPU - # no GPU on Travis so this will crash - if not utils.env_detection.has_gpu(): - model = torch_models.resnet18() - with pytest.raises((AssertionError, RuntimeError)): - _ = misc.model_to(on_gpu=True, model=model) - - # Test on CPU - model = torch_models.resnet18() - model = misc.model_to(on_gpu=False, model=model) - assert isinstance(model, nn.Module) - - def test_save_as_json(tmp_path: Path) -> None: """Test save data to json.""" # This should be broken up into separate tests! diff --git a/tests/test_wsimeta.py b/tests/test_wsimeta.py index bc3555e36..01b1cac8b 100644 --- a/tests/test_wsimeta.py +++ b/tests/test_wsimeta.py @@ -8,7 +8,6 @@ from tiatoolbox.wsicore import WSIMeta, wsimeta, wsireader -# noinspection PyTypeChecker def test_wsimeta_init_fail() -> None: """Test incorrect init for WSIMeta raises TypeError.""" with pytest.raises(TypeError): diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index ecd173ced..e91a3b68c 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -1,4 +1,6 @@ """Models package for the models implemented in tiatoolbox.""" +from __future__ import annotations + from . import architecture, dataset, engine, models_abc from .architecture.hovernet import HoVerNet from .architecture.hovernetplus import HoVerNetPlus diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index fc9abdf7e..9cc8bb96c 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -340,7 +340,7 @@ class WSIPatchDataset(PatchDatasetABC): """ - def __init__( + def __init__( # noqa: PLR0913, PLR0915 self, img_path, mode="wsi", @@ -349,9 +349,10 @@ def __init__( stride_shape=None, resolution=None, units=None, - auto_get_mask=True, min_mask_ratio=0, preproc_func=None, + *, + auto_get_mask=True, ) -> None: """Create a WSI-level patch dataset. @@ -424,14 +425,14 @@ def __init__( if ( not np.issubdtype(patch_input_shape.dtype, np.integer) - or np.size(patch_input_shape) > 2 + or np.size(patch_input_shape) > 2 # noqa: PLR2004 or np.any(patch_input_shape < 0) ): msg = f"Invalid `patch_input_shape` value {patch_input_shape}." raise ValueError(msg) if ( not np.issubdtype(stride_shape.dtype, np.integer) - or np.size(stride_shape) > 2 + or np.size(stride_shape) > 2 # noqa: PLR2004 or np.any(stride_shape < 0) ): msg = f"Invalid `stride_shape` value {stride_shape}." @@ -565,7 +566,11 @@ class PatchDataset(PatchDatasetABC): """ - def __init__(self, inputs, labels=None) -> None: + def __init__( + self: PatchDataset, + inputs: np.ndarray | list, + labels: list | None = None, + ) -> None: """Initialize :class:`PatchDataset`.""" super().__init__() @@ -577,7 +582,7 @@ def __init__(self, inputs, labels=None) -> None: # perform check on the input self._check_input_integrity(mode="patch") - def __getitem__(self, idx): + def __getitem__(self: PatchDataset, idx: int) -> dict: """Get an item from the dataset.""" patch = self.inputs[idx] diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index a9dcce8a7..807cd9fad 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -10,8 +10,9 @@ import torch import tqdm +import tiatoolbox.models.models_abc from tiatoolbox import logger -from tiatoolbox.utils import misc, save_as_json +from tiatoolbox.utils import save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader if TYPE_CHECKING: # pragma: no cover @@ -402,7 +403,7 @@ def _predict_engine( ) # use external for testing - model = misc.model_to(model=self.model, on_gpu=on_gpu) + model = tiatoolbox.models.models_abc.model_to(model=self.model, on_gpu=on_gpu) cum_output = { "probabilities": [], diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 33acd4bd5..e1341c640 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -16,11 +16,12 @@ import torch.utils.data as torch_data import tqdm +import tiatoolbox.models.models_abc from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils import imread, misc +from tiatoolbox.utils import imread from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader from .io_config import IOSegmentorConfig @@ -1049,7 +1050,10 @@ def predict( # noqa: PLR0913 # use external for testing self._on_gpu = on_gpu - self._model = misc.model_to(model=self.model, on_gpu=on_gpu) + self._model = tiatoolbox.models.models_abc.model_to( + model=self.model, + on_gpu=on_gpu, + ) # workers should be > 0 else Value Error will be thrown self._prepare_workers() diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 4edc5defa..9c5bb4cd1 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -1,9 +1,15 @@ """Define Abstract Base Class for Models defined in tiatoolbox.""" +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -import numpy as np +import torch from torch import nn +if TYPE_CHECKING: # pragma: no cover + import numpy as np + class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" @@ -109,3 +115,22 @@ def postproc_func(self, func): self._postproc = self.postproc else: self._postproc = func + + +def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: + """Transfers model to cpu/gpu. + + Args: + model (torch.nn.Module): PyTorch defined model. + on_gpu (bool): Transfers model to gpu if True otherwise to cpu. + + Returns: + torch.nn.Module: + The model after being moved to cpu/gpu. + + """ + if on_gpu: # DataParallel work only for cuda + model = torch.nn.DataParallel(model) + return model.to("cuda") + + return model.to("cpu") diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 89e60970e..98a6fe7ec 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -14,7 +14,6 @@ import numpy as np import pandas as pd import requests -import torch import yaml from filelock import FileLock from shapely.affinity import translate @@ -873,25 +872,6 @@ def select_device(*, on_gpu: bool) -> str: return "cpu" -def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: - """Transfers model to cpu/gpu. - - Args: - model (torch.nn.Module): PyTorch defined model. - on_gpu (bool): Transfers model to gpu if True otherwise to cpu. - - Returns: - torch.nn.Module: - The model after being moved to cpu/gpu. - - """ - if on_gpu: # DataParallel work only for cuda - model = torch.nn.DataParallel(model) - return model.to("cuda") - - return model.to("cpu") - - def get_bounding_box(img: np.ndarray) -> np.ndarray: """Get bounding box coordinate information. From 112d2b4e848dc7a65d324e5ee629cd21e8d172e3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 30 Aug 2023 11:01:09 +0100 Subject: [PATCH 11/36] :twisted_rightwards_arrows: Merge develop into dev-define-engines-abc --- pre-commit/missing_imports.py | 9 +-- pre-commit/requirements_consistency.py | 5 +- tests/engines/_test_multi_task_segmentor.py | 4 +- .../_test_nucleus_instance_segmentor.py | 4 +- tests/engines/_test_patch_predictor.py | 63 ++++++++++++------- tests/models/test_dataset.py | 9 +-- 6 files changed, 58 insertions(+), 36 deletions(-) diff --git a/pre-commit/missing_imports.py b/pre-commit/missing_imports.py index 12fc58bb8..09fb024e9 100644 --- a/pre-commit/missing_imports.py +++ b/pre-commit/missing_imports.py @@ -14,6 +14,7 @@ import sys import tokenize from pathlib import Path +from typing import NoReturn from requirements_consistency import parse_requirements @@ -135,7 +136,7 @@ def stems(node: ast.Import | ast.ImportFrom) -> list[tuple[str, str]]: ) -def main(): +def main() -> NoReturn: """Main entry point.""" parser = argparse.ArgumentParser( description="Static analysis of requirements files and import statements.", @@ -219,13 +220,13 @@ def find_bad_imports( return result -def find_comments(path, line_num: int): +def find_comments(path: str | Path, line_num: int) -> list: """Find comments on the given line. Args: - path: + path (str | Path): Path to the file. - line_num: + line_num (int): Line number to find comments on. Returns: diff --git a/pre-commit/requirements_consistency.py b/pre-commit/requirements_consistency.py index f46795bce..1b3dde8be 100644 --- a/pre-commit/requirements_consistency.py +++ b/pre-commit/requirements_consistency.py @@ -4,6 +4,7 @@ import importlib import sys from pathlib import Path +from typing import NoReturn import yaml from pkg_resources import Requirement @@ -100,7 +101,7 @@ def parse_conda(file_path: Path) -> dict[str, Requirement]: return packages -def parse_setup_py(file_path) -> dict[str, Requirement]: +def parse_setup_py(file_path: Path) -> dict[str, Requirement]: """Parse a setup.py file. Args: @@ -233,7 +234,7 @@ def in_common_consistent(all_requirements: dict[Path, dict[str, Requirement]]) - return consistent -def main(): +def main() -> NoReturn: """Main entry point for the hook.""" root = Path(__file__).parent.parent test_files_exist(root) diff --git a/tests/engines/_test_multi_task_segmentor.py b/tests/engines/_test_multi_task_segmentor.py index 5b5b1de5c..c3cc85cea 100644 --- a/tests/engines/_test_multi_task_segmentor.py +++ b/tests/engines/_test_multi_task_segmentor.py @@ -32,13 +32,13 @@ # ---------------------------------------------------- -def _crash_func(_) -> None: +def _crash_func(_: object) -> None: """Helper to induce crash.""" msg = "Propagation Crash." raise ValueError(msg) -def semantic_postproc_func(raw_output): +def semantic_postproc_func(raw_output: np.ndarray) -> np.ndarray: """Function to post process semantic segmentations. Post processes semantic segmentation to form one map output. diff --git a/tests/engines/_test_nucleus_instance_segmentor.py b/tests/engines/_test_nucleus_instance_segmentor.py index 02900da71..6d3ea2f67 100644 --- a/tests/engines/_test_nucleus_instance_segmentor.py +++ b/tests/engines/_test_nucleus_instance_segmentor.py @@ -36,13 +36,13 @@ # ---------------------------------------------------- -def _crash_func(_x) -> None: +def _crash_func(_x: object) -> None: """Helper to induce crash.""" msg = "Propagation Crash." raise ValueError(msg) -def helper_tile_info(): +def helper_tile_info() -> list: """Helper function for tile information.""" predictor = NucleusInstanceSegmentor(model="A") # ! assuming the tiles organized as follows (coming out from diff --git a/tests/engines/_test_patch_predictor.py b/tests/engines/_test_patch_predictor.py index f10d9fca6..985970d42 100644 --- a/tests/engines/_test_patch_predictor.py +++ b/tests/engines/_test_patch_predictor.py @@ -1,4 +1,5 @@ """Test for Patch Predictor.""" +from __future__ import annotations import copy import shutil @@ -32,7 +33,10 @@ # ------------------------------------------------------------------------------------- -def test_patch_dataset_path_imgs(sample_patch1, sample_patch2) -> None: +def test_patch_dataset_path_imgs( + sample_patch1: str | Path, + sample_patch2: str | Path, +) -> None: """Test for patch dataset with a list of file paths as input.""" size = (224, 224, 3) @@ -212,18 +216,21 @@ def test_patch_dataset_crash(tmp_path: Path) -> None: predefined_preproc_func("secret-dataset") -def test_wsi_patch_dataset(sample_wsi_dict, tmp_path: Path) -> None: # noqa: PLR0915 +def test_wsi_patch_dataset( # noqa: PLR0915 + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: """A test for creation and bare output.""" # convert to pathlib Path to prevent wsireader complaint mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - def reuse_init(img_path=mini_wsi_svs, **kwargs): + def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: """Testing function.""" return WSIPatchDataset(img_path=img_path, **kwargs) - def reuse_init_wsi(**kwargs): + def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: """Testing function.""" return reuse_init(mode="wsi", **kwargs) @@ -231,13 +238,13 @@ def reuse_init_wsi(**kwargs): # intentionally created to check error # skipcq class Proto(PatchDatasetABC): - def __init__(self) -> None: + def __init__(self: Proto) -> None: super().__init__() self.inputs = "CRASH" self._check_input_integrity("wsi") # skipcq - def __getitem__(self, idx): + def __getitem__(self: Proto, idx: int) -> object: """Get an item from the dataset.""" with pytest.raises( @@ -416,7 +423,7 @@ def test_patch_dataset_abc() -> None: # skipcq class Proto(PatchDatasetABC): # skipcq - def __init__(self) -> None: + def __init__(self: Proto) -> None: super().__init__() # crash due to undefined __getitem__ @@ -426,11 +433,11 @@ def __init__(self) -> None: # skipcq class Proto(PatchDatasetABC): # skipcq - def __init__(self) -> None: + def __init__(self: Proto) -> None: super().__init__() # skipcq - def __getitem__(self, idx): + def __getitem__(self: Proto, idx: int) -> None: """Get an item from the dataset.""" ds = Proto() # skipcq @@ -598,7 +605,11 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: shutil.rmtree(tmp_path / "dump", ignore_errors=True) -def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path: Path) -> None: +def test_patch_predictor_api( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: """Helper function to get the model output using API 1.""" save_dir_path = tmp_path @@ -694,7 +705,11 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path: Path) -> No assert len(output["predictions"]) == len(output["probabilities"]) -def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path, chdir: Callable) -> None: +def test_wsi_predictor_api( + sample_wsi_dict: dict, + tmp_path: Path, + chdir: Callable, +) -> None: """Test normal run of wsi predictor.""" save_dir_path = tmp_path @@ -814,7 +829,7 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path, chdir: Callable) -> shutil.rmtree("output", ignore_errors=True) -def test_wsi_predictor_merge_predictions(sample_wsi_dict) -> None: +def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: """Test normal run of wsi predictor with merge predictions option.""" # convert to pathlib Path to prevent reader complaint mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) @@ -915,11 +930,12 @@ def test_wsi_predictor_merge_predictions(sample_wsi_dict) -> None: def _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=None, - predictions_check=None, - on_gpu=ON_GPU, + inputs: list, + pretrained_model: str, + probabilities_check: list | None = None, + predictions_check: list | None = None, + *, + on_gpu: bool = ON_GPU, ) -> None: """Test the predictions of multiple models included in tiatoolbox.""" predictor = PatchPredictor( @@ -954,7 +970,10 @@ def _test_predictor_output( ) -def test_patch_predictor_kather100k_output(sample_patch1, sample_patch2) -> None: +def test_patch_predictor_kather100k_output( + sample_patch1: Path, + sample_patch2: Path, +) -> None: """Test the output of patch prediction models on Kather100K dataset.""" inputs = [Path(sample_patch1), Path(sample_patch2)] pretrained_info = { @@ -989,7 +1008,7 @@ def test_patch_predictor_kather100k_output(sample_patch1, sample_patch2) -> None break -def test_patch_predictor_pcam_output(sample_patch3, sample_patch4) -> None: +def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: """Test the output of patch prediction models on PCam dataset.""" inputs = [Path(sample_patch3), Path(sample_patch4)] pretrained_info = { @@ -1029,7 +1048,7 @@ def test_patch_predictor_pcam_output(sample_patch3, sample_patch4) -> None: # ------------------------------------------------------------------------------------- -def test_command_line_models_file_not_found(sample_svs, tmp_path: Path) -> None: +def test_command_line_models_file_not_found(sample_svs: Path, tmp_path: Path) -> None: """Test for models CLI file not found error.""" runner = CliRunner() model_file_not_found_result = runner.invoke( @@ -1050,7 +1069,7 @@ def test_command_line_models_file_not_found(sample_svs, tmp_path: Path) -> None: assert isinstance(model_file_not_found_result.exception, FileNotFoundError) -def test_command_line_models_incorrect_mode(sample_svs, tmp_path: Path) -> None: +def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> None: """Test for models CLI mode not in wsi, tile.""" runner = CliRunner() mode_not_in_wsi_tile_result = runner.invoke( @@ -1073,7 +1092,7 @@ def test_command_line_models_incorrect_mode(sample_svs, tmp_path: Path) -> None: assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) -def test_cli_model_single_file(sample_svs, tmp_path: Path) -> None: +def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: """Test for models CLI single file.""" runner = CliRunner() models_wsi_result = runner.invoke( diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index b14556004..1b9046bc4 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -1,4 +1,5 @@ """Test for predefined dataset within toolbox.""" +from __future__ import annotations import shutil from pathlib import Path @@ -16,7 +17,7 @@ class Proto1(DatasetInfoABC): """Intentionally created to check error with new attribute a.""" - def __init__(self) -> None: + def __init__(self: Proto1) -> None: """Proto1 initialization.""" self.a = "a" @@ -24,7 +25,7 @@ def __init__(self) -> None: class Proto2(DatasetInfoABC): """Intentionally created to check error with attribute inputs.""" - def __init__(self) -> None: + def __init__(self: Proto2) -> None: """Proto2 initialization.""" self.inputs = "a" @@ -32,7 +33,7 @@ def __init__(self) -> None: class Proto3(DatasetInfoABC): """Intentionally created to check error with attribute inputs and labels.""" - def __init__(self) -> None: + def __init__(self: Proto3) -> None: """Proto3 initialization.""" self.inputs = "a" self.labels = "a" @@ -41,7 +42,7 @@ def __init__(self) -> None: class Proto4(DatasetInfoABC): """Intentionally created to check error with attribute inputs and label names.""" - def __init__(self) -> None: + def __init__(self: Proto4) -> None: """Proto4 initialization.""" self.inputs = "a" self.label_names = "a" From 5bfdcb1fe28e13e523a16491d9173f102a820c17 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 31 Aug 2023 12:57:49 +0100 Subject: [PATCH 12/36] :recycle: Move `DataSet` tests to `test_dataset.py` (#708) - Move `DataSet` tests to `test_dataset.py` --- tests/engines/_test_patch_predictor.py | 444 +------------------------ tests/models/test_dataset.py | 438 +++++++++++++++++++++++- 2 files changed, 436 insertions(+), 446 deletions(-) diff --git a/tests/engines/_test_patch_predictor.py b/tests/engines/_test_patch_predictor.py index 985970d42..b3322635d 100644 --- a/tests/engines/_test_patch_predictor.py +++ b/tests/engines/_test_patch_predictor.py @@ -6,459 +6,17 @@ from pathlib import Path from typing import Callable -import cv2 import numpy as np import pytest -import torch from click.testing import CliRunner from tiatoolbox import cli from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.models.dataset import ( - PatchDataset, - PatchDatasetABC, - WSIPatchDataset, - predefined_preproc_func, -) -from tiatoolbox.utils import download_data, imread, imwrite +from tiatoolbox.utils import download_data, imwrite from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = toolbox_env.has_gpu() -RNG = np.random.default_rng() # Numpy Random Generator - -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - - -def test_patch_dataset_path_imgs( - sample_patch1: str | Path, - sample_patch2: str | Path, -) -> None: - """Test for patch dataset with a list of file paths as input.""" - size = (224, 224, 3) - - dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_list_imgs(tmp_path: Path) -> None: - """Test for patch dataset with a list of images as input.""" - save_dir_path = tmp_path - - size = (5, 5, 3) - img = RNG.integers(low=0, high=255, size=size) - list_imgs = [img, img, img] - dataset = PatchDataset(list_imgs) - - dataset.preproc_func = lambda x: x - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - # test for changing to another preproc - dataset.preproc_func = lambda x: x - 10 - item = dataset[0] - assert np.sum(item["image"] - (list_imgs[0] - 10)) == 0 - - # * test for loading npy - # remove previously generated data - if Path.exists(save_dir_path): - shutil.rmtree(save_dir_path, ignore_errors=True) - Path.mkdir(save_dir_path, parents=True) - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - assert imgs[0] is not None - # test for path object - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - - -def test_patch_datasetarray_imgs() -> None: - """Test for patch dataset with a numpy array of a list of images.""" - size = (5, 5, 3) - img = RNG.integers(0, 255, size=size) - list_imgs = [img, img, img] - labels = [1, 2, 3] - array_imgs = np.array(list_imgs) - - # test different setter for label - dataset = PatchDataset(array_imgs, labels=labels) - an_item = dataset[2] - assert an_item["label"] == 3 - dataset = PatchDataset(array_imgs, labels=None) - an_item = dataset[2] - assert "label" not in an_item - - dataset = PatchDataset(array_imgs) - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_crash(tmp_path: Path) -> None: - """Test to make sure patch dataset crashes with incorrect input.""" - # all below examples should fail when input to PatchDataset - save_dir_path = tmp_path - - # not supported input type - imgs = {"a": RNG.integers(0, 255, (4, 4, 4))} - with pytest.raises( - ValueError, - match=r".*Input must be either a list/array of images.*", - ): - _ = PatchDataset(imgs) - - # ndarray of mixed dtype - imgs = np.array( - [RNG.integers(0, 255, (4, 5, 3)), "Should crash"], - dtype=object, - ) - with pytest.raises(ValueError, match="Provided input array is non-numerical."): - _ = PatchDataset(imgs) - - # ndarray(s) of NHW images - imgs = RNG.integers(0, 255, (4, 4, 4)) - with pytest.raises(ValueError, match=r".*array of the form HWC*"): - _ = PatchDataset(imgs) - - # list of ndarray(s) with different sizes - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 5, 3)), - ] - with pytest.raises(ValueError, match="Images must have the same dimensions."): - _ = PatchDataset(imgs) - - # list of ndarray(s) with HW and HWC mixed up - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 4)), - ] - with pytest.raises( - ValueError, - match="Each sample must be an array of the form HWC.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = [RNG.integers(0, 255, (4, 4, 3)), "you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = ["you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list not exist paths - with pytest.raises( - ValueError, - match=r".*valid image paths.*", - ): - _ = PatchDataset(["img.npy"]) - - # ** test different extension parser - # save dummy data to temporary location - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - - torch.save({"a": "a"}, save_dir_path / "sample1.tar") - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - - imgs = [ - save_dir_path / "sample1.tar", - save_dir_path / "sample2.npy", - ] - with pytest.raises( - ValueError, - match="Cannot load image data from", - ): - _ = PatchDataset(imgs) - - # preproc func for not defined dataset - with pytest.raises( - ValueError, - match=r".* preprocessing .* does not exist.", - ): - predefined_preproc_func("secret-dataset") - - -def test_wsi_patch_dataset( # noqa: PLR0915 - sample_wsi_dict: dict, - tmp_path: Path, -) -> None: - """A test for creation and bare output.""" - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return WSIPatchDataset(img_path=img_path, **kwargs) - - def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return reuse_init(mode="wsi", **kwargs) - - # test for ABC validate - # intentionally created to check error - # skipcq - class Proto(PatchDatasetABC): - def __init__(self: Proto) -> None: - super().__init__() - self.inputs = "CRASH" - self._check_input_integrity("wsi") - - # skipcq - def __getitem__(self: Proto, idx: int) -> object: - """Get an item from the dataset.""" - - with pytest.raises( - ValueError, - match=r".*`inputs` should be a list of patch coordinates.*", - ): - Proto() # skipcq - - # invalid path input - with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): - WSIPatchDataset( - img_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - ) - - # invalid mask path input - with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): - WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - - # invalid mode - with pytest.raises(ValueError, match="`X` is not supported."): - reuse_init(mode="X") - - # invalid patch - with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): - reuse_init() - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, "a"]) - with pytest.raises(ValueError, match="Invalid `stride_shape` value None."): - reuse_init_wsi(patch_input_shape=512) - # invalid stride - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, "a"]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, 512, 512]) - # negative - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, -512], stride_shape=[512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, -512]) - - # * for wsi - # dummy test for analysing the output - # stride and patch size should be as expected - patch_size = [512, 512] - stride_size = [256, 256] - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - reader = WSIReader.open(mini_wsi_svs) - # tiling top to bottom, left to right - ds_roi = ds[2]["image"] - step_idx = 2 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - rd_roi = reader.read_bounds( - start + end, - resolution=1.0, - units="mpp", - coord_space="resolution", - ) - correlation = np.corrcoef( - cv2.cvtColor(ds_roi, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(rd_roi, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert ds_roi.shape[0] == rd_roi.shape[0] - assert ds_roi.shape[1] == rd_roi.shape[1] - assert np.min(correlation) > 0.9, correlation - - # test creation with auto mask gen and input mask - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=True, - ) - assert len(ds) > 0 - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=mini_wsi_msk, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - negative_mask = imread(mini_wsi_msk) - negative_mask = np.zeros_like(negative_mask) - negative_mask_path = tmp_path / "negative_mask.png" - imwrite(negative_mask_path, negative_mask) - with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=negative_mask_path, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - - # * for tile - reader = WSIReader.open(mini_wsi_jpg) - tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, - mode="tile", - patch_input_shape=patch_size, - stride_shape=stride_size, - auto_get_mask=False, - ) - step_idx = 3 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - roi2 = reader.read_bounds( - start + end, - resolution=1.0, - units="baseline", - coord_space="resolution", - ) - roi1 = tile_ds[3]["image"] # match with step_index - correlation = np.corrcoef( - cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert roi1.shape[0] == roi2.shape[0] - assert roi1.shape[1] == roi2.shape[1] - assert np.min(correlation) > 0.9, correlation - - -def test_patch_dataset_abc() -> None: - """Test for ABC methods. - - Test missing definition for abstract intentionally created to check error. - - """ - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # crash due to undefined __getitem__ - with pytest.raises(TypeError): - Proto() # skipcq - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # skipcq - def __getitem__(self: Proto, idx: int) -> None: - """Get an item from the dataset.""" - - ds = Proto() # skipcq - - # test setter and getter - assert ds.preproc_func(1) == 1 - ds.preproc_func = lambda x: x - 1 # skipcq: PYL-W0201 - assert ds.preproc_func(1) == 0 - assert ds.preproc(1) == 1, "Must be unchanged!" - ds.preproc_func = None # skipcq: PYL-W0201 - assert ds.preproc_func(2) == 2 - - # test assign uncallable to preproc_func/postproc_func - with pytest.raises(ValueError, match=r".*callable*"): - ds.preproc_func = 1 # skipcq: PYL-W0201 - - -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - # ------------------------------------------------------------------------------------- # Engine diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index 1b9046bc4..538b2dbd4 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -4,14 +4,24 @@ import shutil from pathlib import Path +import cv2 import numpy as np import pytest +import torch from tiatoolbox import rcParam -from tiatoolbox.models import PatchDataset -from tiatoolbox.models.dataset import DatasetInfoABC, KatherPatchDataset -from tiatoolbox.utils import download_data, unzip_data +from tiatoolbox.models import PatchDataset, WSIPatchDataset +from tiatoolbox.models.dataset import ( + DatasetInfoABC, + KatherPatchDataset, + PatchDatasetABC, + predefined_preproc_func, +) +from tiatoolbox.utils import download_data, imread, imwrite, unzip_data from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.wsicore import WSIReader + +RNG = np.random.default_rng() # Numpy Random Generator class Proto1(DatasetInfoABC): @@ -116,3 +126,425 @@ def test_kather_dataset(tmp_path: Path) -> None: # remove generated data shutil.rmtree(save_dir_path, ignore_errors=True) + + +def test_patch_dataset_path_imgs( + sample_patch1: str | Path, + sample_patch2: str | Path, +) -> None: + """Test for patch dataset with a list of file paths as input.""" + size = (224, 224, 3) + + dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) + + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + +def test_patch_dataset_list_imgs(tmp_path: Path) -> None: + """Test for patch dataset with a list of images as input.""" + save_dir_path = tmp_path + + size = (5, 5, 3) + img = RNG.integers(low=0, high=255, size=size) + list_imgs = [img, img, img] + dataset = PatchDataset(list_imgs) + + dataset.preproc_func = lambda x: x + + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + # test for changing to another preproc + dataset.preproc_func = lambda x: x - 10 + item = dataset[0] + assert np.sum(item["image"] - (list_imgs[0] - 10)) == 0 + + # * test for loading npy + # remove previously generated data + if Path.exists(save_dir_path): + shutil.rmtree(save_dir_path, ignore_errors=True) + Path.mkdir(save_dir_path, parents=True) + np.save( + str(save_dir_path / "sample2.npy"), + RNG.integers(0, 255, (4, 4, 3)), + ) + imgs = [ + save_dir_path / "sample2.npy", + ] + _ = PatchDataset(imgs) + assert imgs[0] is not None + # test for path object + imgs = [ + save_dir_path / "sample2.npy", + ] + _ = PatchDataset(imgs) + + +def test_patch_datasetarray_imgs() -> None: + """Test for patch dataset with a numpy array of a list of images.""" + size = (5, 5, 3) + img = RNG.integers(0, 255, size=size) + list_imgs = [img, img, img] + labels = [1, 2, 3] + array_imgs = np.array(list_imgs) + + # test different setter for label + dataset = PatchDataset(array_imgs, labels=labels) + an_item = dataset[2] + assert an_item["label"] == 3 + dataset = PatchDataset(array_imgs, labels=None) + an_item = dataset[2] + assert "label" not in an_item + + dataset = PatchDataset(array_imgs) + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + +def test_patch_dataset_crash(tmp_path: Path) -> None: + """Test to make sure patch dataset crashes with incorrect input.""" + # all below examples should fail when input to PatchDataset + save_dir_path = tmp_path + + # not supported input type + imgs = {"a": RNG.integers(0, 255, (4, 4, 4))} + with pytest.raises( + ValueError, + match=r".*Input must be either a list/array of images.*", + ): + _ = PatchDataset(imgs) + + # ndarray of mixed dtype + imgs = np.array( + [RNG.integers(0, 255, (4, 5, 3)), "Should crash"], + dtype=object, + ) + with pytest.raises(ValueError, match="Provided input array is non-numerical."): + _ = PatchDataset(imgs) + + # ndarray(s) of NHW images + imgs = RNG.integers(0, 255, (4, 4, 4)) + with pytest.raises(ValueError, match=r".*array of the form HWC*"): + _ = PatchDataset(imgs) + + # list of ndarray(s) with different sizes + imgs = [ + RNG.integers(0, 255, (4, 4, 3)), + RNG.integers(0, 255, (4, 5, 3)), + ] + with pytest.raises(ValueError, match="Images must have the same dimensions."): + _ = PatchDataset(imgs) + + # list of ndarray(s) with HW and HWC mixed up + imgs = [ + RNG.integers(0, 255, (4, 4, 3)), + RNG.integers(0, 255, (4, 4)), + ] + with pytest.raises( + ValueError, + match="Each sample must be an array of the form HWC.", + ): + _ = PatchDataset(imgs) + + # list of mixed dtype + imgs = [RNG.integers(0, 255, (4, 4, 3)), "you_should_crash_here", 123, 456] + with pytest.raises( + ValueError, + match="Input must be either a list/array of images or a list of " + "valid image paths.", + ): + _ = PatchDataset(imgs) + + # list of mixed dtype + imgs = ["you_should_crash_here", 123, 456] + with pytest.raises( + ValueError, + match="Input must be either a list/array of images or a list of " + "valid image paths.", + ): + _ = PatchDataset(imgs) + + # list not exist paths + with pytest.raises( + ValueError, + match=r".*valid image paths.*", + ): + _ = PatchDataset(["img.npy"]) + + # ** test different extension parser + # save dummy data to temporary location + # remove prev generated data + shutil.rmtree(save_dir_path, ignore_errors=True) + save_dir_path.mkdir(parents=True) + + torch.save({"a": "a"}, save_dir_path / "sample1.tar") + np.save( + str(save_dir_path / "sample2.npy"), + RNG.integers(0, 255, (4, 4, 3)), + ) + + imgs = [ + save_dir_path / "sample1.tar", + save_dir_path / "sample2.npy", + ] + with pytest.raises( + ValueError, + match="Cannot load image data from", + ): + _ = PatchDataset(imgs) + + # preproc func for not defined dataset + with pytest.raises( + ValueError, + match=r".* preprocessing .* does not exist.", + ): + predefined_preproc_func("secret-dataset") + + +def test_wsi_patch_dataset( # noqa: PLR0915 + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """A test for creation and bare output.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: + """Testing function.""" + return WSIPatchDataset(img_path=img_path, **kwargs) + + def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: + """Testing function.""" + return reuse_init(mode="wsi", **kwargs) + + # test for ABC validate + # intentionally created to check error + # skipcq + class Proto(PatchDatasetABC): + def __init__(self: Proto) -> None: + super().__init__() + self.inputs = "CRASH" + self._check_input_integrity("wsi") + + # skipcq + def __getitem__(self: Proto, idx: int) -> object: + """Get an item from the dataset.""" + + with pytest.raises( + ValueError, + match=r".*`inputs` should be a list of patch coordinates.*", + ): + Proto() # skipcq + + # invalid path input + with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): + WSIPatchDataset( + img_path="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + ) + + # invalid mask path input + with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): + WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + + # invalid mode + with pytest.raises(ValueError, match="`X` is not supported."): + reuse_init(mode="X") + + # invalid patch + with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): + reuse_init() + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \[512 512 512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512, 512]) + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \['512' 'a'\].", + ): + reuse_init_wsi(patch_input_shape=[512, "a"]) + with pytest.raises(ValueError, match="Invalid `stride_shape` value None."): + reuse_init_wsi(patch_input_shape=512) + # invalid stride + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \['512' 'a'\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, "a"]) + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \[512 512 512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, 512, 512]) + # negative + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \[ 512 -512\].", + ): + reuse_init_wsi(patch_input_shape=[512, -512], stride_shape=[512, 512]) + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \[ 512 -512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, -512]) + + # * for wsi + # dummy test for analysing the output + # stride and patch size should be as expected + patch_size = [512, 512] + stride_size = [256, 256] + ds = reuse_init_wsi( + patch_input_shape=patch_size, + stride_shape=stride_size, + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + reader = WSIReader.open(mini_wsi_svs) + # tiling top to bottom, left to right + ds_roi = ds[2]["image"] + step_idx = 2 # manually calibrate + start = (step_idx * stride_size[1], 0) + end = (start[0] + patch_size[0], start[1] + patch_size[1]) + rd_roi = reader.read_bounds( + start + end, + resolution=1.0, + units="mpp", + coord_space="resolution", + ) + correlation = np.corrcoef( + cv2.cvtColor(ds_roi, cv2.COLOR_RGB2GRAY).flatten(), + cv2.cvtColor(rd_roi, cv2.COLOR_RGB2GRAY).flatten(), + ) + assert ds_roi.shape[0] == rd_roi.shape[0] + assert ds_roi.shape[1] == rd_roi.shape[1] + assert np.min(correlation) > 0.9, correlation + + # test creation with auto mask gen and input mask + ds = reuse_init_wsi( + patch_input_shape=patch_size, + stride_shape=stride_size, + resolution=1.0, + units="mpp", + auto_get_mask=True, + ) + assert len(ds) > 0 + ds = WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path=mini_wsi_msk, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + negative_mask = imread(mini_wsi_msk) + negative_mask = np.zeros_like(negative_mask) + negative_mask_path = tmp_path / "negative_mask.png" + imwrite(negative_mask_path, negative_mask) + with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): + ds = WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path=negative_mask_path, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + # * for tile + reader = WSIReader.open(mini_wsi_jpg) + tile_ds = WSIPatchDataset( + img_path=mini_wsi_jpg, + mode="tile", + patch_input_shape=patch_size, + stride_shape=stride_size, + auto_get_mask=False, + ) + step_idx = 3 # manually calibrate + start = (step_idx * stride_size[1], 0) + end = (start[0] + patch_size[0], start[1] + patch_size[1]) + roi2 = reader.read_bounds( + start + end, + resolution=1.0, + units="baseline", + coord_space="resolution", + ) + roi1 = tile_ds[3]["image"] # match with step_index + correlation = np.corrcoef( + cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), + cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), + ) + assert roi1.shape[0] == roi2.shape[0] + assert roi1.shape[1] == roi2.shape[1] + assert np.min(correlation) > 0.9, correlation + + +def test_patch_dataset_abc() -> None: + """Test for ABC methods. + + Test missing definition for abstract intentionally created to check error. + + """ + + # skipcq + class Proto(PatchDatasetABC): + # skipcq + def __init__(self: Proto) -> None: + super().__init__() + + # crash due to undefined __getitem__ + with pytest.raises(TypeError): + Proto() # skipcq + + # skipcq + class Proto(PatchDatasetABC): + # skipcq + def __init__(self: Proto) -> None: + super().__init__() + + # skipcq + def __getitem__(self: Proto, idx: int) -> None: + """Get an item from the dataset.""" + + ds = Proto() # skipcq + + # test setter and getter + assert ds.preproc_func(1) == 1 + ds.preproc_func = lambda x: x - 1 # skipcq: PYL-W0201 + assert ds.preproc_func(1) == 0 + assert ds.preproc(1) == 1, "Must be unchanged!" + ds.preproc_func = None # skipcq: PYL-W0201 + assert ds.preproc_func(2) == 2 + + # test assign uncallable to preproc_func/postproc_func + with pytest.raises(ValueError, match=r".*callable*"): + ds.preproc_func = 1 # skipcq: PYL-W0201 From bd134b0518ed2f6c3fdaad14eecf259d11e09dd0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:42:56 +0100 Subject: [PATCH 13/36] :bug: Fix `tiatoolbox/models/dataset/classification.py` for annotations --- tiatoolbox/models/dataset/classification.py | 79 ++++++++++++++++++++- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index bdd947f20..ce54f64e8 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -1,8 +1,17 @@ """Define classes and methods for classification datasets.""" +from __future__ import annotations + +from typing import TYPE_CHECKING import PIL +from PIL import Image from torchvision import transforms +from tiatoolbox.models.dataset import dataset_abc + +if TYPE_CHECKING: # pragma: no cover + import numpy as np + class _TorchPreprocCaller: """Wrapper for applying PyTorch transforms. @@ -16,16 +25,16 @@ class _TorchPreprocCaller: """ - def __init__(self, preprocs) -> None: + def __init__(self: _TorchPreprocCaller, preprocs: list) -> None: self.func = transforms.Compose(preprocs) - def __call__(self, img): + def __call__(self: _TorchPreprocCaller, img: np.ndarray) -> Image: img = PIL.Image.fromarray(img) img = self.func(img) return img.permute(1, 2, 0) -def predefined_preproc_func(dataset_name): +def predefined_preproc_func(dataset_name: str) -> _TorchPreprocCaller: """Get the preprocessing information used for the pretrained model. Args: @@ -54,3 +63,67 @@ def predefined_preproc_func(dataset_name): preprocs = preproc_dict[dataset_name] return _TorchPreprocCaller(preprocs) + + +class PatchDataset(dataset_abc.PatchDatasetABC): + """Define PatchDataset for torch inference. + + Define a simple patch dataset, which inherits from the + `torch.utils.data.Dataset` class. + + Attributes: + inputs (list or np.ndarray): + Either a list of patches, where each patch is a ndarray or a + list of valid path with its extension be (".jpg", ".jpeg", + ".tif", ".tiff", ".png") pointing to an image. + labels (list): + List of labels for sample at the same index in `inputs`. + Default is `None`. + + Examples: + >>> # A user defined preproc func and expected behavior + >>> preproc_func = lambda img: img/2 # reduce intensity by half + >>> transformed_img = preproc_func(img) + >>> # create a dataset to get patches preprocessed by the above function + >>> ds = PatchDataset( + ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], + ... labels=["labels1", "labels2"], + ... ) + + """ + + def __init__( + self: PatchDataset, + inputs: np.ndarray | list, + labels: list | None = None, + ) -> None: + """Initialize :class:`PatchDataset`.""" + super().__init__() + + self.data_is_npy_alike = False + + self.inputs = inputs + self.labels = labels + + # perform check on the input + self._check_input_integrity(mode="patch") + + def __getitem__(self: PatchDataset, idx: int) -> dict: + """Get an item from the dataset.""" + patch = self.inputs[idx] + + # Mode 0 is list of paths + if not self.data_is_npy_alike: + patch = self.load_img(patch) + + # Apply preprocessing to selected patch + patch = self._preproc(patch) + + data = { + "image": patch, + } + if self.labels is not None: + data["label"] = self.labels[idx] + return data + + return data From 02115bfb0bf9796de2c6d21631498fa31bd14278 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:44:59 +0100 Subject: [PATCH 14/36] :bug: Fix `tiatoolbox/models/dataset/classification.py` for annotations --- tiatoolbox/models/dataset/classification.py | 66 --------------------- 1 file changed, 66 deletions(-) diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index ce54f64e8..783aca0cd 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -7,8 +7,6 @@ from PIL import Image from torchvision import transforms -from tiatoolbox.models.dataset import dataset_abc - if TYPE_CHECKING: # pragma: no cover import numpy as np @@ -63,67 +61,3 @@ def predefined_preproc_func(dataset_name: str) -> _TorchPreprocCaller: preprocs = preproc_dict[dataset_name] return _TorchPreprocCaller(preprocs) - - -class PatchDataset(dataset_abc.PatchDatasetABC): - """Define PatchDataset for torch inference. - - Define a simple patch dataset, which inherits from the - `torch.utils.data.Dataset` class. - - Attributes: - inputs (list or np.ndarray): - Either a list of patches, where each patch is a ndarray or a - list of valid path with its extension be (".jpg", ".jpeg", - ".tif", ".tiff", ".png") pointing to an image. - labels (list): - List of labels for sample at the same index in `inputs`. - Default is `None`. - - Examples: - >>> # A user defined preproc func and expected behavior - >>> preproc_func = lambda img: img/2 # reduce intensity by half - >>> transformed_img = preproc_func(img) - >>> # create a dataset to get patches preprocessed by the above function - >>> ds = PatchDataset( - ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], - ... labels=["labels1", "labels2"], - ... ) - - """ - - def __init__( - self: PatchDataset, - inputs: np.ndarray | list, - labels: list | None = None, - ) -> None: - """Initialize :class:`PatchDataset`.""" - super().__init__() - - self.data_is_npy_alike = False - - self.inputs = inputs - self.labels = labels - - # perform check on the input - self._check_input_integrity(mode="patch") - - def __getitem__(self: PatchDataset, idx: int) -> dict: - """Get an item from the dataset.""" - patch = self.inputs[idx] - - # Mode 0 is list of paths - if not self.data_is_npy_alike: - patch = self.load_img(patch) - - # Apply preprocessing to selected patch - patch = self._preproc(patch) - - data = { - "image": patch, - } - if self.labels is not None: - data["label"] = self.labels[idx] - return data - - return data From 77921c97efcc48443950f8096c33f9ecdc4cd458 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:52:59 +0100 Subject: [PATCH 15/36] :bug: Fix `tiatoolbox/models/dataset/dataset_abc.py` for annotations --- tiatoolbox/models/dataset/dataset_abc.py | 75 +++++++++++++----------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 9cc8bb96c..c2c8ef0cf 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -20,13 +20,14 @@ from multiprocessing.managers import Namespace from tiatoolbox.models.engine.io_config import IOSegmentorConfig + from tiatoolbox.typing import IntPair, Resolution, Units class PatchDatasetABC(ABC, torch.utils.data.Dataset): """Define abstract base class for patch dataset.""" def __init__( - self, + self: PatchDatasetABC, ) -> None: """Initialize :class:`PatchDatasetABC`.""" super().__init__() @@ -36,7 +37,7 @@ def __init__( self.labels = [] @staticmethod - def _check_shape_integrity(shapes): + def _check_shape_integrity(shapes: list | np.ndarray) -> None: """Checks the integrity of input shapes. Args: @@ -56,7 +57,7 @@ def _check_shape_integrity(shapes): msg = "Images must have the same dimensions." raise ValueError(msg) - def _check_input_integrity(self, mode): + def _check_input_integrity(self: PatchDatasetABC, mode: str) -> None: """Check that variables received during init are valid. These checks include: @@ -113,11 +114,15 @@ def _check_input_integrity(self, mode): raise ValueError(msg) @staticmethod - def load_img(path): + def load_img(path: str | Path) -> np.ndarray: """Load an image from a provided path. Args: - path (str): Path to an image file. + path (str or Path): Path to an image file. + + Returns: + :class:`numpy.ndarray`: + Image as a numpy array. """ path = Path(path) @@ -129,12 +134,12 @@ def load_img(path): return imread(path, as_uint8=False) @staticmethod - def preproc(image): + def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of loader.""" return image @property - def preproc_func(self): + def preproc_func(self: PatchDatasetABC) -> Callable: """Return the current pre-processing function of this instance. The returned function is expected to behave as follows: @@ -144,7 +149,7 @@ def preproc_func(self): return self._preproc @preproc_func.setter - def preproc_func(self, func): + def preproc_func(self: PatchDatasetABC, func: Callable) -> None: """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. @@ -162,12 +167,12 @@ def preproc_func(self, func): msg = f"{func} is not callable!" raise ValueError(msg) - def __len__(self) -> int: + def __len__(self: PatchDatasetABC) -> int: """Return the length of the instance attributes.""" return len(self.inputs) @abstractmethod - def __getitem__(self, idx): + def __getitem__(self: PatchDatasetABC, idx: int) -> None: """Get an item from the dataset.""" ... # pragma: no cover @@ -213,12 +218,12 @@ class WSIStreamDataset(torch_data.Dataset): """ def __init__( - self, + self: WSIStreamDataset, ioconfig: IOSegmentorConfig, wsi_paths: list[str | Path], mp_shared_space: Namespace, preproc: Callable[[np.ndarray], np.ndarray] | None = None, - mode="wsi", + mode: str = "wsi", ) -> None: """Initialize :class:`WSIStreamDataset`.""" super().__init__() @@ -240,7 +245,7 @@ def __init__( self.wsi_idx = None # to be received externally via thread communication self.reader = None - def _get_reader(self, img_path): + def _get_reader(self: WSIStreamDataset, img_path: str | Path) -> WSIReader: """Get appropriate reader for input path.""" img_path = Path(img_path) if self.mode == "wsi": @@ -261,12 +266,12 @@ def _get_reader(self, img_path): info=metadata, ) - def __len__(self) -> int: + def __len__(self: WSIStreamDataset) -> int: """Return the length of the instance attributes.""" return len(self.mp_shared_space.patch_inputs) @staticmethod - def collate_fn(batch): + def collate_fn(batch: list | np.ndarray) -> torch.Tensor: """Prototype to handle reading exception. This will exclude any sample with `None` from the batch. As @@ -278,7 +283,7 @@ def collate_fn(batch): batch = [v for v in batch if v is not None] return torch.utils.data.dataloader.default_collate(batch) - def __getitem__(self, idx: int): + def __getitem__(self: WSIStreamDataset, idx: int) -> tuple: """Get an item from the dataset.""" # ! no need to lock as we do not modify source value in shared space if self.wsi_idx != self.mp_shared_space.wsi_idx: @@ -341,18 +346,18 @@ class WSIPatchDataset(PatchDatasetABC): """ def __init__( # noqa: PLR0913, PLR0915 - self, - img_path, - mode="wsi", - mask_path=None, - patch_input_shape=None, - stride_shape=None, - resolution=None, - units=None, - min_mask_ratio=0, - preproc_func=None, + self: WSIPatchDataset, + img_path: str | Path, + mode: str = "wsi", + mask_path: str | Path | None = None, + patch_input_shape: IntPair = None, + stride_shape: IntPair = None, + resolution: Resolution = None, + units: Units = None, + min_mask_ratio: float = 0, + preproc_func: Callable | None = None, *, - auto_get_mask=True, + auto_get_mask: bool = True, ) -> None: """Create a WSI-level patch dataset. @@ -377,20 +382,20 @@ def __init__( # noqa: PLR0913, PLR0915 stride shape to read at requested `resolution` and `units`. Expected to be positive and of (height, width). Note, this is not at level 0. - resolution: + resolution (Resolution): Check (:class:`.WSIReader`) for details. When `mode='tile'`, value is fixed to be `resolution=1.0` and `units='baseline'` units: check (:class:`.WSIReader`) for details. - units: + units (Units): Units in which `resolution` is defined. - auto_get_mask: + auto_get_mask (bool): If `True`, then automatically get simple threshold mask using WSIReader.tissue_mask() function. - min_mask_ratio: + min_mask_ratio (float): Only patches with positive area percentage above this value are included. Defaults to 0. - preproc_func: + preproc_func (Callable): Preprocessing function used to transform the input data. If supplied, the function will be called on each patch before returning it. @@ -521,7 +526,7 @@ def __init__( # noqa: PLR0913, PLR0915 # Perform check on the input self._check_input_integrity(mode="wsi") - def __getitem__(self, idx): + def __getitem__(self: WSIPatchDataset, idx: int) -> dict: """Get an item from the dataset.""" coords = self.inputs[idx] # Read image patch from the whole-slide image @@ -546,11 +551,11 @@ class PatchDataset(PatchDatasetABC): `torch.utils.data.Dataset` class. Attributes: - inputs: + inputs (list or np.ndarray): Either a list of patches, where each patch is a ndarray or a list of valid path with its extension be (".jpg", ".jpeg", ".tif", ".tiff", ".png") pointing to an image. - labels: + labels (list): List of labels for sample at the same index in `inputs`. Default is `None`. From a6cd50808ecd45b5675390dbf0e98b870abcbe84 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 31 Oct 2023 11:24:18 +0000 Subject: [PATCH 16/36] :art: New `EngineABC` Design for `Patches` (#635) - New `EngineABC` design - Update `PatchPredictor` based on the new design - Move `model_to` from `utils.misc` to `model_abc` - Define `load_torch_model` in `model_abc` --------- Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: measty <20169086+measty@users.noreply.github.com> Co-authored-by: abishekrajvg --- .gitignore | 3 + tests/engines/_test_feature_extractor.py | 99 -- tests/engines/_test_multi_task_segmentor.py | 423 --------- .../_test_nucleus_instance_segmentor.py | 596 ------------ tests/engines/_test_patch_predictor.py | 763 ---------------- tests/engines/_test_semantic_segmentation.py | 854 ------------------ tests/engines/test_engine_abc.py | 418 +++++++++ tests/models/test_arch_mapde.py | 4 +- tests/models/test_arch_micronet.py | 2 +- tests/models/test_arch_nuclick.py | 3 +- tests/models/test_arch_sccnn.py | 17 +- tests/models/test_arch_unet.py | 5 +- tests/models/test_arch_vanilla.py | 6 +- tests/models/test_hovernet.py | 9 +- tests/models/test_hovernetplus.py | 3 +- .../{test_abc.py => test_models_abc.py} | 4 +- tests/test_annotation_tilerendering.py | 1 + tiatoolbox/annotation/storage.py | 15 + tiatoolbox/cli/patch_predictor.py | 2 +- tiatoolbox/models/architecture/__init__.py | 20 +- tiatoolbox/models/architecture/hovernet.py | 8 +- .../models/architecture/hovernetplus.py | 8 +- tiatoolbox/models/architecture/mapde.py | 8 +- tiatoolbox/models/architecture/micronet.py | 8 +- tiatoolbox/models/architecture/nuclick.py | 11 +- tiatoolbox/models/architecture/sccnn.py | 8 +- tiatoolbox/models/architecture/unet.py | 8 +- tiatoolbox/models/architecture/vanilla.py | 17 +- tiatoolbox/models/engine/__init__.py | 4 +- tiatoolbox/models/engine/engine_abc.py | 633 ++++++++++++- .../models/engine/multi_task_segmentor.py | 57 +- .../engine/nucleus_instance_segmentor.py | 87 +- tiatoolbox/models/engine/patch_predictor.py | 272 +++--- .../models/engine/semantic_segmentor.py | 156 ++-- tiatoolbox/models/models_abc.py | 92 +- 35 files changed, 1496 insertions(+), 3128 deletions(-) delete mode 100644 tests/engines/_test_feature_extractor.py delete mode 100644 tests/engines/_test_multi_task_segmentor.py delete mode 100644 tests/engines/_test_nucleus_instance_segmentor.py delete mode 100644 tests/engines/_test_patch_predictor.py delete mode 100644 tests/engines/_test_semantic_segmentation.py create mode 100644 tests/engines/test_engine_abc.py rename tests/models/{test_abc.py => test_models_abc.py} (96%) diff --git a/.gitignore b/.gitignore index a192542d6..16ea54a83 100644 --- a/.gitignore +++ b/.gitignore @@ -115,3 +115,6 @@ ENV/ # vim/vi generated *.swp + +# output zarr generated +*.zarr diff --git a/tests/engines/_test_feature_extractor.py b/tests/engines/_test_feature_extractor.py deleted file mode 100644 index 3315cf0c3..000000000 --- a/tests/engines/_test_feature_extractor.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Test for feature extractor.""" - -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import torch - -from tiatoolbox.models import IOSegmentorConfig -from tiatoolbox.models.architecture.vanilla import CNNBackbone -from tiatoolbox.models.engine.semantic_segmentor import DeepFeatureExtractor -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_functional(remote_sample: Callable, tmp_path: Path) -> None: - """Test for feature extraction.""" - save_dir = tmp_path / "output" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * test providing pretrained from torch vs pretrained_model.yaml - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - extractor = DeepFeatureExtractor(batch_size=1, pretrained_model="fcn-tissue_mask") - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - assert len(features.shape) == 4 - - # * test same output between full infer and engine - # pre-emptive clean up - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[256, 256], - save_resolution={"units": "mpp", "resolution": 8.0}, - ) - - model = CNNBackbone("resnet50") - extractor = DeepFeatureExtractor(batch_size=4, model=model) - # should still run because we skip exception - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - - reader = WSIReader.open(mini_wsi_svs) - patches = [ - reader.read_bounds( - positions[patch_idx], - resolution=0.25, - units="mpp", - pad_constant_values=0, - coord_space="resolution", - ) - for patch_idx in range(4) - ] - patches = np.array(patches) - patches = torch.from_numpy(patches) # NHWC - patches = patches.permute(0, 3, 1, 2) # NCHW - patches = patches.type(torch.float32) - model = model.to("cpu") - # Inference mode - model.eval() - with torch.inference_mode(): - _features = model(patches).numpy() - # ! must maintain same batch size and likely same ordering - # ! else the output values will not exactly be the same (still < 1.0e-4 - # ! of epsilon though) - assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1 diff --git a/tests/engines/_test_multi_task_segmentor.py b/tests/engines/_test_multi_task_segmentor.py deleted file mode 100644 index c3cc85cea..000000000 --- a/tests/engines/_test_multi_task_segmentor.py +++ /dev/null @@ -1,423 +0,0 @@ -"""Unit test package for HoVerNet+.""" - -import copy - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest - -from tiatoolbox.models import ( - IOInstanceSegmentorConfig, - MultiTaskSegmentor, - SemanticSegmentor, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection - -ON_GPU = toolbox_env.has_gpu() -BATCH_SIZE = 1 if not ON_GPU else 8 # 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def semantic_postproc_func(raw_output: np.ndarray) -> np.ndarray: - """Function to post process semantic segmentations. - - Post processes semantic segmentation to form one map output. - - """ - return np.argmax(raw_output, axis=-1) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for multi task segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("svs-1-small")) - save_dir = root_save_dir / "multitask" - shutil.rmtree(save_dir, ignore_errors=True) - - # * generate full output w/o parallel post-processing worker first - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_a = joblib.load(f"{output[0][1]}.0.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - assert multi_segmentor.num_postproc_workers == NUM_POSTPROC_WORKERS - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_b = joblib.load(f"{output[0][1]}.0.dat") - layer_map_b = np.load(f"{output[0][1]}.1.npy") - assert len(inst_dict_b) > 0, "Must have some nuclei" - assert layer_map_b is not None, "Must have some layers." - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - -def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - required_dims = (258, 258) - # above image is 512 x 512 at 0.252 mpp resolution. This is 258 x 258 at 0.500 mpp. - - save_dir = f"{root_save_dir}/multi/" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - layer_map = np.load(f"{output[0][1]}.1.npy") - - assert len(inst_dict) > 0, "Must have some nuclei." - assert layer_map is not None, "Must have some layers." - assert ( - layer_map.shape == required_dims - ), "Output layer map dimensions must be same as the expected output shape" - - -def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test segmentor when image is masked.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = root_save_dir / "instance" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - output = multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_functionality_process_instance_predictions( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test the functionality of instance predictions processing.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(4)] - - dummy_reference = [{i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)}] - - dummy_tiles = [np.zeros((512, 512))] - dummy_bounds = np.array([0, 0, 512, 512]) - - multi_segmentor.wsi_layers = [np.zeros_like(raw_maps[0][..., 0])] - multi_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - multi_segmentor._futures = [ - [dummy_reference, [dummy_reference[0].keys()], dummy_tiles, dummy_bounds], - ] - multi_segmentor._merge_post_process_results() - assert len(multi_segmentor._wsi_inst_info[0]) == 0 - - -def test_empty_image(tmp_path: Path) -> None: - """Test MultiTaskSegmentor for an empty image.""" - root_save_dir = Path(tmp_path) - sample_patch = np.ones((256, 256, 3), dtype="uint8") * 255 - sample_patch_path = root_save_dir / "sample_tile.png" - imwrite(sample_patch_path, sample_patch) - - save_dir = root_save_dir / "hovernetplus" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "hovernet" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "semantic" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=(2048, 2048), - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - -def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - with pytest.raises( - ValueError, - match=r"Output type must be specified for instance or semantic segmentation.", - ): - MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - save_dir = f"{root_save_dir}/multi/" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=2048, - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - multi_segmentor.model.postproc_func = semantic_postproc_func - - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - layer_map = np.load(f"{output[0][1]}.0.npy") - - assert layer_map is not None, "Must have some segmentations." - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/multi/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[512, 512], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernetplus-oed", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Crash."): - multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) diff --git a/tests/engines/_test_nucleus_instance_segmentor.py b/tests/engines/_test_nucleus_instance_segmentor.py deleted file mode 100644 index 6d3ea2f67..000000000 --- a/tests/engines/_test_nucleus_instance_segmentor.py +++ /dev/null @@ -1,596 +0,0 @@ -"""Test for Nucleus Instance Segmentor.""" - -import copy - -# ! The garbage collector -import gc -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest -import yaml -from click.testing import CliRunner - -from tiatoolbox import cli -from tiatoolbox.models import ( - IOInstanceSegmentorConfig, - NucleusInstanceSegmentor, - SemanticSegmentor, -) -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - _process_tile_predictions, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def helper_tile_info() -> list: - """Helper function for tile information.""" - predictor = NucleusInstanceSegmentor(model="A") - # ! assuming the tiles organized as follows (coming out from - # ! PatchExtractor). If this is broken, need to check back - # ! PatchExtractor output ordering first - # left to right, top to bottom - # --------------------- - # | 0 | 1 | 2 | 3 | - # --------------------- - # | 4 | 5 | 6 | 7 | - # --------------------- - # | 8 | 9 | 10 | 11 | - # --------------------- - # | 12 | 13 | 14 | 15 | - # --------------------- - # ! assume flag index ordering: left right top bottom - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - ], - margin=1, - tile_shape=(4, 4), - stride_shape=[4, 4], - patch_input_shape=[4, 4], - patch_output_shape=[4, 4], - ) - - return predictor._get_tile_info([16, 16], ioconfig) - - -# ---------------------------------------------------- - - -def test_get_tile_info() -> None: - """Test for getting tile info.""" - info = helper_tile_info() - _, flag = info[0] # index 0 should be full grid, removal - # removal flag at top edges - assert ( - np.sum( - np.nonzero(flag[:, 0]) - != np.array([4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), - ) - == 0 - ), "Fail Top" - # removal flag at bottom edges - assert ( - np.sum( - np.nonzero(flag[:, 1]) != np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), - ) - == 0 - ), "Fail Bottom" - # removal flag at left edges - assert ( - np.sum( - np.nonzero(flag[:, 2]) - != np.array([1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]), - ) - == 0 - ), "Fail Left" - # removal flag at right edges - assert ( - np.sum( - np.nonzero(flag[:, 3]) - != np.array([0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14]), - ) - == 0 - ), "Fail Right" - - -def test_vertical_boundary_boxes() -> None: - """Test for vertical boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [3, 0, 5, 4], - [7, 0, 9, 4], - [11, 0, 13, 4], - [3, 4, 5, 8], - [7, 4, 9, 8], - [11, 4, 13, 8], - [3, 8, 5, 12], - [7, 8, 9, 12], - [11, 8, 13, 12], - [3, 12, 5, 16], - [7, 12, 9, 16], - [11, 12, 13, 16], - ], - ) - _flag = np.array( - [ - [0, 1, 0, 0], - [0, 1, 0, 0], - [0, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - ], - ) - boxes, flag = info[1] - assert np.sum(_boxes - boxes) == 0, "Wrong Vertical Bounds" - assert np.sum(flag - _flag) == 0, "Fail Vertical Flag" - - -def test_horizontal_boundary_boxes() -> None: - """Test for horizontal boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [0, 3, 4, 5], - [4, 3, 8, 5], - [8, 3, 12, 5], - [12, 3, 16, 5], - [0, 7, 4, 9], - [4, 7, 8, 9], - [8, 7, 12, 9], - [12, 7, 16, 9], - [0, 11, 4, 13], - [4, 11, 8, 13], - [8, 11, 12, 13], - [12, 11, 16, 13], - ], - ) - _flag = np.array( - [ - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - ], - ) - boxes, flag = info[2] - assert np.sum(_boxes - boxes) == 0, "Wrong Horizontal Bounds" - assert np.sum(flag - _flag) == 0, "Fail Horizontal Flag" - - -def test_cross_section_boundary_boxes() -> None: - """Test for cross-section boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [2, 2, 6, 6], - [6, 2, 10, 6], - [10, 2, 14, 6], - [2, 6, 6, 10], - [6, 6, 10, 10], - [10, 6, 14, 10], - [2, 10, 6, 14], - [6, 10, 10, 14], - [10, 10, 14, 14], - ], - ) - _flag = np.array( - [ - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - ], - ) - boxes, flag = info[3] - assert np.sum(boxes - _boxes) == 0, "Wrong Cross Section Bounds" - assert np.sum(flag - _flag) == 0, "Fail Cross Section Flag" - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/instance/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - instance_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree("output", ignore_errors=True) - shutil.rmtree(save_dir, ignore_errors=True) - instance_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - instance_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_ci(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for nuclei instance segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - save_dir = f"{root_save_dir}/instance/" - - # * test run on wsi, test run with worker - # resolution for travis testing, not the correct ones - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(1024, 1024), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - inst_segmentor = NucleusInstanceSegmentor( - batch_size=1, - num_loader_workers=0, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_merge_tile_predictions_ci( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Functional tests for merging tile predictions.""" - gc.collect() # Force clean up everything on hold - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 0.5 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - # mainly to hook the merge prediction function - inst_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - raw_maps = [[v] for v in raw_maps] # mask it as patch output - - dummy_reference = {i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)} - dummy_flag_mode_list = [ - [[1, 1, 0, 0], 1], - [[0, 0, 1, 1], 2], - [[1, 1, 1, 1], 3], - [[0, 0, 0, 0], 0], - ] - - inst_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - inst_segmentor._futures = [[dummy_reference, dummy_reference.keys()]] - inst_segmentor._merge_post_process_results() - assert len(inst_segmentor._wsi_inst_info) == 0 - - blank_raw_maps = [np.zeros_like(v) for v in raw_maps] - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=dummy_flag_mode_list[0][0], - tile_mode=dummy_flag_mode_list[0][1], - tile_output=[[np.array([0, 0, 512, 512]), blank_raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - for tile_flag, tile_mode in dummy_flag_mode_list: - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=tile_mode, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - # test exception flag - tile_flag = [0, 0, 0, 0] - with pytest.raises(ValueError, match=r".*Unknown tile mode.*"): - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=-1, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for nuclei instance segmentor.""" - root_save_dir = Path(tmp_path) - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * generate full output w/o parallel post-processing worker first - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - batch_size=8, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_a = joblib.load(f"{output[0][1]}.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - assert inst_segmentor.num_postproc_workers == 2 - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_b = joblib.load(f"{output[0][1]}.dat") - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - # ** - # To evaluate the precision of doing post-processing on tile - # then re-assemble without using full image prediction maps, - # we compare its output with the output when doing - # post-processing on the entire images. - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - _, inst_dict_b = semantic_segmentor.model.postproc(raw_maps) - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.9, "Heavy loss of precision!" - - -def test_cli_nucleus_instance_segment_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for nucleus segmentation with IOConfig.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - pretrained_weights = fetch_pretrained_weights("hovernet_fast-pannuke") - - # resolution for travis testing, not the correct ones - config = { - "input_resolutions": [{"units": "mpp", "resolution": resolution}], - "output_resolutions": [ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - "margin": 128, - "tile_shape": [512, 512], - "patch_input_shape": [256, 256], - "patch_output_shape": [164, 164], - "stride_shape": [164, 164], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - - with Path.open(tmp_path / "config.yaml", "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_jpg), - "--pretrained-weights", - str(pretrained_weights), - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--mode", - "tile", - "--output-path", - str(output_path), - "--yaml-config-path", - str(tmp_path.joinpath("config.yaml")), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() - - -def test_cli_nucleus_instance_segment(remote_sample: Callable, tmp_path: Path) -> None: - """Test for nucleus segmentation.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--output-path", - str(output_path), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() diff --git a/tests/engines/_test_patch_predictor.py b/tests/engines/_test_patch_predictor.py deleted file mode 100644 index b3322635d..000000000 --- a/tests/engines/_test_patch_predictor.py +++ /dev/null @@ -1,763 +0,0 @@ -"""Test for Patch Predictor.""" -from __future__ import annotations - -import copy -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -from click.testing import CliRunner - -from tiatoolbox import cli -from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor -from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.utils import download_data, imwrite -from tiatoolbox.utils import env_detection as toolbox_env - -ON_GPU = toolbox_env.has_gpu() - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_predictor_crash(tmp_path: Path) -> None: - """Test for crash when making predictor.""" - # without providing any model - with pytest.raises(ValueError, match=r"Must provide.*"): - PatchPredictor() - - # provide wrong unknown pretrained model - with pytest.raises(ValueError, match=r"Pretrained .* does not exist"): - PatchPredictor(pretrained_model="secret_model-kather100k") - - # provide wrong model of unknown type, deprecated later with type hint - with pytest.raises(TypeError, match=r".*must be a string.*"): - PatchPredictor(pretrained_model=123) - - # test predict crash - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - with pytest.raises(ValueError, match=r".*not a valid mode.*"): - predictor.predict("aaa", mode="random", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(TypeError, match=r".*must be a list of file paths.*"): - predictor.predict("aaa", mode="wsi", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"): - predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path) - with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): - predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - - -def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: - """Test for delegating args to io config.""" - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - - # test not providing config / full input info for not pretrained models - model = CNNModel("resnet50") - predictor = PatchPredictor(model=model) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict([mini_wsi_svs], mode="wsi", save_dir=tmp_path / "dump") - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - kwargs = { - "patch_input_shape": [512, 512], - "resolution": 1.75, - "units": "mpp", - } - for key in kwargs: - _kwargs = copy.deepcopy(kwargs) - _kwargs.pop(key) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **_kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test providing config / full input info for not pretrained models - ioconfig = IOPatchPredictorConfig( - patch_input_shape=(512, 512), - stride_shape=(256, 256), - input_resolutions=[{"resolution": 1.35, "units": "mpp"}], - output_resolutions=[], - ) - predictor.predict( - [mini_wsi_svs], - ioconfig=ioconfig, - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test overwriting pretrained ioconfig - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - predictor.predict( - [mini_wsi_svs], - patch_input_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.patch_input_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - stride_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.stride_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - resolution=1.99, - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - units="baseline", - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - predictor.predict( - [mini_wsi_svs], - mode="wsi", - merge_predictions=True, - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - -def test_patch_predictor_api( - sample_patch1: Path, - sample_patch2: Path, - tmp_path: Path, -) -> None: - """Helper function to get the model output using API 1.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent reader complaint - inputs = [Path(sample_patch1), Path(sample_patch2)] - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - # don't run test on GPU - output = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == ["predictions"] - assert len(output["predictions"]) == 2 - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions"]) - assert len(output["predictions"]) == len(output["labels"]) - assert output["labels"] == [1, "a"] - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["probabilities"]) - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - # test saving output, should have no effect - _ = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir="special_dir_not_exist", - ) - assert not Path.is_dir(Path("special_dir_not_exist")) - - # test loading user weight - pretrained_weights_url = ( - "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" - ) - - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - pretrained_weights = ( - save_dir_path / "tmp_pretrained_weigths" / "resnet18-kather100k.pth" - ) - - download_data(pretrained_weights_url, pretrained_weights) - - _ = PatchPredictor( - pretrained_model="resnet18-kather100k", - pretrained_weights=pretrained_weights, - batch_size=1, - ) - - # --- test different using user model - model = CNNModel(backbone="resnet18", num_classes=9) - # test prediction - predictor = PatchPredictor(model=model, batch_size=1, verbose=False) - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - -def test_wsi_predictor_api( - sample_wsi_dict: dict, - tmp_path: Path, - chdir: Callable, -) -> None: - """Test normal run of wsi predictor.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - patch_size = np.array([224, 224]) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - save_dir = f"{save_dir_path}/model_wsi_output" - - # wrapper to make this more clean - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 1.0, - "units": "baseline", - "save_dir": save_dir, - } - # ! add this test back once the read at `baseline` is fixed - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - # remove previously generated data - shutil.rmtree(save_dir, ignore_errors=True) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 0.5, - "save_dir": save_dir, - "merge_predictions": True, # to test the api coverage - "units": "mpp", - } - - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = False - # test reading of multiple whole-slide images - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" not in output_info - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - # coverage test - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = True - # test reading of multiple whole-slide images - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - _kwargs = copy.deepcopy(kwargs) - with pytest.raises(FileExistsError): - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - # remove previously generated data - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - with chdir(save_dir_path): - # test reading of multiple whole-slide images - _kwargs = copy.deepcopy(kwargs) - _kwargs["save_dir"] = None # default coverage - _kwargs["return_probabilities"] = False - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - assert Path.exists(Path("output")) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" in output_info - assert Path(output_info["merged"]).exists() - - # remove previously generated data - shutil.rmtree("output", ignore_errors=True) - - -def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: - """Test normal run of wsi predictor with merge predictions option.""" - # convert to pathlib Path to prevent reader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - # blind test - # pseudo output dict from model with 2 patches - output = { - "resolution": 1.0, - "units": "baseline", - "probabilities": [[0.45, 0.55], [0.90, 0.10]], - "predictions": [1, 0], - "coordinates": [[0, 0, 2, 2], [2, 2, 4, 4]], - } - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - ) - _merged = np.array([[2, 2, 0, 0], [2, 2, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]) - assert np.sum(merged - _merged) == 0 - - # blind test for merging probabilities - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - return_raw=True, - ) - _merged = np.array( - [ - [0.45, 0.45, 0, 0], - [0.45, 0.45, 0, 0], - [0, 0, 0.90, 0.90], - [0, 0, 0.90, 0.90], - ], - ) - assert merged.shape == (4, 4, 2) - assert np.mean(np.abs(merged[..., 0] - _merged)) < 1.0e-6 - - # integration test - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": np.array([224, 224]), - "stride_shape": np.array([224, 224]), - "resolution": 1.0, - "units": "baseline", - "merge_predictions": True, - } - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - # mock up to change the preproc func and - # force to use the default in merge function - # still should have the same results - kwargs["merge_predictions"] = False - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - merged_tile_output = predictor.merge_predictions( - mini_wsi_jpg, - tile_output[0], - resolution=kwargs["resolution"], - units=kwargs["units"], - ) - tile_output.append(merged_tile_output) - - # first make sure nothing breaks with predictions - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - merged_wsi = wsi_output[1] - merged_tile = tile_output[1] - # ensure shape of merged predictions of tile and wsi input are the same - assert merged_wsi.shape == merged_tile.shape - # ensure consistent predictions between tile and wsi mode - diff = merged_tile == merged_wsi - accuracy = np.sum(diff) / np.size(merged_wsi) - assert accuracy > 0.9, np.nonzero(~diff) - - -def _test_predictor_output( - inputs: list, - pretrained_model: str, - probabilities_check: list | None = None, - predictions_check: list | None = None, - *, - on_gpu: bool = ON_GPU, -) -> None: - """Test the predictions of multiple models included in tiatoolbox.""" - predictor = PatchPredictor( - pretrained_model=pretrained_model, - batch_size=32, - verbose=False, - ) - # don't run test on GPU - output = predictor.predict( - inputs, - return_probabilities=True, - return_labels=False, - on_gpu=on_gpu, - ) - predictions = output["predictions"] - probabilities = output["probabilities"] - for idx, probabilities_ in enumerate(probabilities): - probabilities_max = max(probabilities_) - assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - assert predictions[idx] == predictions_check[idx], ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - - -def test_patch_predictor_kather100k_output( - sample_patch1: Path, - sample_patch2: Path, -) -> None: - """Test the output of patch prediction models on Kather100K dataset.""" - inputs = [Path(sample_patch1), Path(sample_patch2)] - pretrained_info = { - "alexnet-kather100k": [1.0, 0.9999735355377197], - "resnet18-kather100k": [1.0, 0.9999911785125732], - "resnet34-kather100k": [1.0, 0.9979840517044067], - "resnet50-kather100k": [1.0, 0.9999986886978149], - "resnet101-kather100k": [1.0, 0.9999932050704956], - "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], - "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], - "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], - "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], - "densenet121-kather100k": [1.0, 1.0], - "densenet161-kather100k": [1.0, 0.9999959468841553], - "densenet169-kather100k": [1.0, 0.9999934434890747], - "densenet201-kather100k": [1.0, 0.9999983310699463], - "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], - "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], - "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], - "googlenet-kather100k": [1.0, 0.9999639987945557], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[6, 3], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: - """Test the output of patch prediction models on PCam dataset.""" - inputs = [Path(sample_patch3), Path(sample_patch4)] - pretrained_info = { - "alexnet-pcam": [0.999980092048645, 0.9769067168235779], - "resnet18-pcam": [0.999992847442627, 0.9466130137443542], - "resnet34-pcam": [1.0, 0.9976525902748108], - "resnet50-pcam": [0.9999270439147949, 0.9999996423721313], - "resnet101-pcam": [1.0, 0.9997289776802063], - "resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802], - "resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244], - "wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799], - "wide_resnet101_2-pcam": [1.0, 0.9945427179336548], - "densenet121-pcam": [0.9999251365661621, 0.9997479319572449], - "densenet161-pcam": [0.9999969005584717, 0.9662821292877197], - "densenet169-pcam": [0.9999998807907104, 0.9993504881858826], - "densenet201-pcam": [0.9999942779541016, 0.9950824975967407], - "mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986], - "mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085], - "mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972], - "googlenet-pcam": [0.9999929666519165, 0.8701475858688354], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[1, 0], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_command_line_models_file_not_found(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI file not found error.""" - runner = CliRunner() - model_file_not_found_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs)[:-1], - "--file-types", - '"*.ndpi, *.svs"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert model_file_not_found_result.output == "" - assert model_file_not_found_result.exit_code == 1 - assert isinstance(model_file_not_found_result.exception, FileNotFoundError) - - -def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI mode not in wsi, tile.""" - runner = CliRunner() - mode_not_in_wsi_tile_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--file-types", - '"*.ndpi, *.svs"', - "--mode", - '"patch"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert "Invalid value for '--mode'" in mode_not_in_wsi_tile_result.output - assert mode_not_in_wsi_tile_result.exit_code != 0 - assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) - - -def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI single file.""" - runner = CliRunner() - models_wsi_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--mode", - "wsi", - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_wsi_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_single_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI single file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - mini_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - # Make multiple copies for test - dir_path = tmp_path.joinpath("new_copies") - dir_path.mkdir() - - dir_path_masks = tmp_path.joinpath("new_copies_masks") - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("3_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - except OSError: - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("1_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("2_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("3_" + mini_wsi_msk.name)) - - tmp_path = tmp_path.joinpath("output") - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("0.merged.npy").exists() - assert tmp_path.joinpath("0.raw.json").exists() - assert tmp_path.joinpath("1.merged.npy").exists() - assert tmp_path.joinpath("1.raw.json").exists() - assert tmp_path.joinpath("2.merged.npy").exists() - assert tmp_path.joinpath("2.raw.json").exists() - assert tmp_path.joinpath("results.json").exists() diff --git a/tests/engines/_test_semantic_segmentation.py b/tests/engines/_test_semantic_segmentation.py deleted file mode 100644 index 0bc2babde..000000000 --- a/tests/engines/_test_semantic_segmentation.py +++ /dev/null @@ -1,854 +0,0 @@ -"""Test for Semantic Segmentor.""" -from __future__ import annotations - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -import torch -import torch.multiprocessing as torch_mp -import torch.nn.functional as F # noqa: N812 -import yaml -from click.testing import CliRunner -from torch import nn - -from tiatoolbox import cli -from tiatoolbox.models import IOSegmentorConfig, SemanticSegmentor -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.architecture.utils import centre_crop -from tiatoolbox.models.engine.semantic_segmentor import WSIStreamDataset -from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imread, imwrite -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -class _CNNTo1(ModelABC): - """Contains a convolution. - - Simple model to test functionality, this contains a single - convolution layer which has weight=0 and bias=1. - - """ - - def __init__(self: _CNNTo1) -> None: - super().__init__() - self.conv = nn.Conv2d(3, 1, 3, padding=1) - self.conv.weight.data.fill_(0) - self.conv.bias.data.fill_(1) - - def forward(self: _CNNTo1, img: np.ndarray) -> torch.Tensor: - """Define how to use layer.""" - return self.conv(img) - - @staticmethod - def infer_batch( - model: nn.Module, - batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list: - """Run inference on an input batch. - - Contains logic for forward operation as well as i/o - aggregation for a single data batch. - - Args: - model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): A batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. - - """ - device = "cuda" if on_gpu else "cpu" - #### - model.eval() # infer mode - - #### - img_list = batch_data - - img_list = img_list.to(device).type(torch.float32) - img_list = img_list.permute(0, 3, 1, 2) # to NCHW - - hw = np.array(img_list.shape[2:]) - with torch.inference_mode(): # do not compute gradient - logit_list = model(img_list) - logit_list = centre_crop(logit_list, hw // 2) - logit_list = logit_list.permute(0, 2, 3, 1) # to NHWC - prob_list = F.relu(logit_list) - - prob_list = prob_list.cpu().numpy() - return [prob_list] - - -# ------------------------------------------------------------------------------------- -# IOConfig -# ------------------------------------------------------------------------------------- - - -def test_segmentor_ioconfig() -> None: - """Test for IOConfig.""" - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - assert ioconfig.highest_input_resolution == {"units": "mpp", "resolution": 0.25} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 1.0 - assert ioconfig.input_resolutions[1]["resolution"] == 0.5 - assert ioconfig.input_resolutions[2]["resolution"] == 1 / 3 - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - output_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - save_resolution={"units": "power", "resolution": 8.0}, - ) - assert ioconfig.highest_input_resolution == {"units": "power", "resolution": 40} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 0.5 - assert ioconfig.input_resolutions[1]["resolution"] == 1.0 - assert ioconfig.save_resolution["resolution"] == 8.0 / 40.0 - - resolutions = [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ] - with pytest.raises(ValueError, match=r".*Unknown units.*"): - ioconfig.scale_to_highest(resolutions, "axx") - - -# ------------------------------------------------------------------------------------- -# Dataset -# ------------------------------------------------------------------------------------- - - -def test_functional_wsi_stream_dataset(remote_sample: Callable) -> None: - """Functional test for WSIStreamDataset.""" - gc.collect() # Force clean up everything on hold - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - mp_manager = torch_mp.Manager() - mp_shared_space = mp_manager.Namespace() - - sds = WSIStreamDataset(ioconfig, [mini_wsi_svs], mp_shared_space) - # test for collate - out = sds.collate_fn([None, 1, 2, 3]) - assert np.sum(out.numpy() != np.array([1, 2, 3])) == 0 - - # artificial data injection - mp_shared_space.wsi_idx = torch.tensor(0) # a scalar - mp_shared_space.patch_inputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - mp_shared_space.patch_outputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - # test read - for _, sample in enumerate(sds): - patch_data, _ = sample - (patch_resolution1, patch_resolution2, patch_resolution3) = patch_data - assert np.round(patch_resolution1.shape[0] / patch_resolution2.shape[0]) == 2 - assert np.round(patch_resolution1.shape[0] / patch_resolution3.shape[0]) == 3 - - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_crash_segmentor(remote_sample: Callable) -> None: - """Functional crash tests for segmentor.""" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk")) - - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - # fake injection to trigger Segmentor to create parallel - # post-processing workers because baseline Semantic Segmentor does not support - # post-processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # * test basic crash - shutil.rmtree("output", ignore_errors=True) # default output dir test - with pytest.raises(TypeError, match=r".*`mask_reader`.*"): - semantic_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) - with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): - semantic_segmentor.filter_coordinates( - WSIReader.open(mini_wsi_msk), - np.array([1.0, 2.0]), - ) - semantic_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True) - with pytest.raises(ValueError, match=r".*must be a valid file path.*"): - semantic_segmentor.get_reader( - mini_wsi_msk, - "not_exist", - "wsi", - auto_get_mask=True, - ) - - shutil.rmtree("output", ignore_errors=True) # default output dir test - with pytest.raises(ValueError, match=r".*provide.*"): - SemanticSegmentor() - with pytest.raises(ValueError, match=r".*valid mode.*"): - semantic_segmentor.predict([], mode="abc") - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - with pytest.raises(ValueError, match=r".*already exists.*"): - semantic_segmentor.predict([], mode="tile", patch_input_shape=(2048, 2048)) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * Test crash propagation when parallelize post-processing - semantic_segmentor.num_postproc_workers = 2 - semantic_segmentor.model.forward = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - resolution=1.0, - units="baseline", - ) - - shutil.rmtree("output", ignore_errors=True) - - with pytest.raises(ValueError, match=r"Invalid resolution.*"): - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) - # test ignore crash - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=False, - resolution=1.0, - units="baseline", - ) - shutil.rmtree("output", ignore_errors=True) - - -def test_functional_segmentor_merging(tmp_path: Path) -> None: - """Functional test for assmebling output.""" - save_dir = Path(tmp_path) - - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - # predictions with HW - _output = np.array( - [ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 2, 2], - [0, 0, 2, 2], - ], - ) - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # a second rerun to test overlapping count, - # should still maintain same result - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # else will leave hanging file pointer - # and hence cant remove its folder later - del canvas # skipcq - - # * predictions with HWC - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - _ = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - del _ # skipcq - - # * test crashing when switch to image having larger - # * shape but still provide old links - semantic_segmentor.merge_prediction( - [8, 8], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - with pytest.raises(ValueError, match=r".*`save_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.py", - ) - - with pytest.raises(ValueError, match=r".*`cache_count_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - # * test non HW predictions - with pytest.raises(ValueError, match=r".*Prediction is no HW or HWC.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2,), 1), np.full((2,), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - - # * with an out of bound location - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [ - np.full((2, 2), 1), - np.full((2, 2), 2), - np.full((2, 2), 3), - np.full((2, 2), 4), - ], - [[0, 0, 2, 2], [2, 2, 4, 4], [0, 4, 2, 6], [4, 0, 6, 2]], - save_path=None, - ) - assert np.sum(canvas - _output) < 1.0e-8 - del canvas # skipcq - - -def test_functional_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Functional test for segmentor.""" - save_dir = tmp_path / "dump" - # # convert to pathlib Path to prevent wsireader complaint - resolution = 2.0 - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - mini_wsi_msk = f"{tmp_path}/mini_mask.jpg" - imwrite(mini_wsi_msk, (thumb > 0).astype(np.uint8)) - - # preemptive clean up - shutil.rmtree("output", ignore_errors=True) # default output dir test - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - # fake injection to trigger Segmentor to create parallel - # post-processing workers because baseline Semantic Segmentor does not support - # post-processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # should still run because we skip exception - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - resolution=resolution, - units="mpp", - crash_on_exception=False, - ) - - shutil.rmtree("output", ignore_errors=True) # default output dir test - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * check exception bypass in the log - # there should be no exception, but how to check the log? - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - patch_output_shape=(512, 512), - stride_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=False, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * test basic running and merging prediction - # * should dumping all 1 in the output - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "baseline", "resolution": 1.0}], - output_resolutions=[{"units": "baseline", "resolution": 1.0}], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - file_list = [ - mini_wsi_jpg, - mini_wsi_jpg, - ] - output_list = semantic_segmentor.predict( - file_list, - mode="tile", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - pred_2 = np.load(output_list[1][1] + ".raw.0.npy") - assert len(output_list) == 2 - assert np.sum(pred_1 - pred_2) == 0 - # due to overlapping merge and division, will not be - # exactly 1, but should be approximately so - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # * test running with mask and svs - # * also test merging prediction at designated resolution - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[{"units": "mpp", "resolution": resolution}], - save_resolution={"units": "mpp", "resolution": resolution}, - patch_input_shape=[512, 512], - patch_output_shape=[256, 256], - stride_shape=[512, 512], - ) - shutil.rmtree(save_dir, ignore_errors=True) - output_list = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - reader = WSIReader.open(mini_wsi_svs) - expected_shape = reader.slide_dimensions(**ioconfig.save_resolution) - expected_shape = np.array(expected_shape)[::-1] # to YX - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - saved_shape = np.array(pred_1.shape[:2]) - assert np.sum(expected_shape - saved_shape) == 0 - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # check normal run with auto get mask - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - model=model, - auto_generate_mask=True, - ) - _ = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -def test_subclass(remote_sample: Callable, tmp_path: Path) -> None: - """Create subclass and test parallel processing setup.""" - save_dir = Path(tmp_path) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - model = _CNNTo1() - - class XSegmentor(SemanticSegmentor): - """Dummy class to test subclassing.""" - - def __init__(self: XSegmentor) -> None: - super().__init__(model=model) - self.num_postproc_worker = 2 - - semantic_segmentor = XSegmentor() - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(1024, 1024), - patch_output_shape=(512, 512), - stride_shape=(256, 256), - resolution=1.0, - units="baseline", - crash_on_exception=False, - save_dir=save_dir / "raw", - ) - - -# specifically designed for travis -def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None: - """Test for load up pretrained and over-writing tile mode ioconfig.""" - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=1.0, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - assert save_dir.joinpath("raw/0.raw.0.npy").exists() - assert save_dir.joinpath("raw/file_map.dat").exists() - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - wsi_with_artifacts = Path(remote_sample("wsi3_20k_20k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [wsi_with_artifacts], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("wsi3_20k_20k_pred"))) - _test_pred = np.load(str(save_dir / "raw" / "0.raw.0.npy")) - _test_pred = (_test_pred[..., 1] > 0.75) * 255 - # divide 255 to binarize - assert np.mean(_cache_pred[..., 0] == _test_pred) > 0.99 - - shutil.rmtree(save_dir, ignore_errors=True) - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=True, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_bcss_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - - wsi_breast = Path(remote_sample("wsi4_4k_4k_svs")) - semantic_segmentor = SemanticSegmentor( - num_loader_workers=4, - batch_size=BATCH_SIZE, - pretrained_model="fcn_resnet50_unet-bcss", - ) - semantic_segmentor.predict( - [wsi_breast], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = np.load(Path(remote_sample("wsi4_4k_4k_pred"))) - _test_pred = np.load(f"{save_dir}/raw/0.raw.0.npy") - _test_pred = np.argmax(_test_pred, axis=-1) - assert np.mean(np.abs(_cache_pred - _test_pred)) < 1.0e-2 - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_cli_semantic_segment_out_exists_error( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation if output path exists.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path, - ], - ) - - assert semantic_segment_result.output == "" - assert semantic_segment_result.exit_code == 1 - assert isinstance(semantic_segment_result.exception, FileExistsError) - - -def test_cli_semantic_segmentation_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation single file custom ioconfig.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") - - config = { - "input_resolutions": [{"units": "mpp", "resolution": 2.0}], - "output_resolutions": [{"units": "mpp", "resolution": 2.0}], - "patch_input_shape": [1024, 1024], - "patch_output_shape": [512, 512], - "stride_shape": [256, 256], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - with Path.open(tmp_path.joinpath("config.yaml"), "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--pretrained-weights", - str(pretrained_weights), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path.joinpath("output"), - "--yaml-config-path", - tmp_path.joinpath("config.yaml"), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("output/0.raw.0.npy").exists() - assert tmp_path.joinpath("output/file_map.dat").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_semantic_segmentation_multi_file( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path / "small_svs_tissue_mask.jpg" - - # Make multiple copies for test - dir_path = tmp_path / "new_copies" - dir_path.mkdir() - - dir_path_masks = tmp_path / "new_copies_masks" - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - dir_path_masks.joinpath("2_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - except OSError: - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("1_" + sample_wsi_msk.name)) - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("2_" + sample_wsi_msk.name)) - - tmp_path = tmp_path / "output" - - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("0.raw.0.npy").exists() - assert tmp_path.joinpath("1.raw.0.npy").exists() - assert tmp_path.joinpath("file_map.dat").exists() - assert tmp_path.joinpath("results.json").exists() - - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("small_svs_tissue_mask"))) - _test_pred = np.load(str(tmp_path.joinpath("0.raw.0.npy"))) - _test_pred = (_test_pred[..., 1] > 0.50) * 255 - - assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3 diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py new file mode 100644 index 000000000..8ab54098f --- /dev/null +++ b/tests/engines/test_engine_abc.py @@ -0,0 +1,418 @@ +"""Test tiatoolbox.models.engine.engine_abc.""" +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, NoReturn + +import numpy as np +import pytest +import torchvision.models as torch_models + +from tiatoolbox.models.architecture import ( + fetch_pretrained_weights, + get_pretrained_model, +) +from tiatoolbox.models.architecture.vanilla import CNNModel +from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir +from tiatoolbox.models.engine.io_config import ModelIOConfigABC + +if TYPE_CHECKING: + import torch.nn + + +class TestEngineABC(EngineABC): + """Test EngineABC.""" + + def __init__( + self: TestEngineABC, + model: str | torch.nn.Module, + weights: str | Path | None = None, + verbose: bool | None = None, + ) -> NoReturn: + """Test EngineABC init.""" + super().__init__(model=model, weights=weights, verbose=verbose) + + def infer_wsi(self: EngineABC) -> NoReturn: + """Test infer_wsi.""" + ... # dummy function for tests. + + def post_process_wsi(self: EngineABC) -> NoReturn: + """Test post_process_wsi.""" + ... # dummy function for tests. + + def pre_process_wsi(self: EngineABC) -> NoReturn: + """Test pre_process_wsi.""" + ... # dummy function for tests. + + +def test_engine_abc() -> NoReturn: + """Test EngineABC initialization.""" + with pytest.raises( + TypeError, + match=r".*Can't instantiate abstract class EngineABC with abstract methods*", + ): + # Can't instantiate abstract class with abstract methods + EngineABC() # skipcq + + +def test_engine_abc_incorrect_model_type() -> NoReturn: + """Test EngineABC initialization with incorrect model type.""" + with pytest.raises( + TypeError, + match=r".*missing 1 required positional argument: 'model'", + ): + TestEngineABC() # skipcq + + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + TestEngineABC(model=1) + + +def test_incorrect_ioconfig() -> NoReturn: + """Test EngineABC initialization with incorrect ioconfig.""" + model = torch_models.resnet18() + engine = TestEngineABC(model=model) + with pytest.raises( + ValueError, + match=r".*provide a valid ModelIOConfigABC.*", + ): + engine.run(images=[], masks=[], ioconfig=None) + + +def test_pretrained_ioconfig() -> NoReturn: + """Test EngineABC initialization with pretrained model name in the toolbox.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + ) + assert "predictions" in out + assert "labels" not in out + + +def test_ioconfig() -> NoReturn: + """Test EngineABC initialization with valid ioconfig.""" + ioconfig = ModelIOConfigABC( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + ], + patch_input_shape=(224, 224), + ) + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + ioconfig=ioconfig, + ) + + assert "predictions" in out + assert "labels" not in out + + +def test_prepare_engines_save_dir( + tmp_path: pytest.TempPathFactory, + caplog: pytest.LogCaptureFixture, +) -> NoReturn: + """Test prepare save directory for engines.""" + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + len_images=1, + overwrite=False, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + len_images=1, + overwrite=True, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + out_dir = prepare_engines_save_dir( + save_dir=None, + patch_mode=True, + len_images=1, + overwrite=False, + ) + assert out_dir is None + + with pytest.raises( + OSError, + match=r".*More than 1 WSIs detected but there is no save directory provided.*", + ): + _ = prepare_engines_save_dir( + save_dir=None, + patch_mode=False, + len_images=2, + overwrite=False, + ) + + out_dir = prepare_engines_save_dir( + save_dir=None, + patch_mode=False, + len_images=1, + overwrite=False, + ) + + assert out_dir == Path.cwd() + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_single_output", + patch_mode=False, + len_images=1, + overwrite=False, + ) + + assert out_dir == tmp_path / "wsi_single_output" + assert out_dir.exists() + assert r"When providing multiple whole-slide images / tiles" not in caplog.text + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_multiple_output", + patch_mode=False, + len_images=2, + overwrite=False, + ) + + assert out_dir == tmp_path / "wsi_multiple_output" + assert out_dir.exists() + assert r"When providing multiple whole slide images" in caplog.text + + # test for file overwrite with Path.mkdirs() method + out_path = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output" / "output.zarr", + patch_mode=True, + len_images=1, + overwrite=True, + ) + assert out_path.exists() + + out_path = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output" / "output.zarr", + patch_mode=True, + len_images=1, + overwrite=True, + ) + assert out_path.exists() + + with pytest.raises(FileExistsError): + out_path = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output" / "output.zarr", + patch_mode=True, + len_images=1, + overwrite=False, + ) + + +def test_engine_initalization() -> NoReturn: + """Test engine initialization.""" + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + _ = TestEngineABC(model=0) + + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) + model = CNNModel("alexnet", num_classes=1) + eng = TestEngineABC(model=model) + assert isinstance(eng, EngineABC) + + model = get_pretrained_model("alexnet-kather100k")[0] + weights_path = fetch_pretrained_weights("alexnet-kather100k") + eng = TestEngineABC(model=model, weights=weights_path) + assert isinstance(eng, EngineABC) + + +def test_engine_run() -> NoReturn: + """Test engine run.""" + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*The input numpy array should be four dimensional.*", + ): + eng.run(images=np.zeros((10, 10))) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + TypeError, + match=r"Input must be a list of file paths or a numpy array.", + ): + eng.run(images=1) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*len\(labels\) is not equal to len(images)*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(1)), + on_gpu=False, + ) + + with pytest.raises( + ValueError, + match=r".*len\(masks\) is not equal to len(images)*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((1, 224, 224, 3)), + on_gpu=False, + ) + + with pytest.raises( + ValueError, + match=r".*The shape of the numpy array should be NHWC*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((10, 3)), + on_gpu=False, + ) + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ) + assert "predictions" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + ) + assert "predictions" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + assert "predictions" in out + assert "labels" in out + + +def test_engine_run_with_verbose() -> NoReturn: + """Test engine run with verbose.""" + """Run pytest with `-rP` option to view progress bar on the captured stderr call""" + + eng = TestEngineABC(model="alexnet-kather100k", verbose=True) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + + assert "predictions" in out + assert "labels" in out + + +def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: + """Test the engine run and patch pred store.""" + save_dir = tmp_path / "patch_output" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + """ test custom zarr output file name""" + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_file="patch_pred_output", + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*Patch output must contain coordinates.", + ): + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_type="AnnotationStore", + ) + + with pytest.raises( + ValueError, + match=r".*Patch output must contain coordinates.", + ): + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_type="AnnotationStore", + class_dict={0: "class0", 1: "class1"}, + ) + + with pytest.raises( + ValueError, + match=r".*Patch output must contain coordinates.", + ): + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_type="AnnotationStore", + scale_factor=(2.0, 2.0), + ) diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index df60d3b47..f0142406d 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -44,7 +44,7 @@ def test_functionality(remote_sample: Callable) -> None: model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - model = model.to(select_device(on_gpu=ON_GPU)) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + model = model.to() + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index cd4bd0833..e7aa23d5b 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -39,7 +39,7 @@ def test_functionality( model = model.to(map_location) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=map_location) output, _ = model.postproc(output[0]) assert np.max(np.unique(output)) == 46 diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py index fda0c01a6..b84516125 100644 --- a/tests/models/test_arch_nuclick.py +++ b/tests/models/test_arch_nuclick.py @@ -10,6 +10,7 @@ from tiatoolbox.models import NuClick from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device ON_GPU = False @@ -53,7 +54,7 @@ def test_functional_nuclick( model = NuClick(num_input_channels=5, num_output_channels=1) pretrained = torch.load(weights_path, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) postproc_masks = model.postproc( output, do_reconstruction=True, diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index bdec99e0b..58d3f67d0 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -4,9 +4,10 @@ import numpy as np import torch -from tiatoolbox import utils from tiatoolbox.models import SCCNN from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.utils import env_detection +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -14,7 +15,7 @@ def _load_sccnn(name: str) -> torch.nn.Module: """Loads SCCNN model with specified weights.""" model = SCCNN() weights_path = fetch_pretrained_weights(name) - map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu()) + map_location = select_device(on_gpu=env_detection.has_gpu()) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) @@ -39,11 +40,19 @@ def test_functionality(remote_sample: Callable) -> None: ) batch = torch.from_numpy(patch)[None] model = _load_sccnn(name="sccnn-crchisto") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) model = _load_sccnn(name="sccnn-conic") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[7, 8]]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index f15a5dc71..69496c7aa 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -8,6 +8,7 @@ from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.unet import UNetModel +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = False @@ -47,7 +48,7 @@ def test_functional_unet(remote_sample: Callable) -> None: model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3]) pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) _ = output[0] # run untrained network to test for architecture @@ -59,4 +60,4 @@ def test_functional_unet(remote_sample: Callable) -> None: encoder_levels=[32, 64], skip_type="concat", ) - _ = model.infer_batch(model, batch, on_gpu=ON_GPU) + _ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index a2b1ac5c9..cfae665b2 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -6,9 +6,11 @@ from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.models_abc import model_to +from tiatoolbox.utils.misc import select_device ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator +device = "cuda" if ON_GPU else "cpu" def test_functional() -> None: @@ -43,8 +45,8 @@ def test_functional() -> None: try: for backbone in backbones: model = CNNModel(backbone, num_classes=1) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU)) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index bf77b46ba..dcf2251ac 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -14,6 +14,7 @@ ResidualBlock, TFSamepaddingLayer, ) +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-pannuke") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-monusac") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-consep") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-kumar") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index 96d0f9d23..1377fdd82 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -7,6 +7,7 @@ from tiatoolbox.models import HoVerNetPlus from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.transforms import imresize @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernetplus-oed") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." output = [v[0] for v in output] output = model.postproc(output) diff --git a/tests/models/test_abc.py b/tests/models/test_models_abc.py similarity index 96% rename from tests/models/test_abc.py rename to tests/models/test_models_abc.py index 3537735ce..635b13be1 100644 --- a/tests/models/test_abc.py +++ b/tests/models/test_models_abc.py @@ -124,9 +124,9 @@ def test_model_to() -> None: if not utils.env_detection.has_gpu(): model = torch_models.resnet18() with pytest.raises((AssertionError, RuntimeError)): - _ = tiatoolbox.models.models_abc.model_to(on_gpu=True, model=model) + _ = tiatoolbox.models.models_abc.model_to(device="cuda", model=model) # Test on CPU model = torch_models.resnet18() - model = tiatoolbox.models.models_abc.model_to(on_gpu=False, model=model) + model = tiatoolbox.models.models_abc.model_to(device="cpu", model=model) assert isinstance(model, nn.Module) diff --git a/tests/test_annotation_tilerendering.py b/tests/test_annotation_tilerendering.py index 4d4651f06..1b2dc8826 100644 --- a/tests/test_annotation_tilerendering.py +++ b/tests/test_annotation_tilerendering.py @@ -460,6 +460,7 @@ def test_function_mapper(fill_store: Callable, tmp_path: Path) -> None: _, store = fill_store(SQLiteStore, tmp_path / "test.db") def color_fn(props: dict[str, str]) -> tuple[int, int, int]: + """Tests Red for cells, otherwise green.""" # simple test function that returns red for cells, otherwise green. if props["type"] == "cell": return 1, 0, 0 diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py index 3440b5aec..4137b4f67 100644 --- a/tiatoolbox/annotation/storage.py +++ b/tiatoolbox/annotation/storage.py @@ -2316,6 +2316,21 @@ def _unpack_wkb( cx: float, cy: float, ) -> bytes: + """Return the geometry as bytes using WKB. + + Args: + data (bytes or str): + The WKB/WKT data to be unpacked. + cx (int): + The X coordinate of the centroid/representative point. + cy (float): + The Y coordinate of the centroid/representative point. + + Returns: + bytes: + The geometry as bytes. + + """ return ( self._decompress_data(data) if data diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index 2c754d1c6..8c6128e8c 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -83,7 +83,7 @@ def patch_predictor( predictor = PatchPredictor( pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + weights=pretrained_weights, batch_size=batch_size, num_loader_workers=num_loader_workers, verbose=verbose, diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index d37ad5c80..7776cdb60 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -1,21 +1,20 @@ """Define a set of models to be used within tiatoolbox.""" from __future__ import annotations -import os from pydoc import locate -from typing import TYPE_CHECKING, Optional, Union - -import torch +from typing import TYPE_CHECKING from tiatoolbox import rcParam from tiatoolbox.models.dataset.classification import predefined_preproc_func +from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils import download_data if TYPE_CHECKING: # pragma: no cover from pathlib import Path - from tiatoolbox.models.models_abc import IOConfigABC + import torch + from tiatoolbox.models.engine.io_config import ModelIOConfigABC __all__ = ["get_pretrained_model", "fetch_pretrained_weights"] PRETRAINED_INFO = rcParam["pretrained_model_info"] @@ -63,7 +62,7 @@ def get_pretrained_model( pretrained_weights: str | Path | None = None, *, overwrite: bool = False, -) -> tuple[torch.nn.Module, IOConfigABC]: +) -> tuple[torch.nn.Module, ModelIOConfigABC]: """Load a predefined PyTorch model with the appropriate pretrained weights. Args: @@ -143,14 +142,11 @@ def get_pretrained_model( overwrite=overwrite, ) - # ! assume to be saved in single GPU mode - # always load on to the CPU - saved_state_dict = torch.load(pretrained_weights, map_location="cpu") - model.load_state_dict(saved_state_dict, strict=True) + model = load_torch_model(model=model, weights=pretrained_weights) # ! io_info = info["ioconfig"] creator = locate(f"tiatoolbox.models.engine.{io_info['class']}") - iostate = creator(**io_info["kwargs"]) - return model, iostate + ioconfig = creator(**io_info["kwargs"]) + return model, ioconfig diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index cad29fe83..216e06ee5 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -19,7 +19,6 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc from tiatoolbox.utils.misc import get_bounding_box @@ -781,7 +780,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: return pred_inst, nuc_inst_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -793,8 +792,8 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: tuple: @@ -806,7 +805,6 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 59135a350..ddcce67ea 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -12,7 +12,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.architecture.utils import UpSample2x -from tiatoolbox.utils import misc class HoVerNetPlus(HoVerNet): @@ -320,7 +319,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -332,13 +331,12 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index 21c588c29..863ce985d 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -13,7 +13,6 @@ from skimage.feature import peak_local_max from tiatoolbox.models.architecture.micronet import MicroNet -from tiatoolbox.utils.misc import select_device class MapDe(MicroNet): @@ -258,7 +257,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -271,8 +270,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -281,7 +280,6 @@ def infer_batch( """ patch_imgs = batch_data - device = select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index 69daa120f..c18e51e6b 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -18,7 +18,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc def group1_forward_branch( @@ -628,7 +627,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -641,8 +640,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -651,7 +650,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/nuclick.py b/tiatoolbox/models/architecture/nuclick.py index 85a759bb6..cb5f52509 100644 --- a/tiatoolbox/models/architecture/nuclick.py +++ b/tiatoolbox/models/architecture/nuclick.py @@ -21,7 +21,6 @@ from tiatoolbox import logger from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc if TYPE_CHECKING: # pragma: no cover from tiatoolbox.typing import IntPair @@ -646,7 +645,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> np.ndarray: """Run inference on an input batch. @@ -655,16 +654,16 @@ def infer_batch( Args: model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): a batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. + batch_data (torch.Tensor): + A batch of data generated by torch.utils.data.DataLoader. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w). """ model.eval() - device = misc.select_device(on_gpu=on_gpu) # Assume batch_data is NCHW batch_data = batch_data.to(device).type(torch.float32) diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py index bbeb58094..bdb8926e3 100644 --- a/tiatoolbox/models/architecture/sccnn.py +++ b/tiatoolbox/models/architecture/sccnn.py @@ -16,7 +16,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class SCCNN(ModelABC): @@ -354,7 +353,7 @@ def infer_batch( model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -367,8 +366,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray` or :class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list of :class:`numpy.ndarray`: @@ -377,7 +376,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index 8f628fb52..7e2e35c02 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -11,7 +11,6 @@ from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class ResNetEncoder(ResNet): @@ -415,7 +414,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list: """Run inference on an input batch. @@ -428,8 +427,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list: @@ -438,7 +437,6 @@ def infer_batch( """ model.eval() - device = misc.select_device(on_gpu=on_gpu) #### imgs = batch_data diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 5855971d5..2ecbd5b86 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -9,7 +9,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import select_device if TYPE_CHECKING: # pragma: no cover from torchvision.models import WeightsEnum @@ -142,7 +141,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str = "cpu", ) -> np.ndarray: """Run inference on an input batch. @@ -154,11 +153,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() @@ -239,7 +238,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray, ...]: """Run inference on an input batch. @@ -251,11 +250,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 0a5968b44..7d0dfe0e1 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -1,11 +1,13 @@ """Engines to run models implemented in tiatoolbox.""" -from tiatoolbox.models.engine import ( +from . import ( + engine_abc, nucleus_instance_segmentor, patch_predictor, semantic_segmentor, ) __all__ = [ + "engine_abc", "nucleus_instance_segmentor", "patch_predictor", "semantic_segmentor", diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 69d66af73..7e7c71242 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,21 +1,636 @@ """Defines Abstract Base Class for TIAToolbox Model Engines.""" +from __future__ import annotations + from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, NoReturn + +import numpy as np +import torch +import tqdm +from torch import nn + +from tiatoolbox import logger +from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.dataset.dataset_abc import PatchDataset +from tiatoolbox.models.models_abc import load_torch_model, model_to +from tiatoolbox.utils.misc import dict_to_store, dict_to_zarr + +if TYPE_CHECKING: # pragma: no cover + import os + + from torch.utils.data import DataLoader + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.wsicore.wsireader import WSIReader + + from .io_config import ModelIOConfigABC + + +def prepare_engines_save_dir( + save_dir: os | Path | None, + len_images: int, + *, + patch_mode: bool, + overwrite: bool, +) -> Path | None: + """Create directory if not defined and number of images is more than 1. + + Args: + save_dir (str or Path): + Path to output directory. + len_images (int): + List of inputs to process. + patch_mode(bool): + Whether to treat input image as a patch or WSI. + overwrite (bool): + Whether to overwrite the results. Default = False. + + Returns: + :class:`Path`: + Path to output directory. + + """ + if patch_mode is True: + if save_dir is not None: + save_dir.mkdir(parents=True, exist_ok=overwrite) + return save_dir + + if save_dir is None: + if len_images > 1: + msg = ( + "More than 1 WSIs detected but there is no save directory provided." + "Please provide a 'save_dir'." + ) + raise OSError(msg) + return ( + Path.cwd() + ) # save the output to current working directory and return save_dir + + if len_images > 1: + logger.info( + "When providing multiple whole slide images, " + "the outputs will be saved and the locations of outputs " + "will be returned to the calling function.", + ) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=overwrite) + + return save_dir class EngineABC(ABC): - """Abstract base class for engines used in tiatoolbox.""" + """Abstract base class for engines used in tiatoolbox. + + Args: + model (str | nn.Module): + A PyTorch model. Default is `None`. + The user can request pretrained models from the toolbox using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = EngineABC( + ... model="pretrained-model-name", + ... weights="pretrained-local-weights.pth") + + batch_size (int): + Number of images fed into the model each time. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. default = 0 + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. default = 0 + device (str): + Select the device to run the model. Default is "cpu". + verbose (bool): + Whether to output logging information. + + Attributes: + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A NHWC image or a path to WSI. + patch_mode (str): + Whether to treat input image as a patch or WSI. + default = True. + model (str | nn.Module): + Defined PyTorch model. + Name of an existing model supported by the TIAToolbox for + processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration to run the Engine. + _ioconfig (): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + merge_predictions (bool): + Whether to merge the predictions to form a 2-dimensional + map. This is only applicable if `patch_mode` is False in inference. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + batch_size (int): + Number of images fed into the model each time. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + Select the device to run the model. Default is "cpu". + num_loader_workers (int): + Number of workers used in torch.utils.data.DataLoader. + verbose (bool): + Whether to output logging information. + + Examples: + >>> # array of list of 2 image patches as input + >>> import numpy as np + >>> data = np.array([np.ndarray, np.ndarray]) + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(data, patch_mode=True) + + >>> # array of list of 2 image patches as input + >>> import numpy as np + >>> data = np.array([np.ndarray, np.ndarray]) + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(data, patch_mode=True) + + >>> # list of 2 image files as input + >>> image = ['path/image1.png', 'path/image2.png'] + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(image, patch_mode=False) + + >>> # list of 2 wsi files as input + >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(wsi_file, patch_mode=True) - def __init__(self) -> None: + """ + + def __init__( + self: EngineABC, + model: str | nn.Module, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = False, + ) -> None: """Initialize Engine.""" super().__init__() + self.masks = None + self.images = None + self.patch_mode = None + self.device = device + + # Initialize model with specified weights and ioconfig. + self.model, self.ioconfig = self._initialize_model_ioconfig( + model=model, + weights=weights, + ) + self.model = model_to(model=self.model, device=self.device) + self._ioconfig = self.ioconfig # runtime ioconfig + + self.batch_size = batch_size + self.num_loader_workers = num_loader_workers + self.num_post_proc_workers = num_post_proc_workers + self.verbose = verbose + self.return_labels = False + self.merge_predictions = False + self.units = "baseline" + self.resolution = 1.0 + self.patch_input_shape = None + self.stride_shape = None + self.labels = None + + @staticmethod + def _initialize_model_ioconfig( + model: str | nn.Module, + weights: str | Path | None, + ) -> tuple[nn.Module, ModelIOConfigABC | None]: + """Helper function to initialize model and ioconfig attributes. + + If a pretrained model provided by the TIAToolbox is requested. The model + can be specified as a string otherwise torch.nn.Module is required. + This function also loads the :class:`ModelIOConfigABC` using the information + from the pretrained models in TIAToolbox. If ioconfig is not available then it + should be provided in the :func:`run` function. + + Args: + model (str | nn.Module): + A torch model which should be run by the engine. + + weights (str | Path | None): + Path to pretrained weights. If no pretrained weights are provided + and the model is provided by TIAToolbox, then pretrained weights will + be automatically loaded from the TIA servers. + + Returns: + nn.Module: + The requested PyTorch model. + + ModelIOConfigABC | None: + The model io configuration for TIAToolbox pretrained models. + Otherwise, None. + + """ + if not isinstance(model, (str, nn.Module)): + msg = "Input model must be a string or 'torch.nn.Module'." + raise TypeError(msg) + + if isinstance(model, str): + # ioconfig is retrieved from the pretrained model in the toolbox. + # list of pretrained models in the TIA Toolbox is available here: + # https://tia-toolbox.readthedocs.io/en/add-bokeh-app/pretrained.html + # no need to provide ioconfig in EngineABC.run() this case. + return get_pretrained_model(model, weights) + + if weights is not None: + model = load_torch_model(model=model, weights=weights) + + return model, None + + def pre_process_patches( + self: EngineABC, + images: np.ndarray | list, + labels: list, + ) -> torch.utils.data.DataLoader: + """Pre-process an image patch.""" + if labels: + # if a labels is provided, then return with the prediction + self.return_labels = bool(labels) + + dataset = PatchDataset(inputs=images, labels=labels) + dataset.preproc_func = self.model.preproc_func + + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_loader_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) + + def infer_patches( + self: EngineABC, + data_loader: DataLoader, + ) -> dict: + """Model inference on an image patch.""" + progress_bar = None + + if self.verbose: + progress_bar = tqdm.tqdm( + total=int(len(data_loader)), + leave=True, + ncols=80, + ascii=True, + position=0, + ) + raw_predictions = { + "predictions": [], + } + + if self.return_labels: + raw_predictions["labels"] = [] + + for _, batch_data in enumerate(data_loader): + batch_output_predictions = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + raw_predictions["predictions"].extend(batch_output_predictions.tolist()) + + if self.return_labels: # be careful of `s` + # We do not use tolist here because label may be of mixed types + # and hence collated as list by torch + raw_predictions["labels"].extend(list(batch_data["label"])) + + if progress_bar: + progress_bar.update() + + if progress_bar: + progress_bar.close() + + return raw_predictions + + def setup_patch_dataset( + self: EngineABC, + raw_predictions: dict, + output_type: str, + save_dir: Path | None = None, + **kwargs: dict, + ) -> Path | AnnotationStore: + """Post-process image patches. + + Args: + raw_predictions (dict): + A dictionary of patch prediction information. + save_dir (Path): + Optional Output Path to directory to save the patch dataset output to a + `.zarr` or `.db` file, provided patch_mode is True. if the patch_mode is + False then save_dir is required. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (dict): + Keyword Args to update setup_patch_dataset() method attributes. + + Returns: (dict, Path, :class:`SQLiteStore`): + if the output_type is "AnnotationStore", the function returns the patch + predictor output as an SQLiteStore containing Annotations for each or the + Path to a `.db` file depending on whether a save_dir Path is provided. + Otherwise, the function defaults to returning patch predictor output, either + as a dict or the Path to a `.zarr` file depending on whether a save_dir Path + is provided. + + """ + if not save_dir and output_type != "AnnotationStore": + return raw_predictions + + output_file = ( + kwargs["output_file"] and kwargs.pop("output_file") + if "output_file" in kwargs + else "output" + ) + + save_path = save_dir / output_file + + if output_type == "AnnotationStore": + # scale_factor set from kwargs + scale_factor = kwargs["scale_factor"] if "scale_factor" in kwargs else None + # class_dict set from kwargs + class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None + + return dict_to_store(raw_predictions, scale_factor, class_dict, save_path) + + return dict_to_zarr( + raw_predictions, + save_path, + **kwargs, + ) + + @abstractmethod + def pre_process_wsi(self: EngineABC) -> NoReturn: + """Pre-process a WSI.""" + raise NotImplementedError + + @abstractmethod + def infer_wsi(self: EngineABC) -> NoReturn: + """Model inference on a WSI.""" + raise NotImplementedError + @abstractmethod - def process_patch(self): - """Process an image patch.""" + def post_process_wsi(self: EngineABC) -> NoReturn: + """Post-process a WSI.""" raise NotImplementedError - # how to deal with patches, list of patches/numpy arrays, WSIs - # how to communicate with sub-processes. - # define how to deal with patches as numpy/zarr arrays. - # convert list of patches/numpy arrays to zarr and then pass to each sub-processes. - # define how to read WSIs, read the image and convert to zarr array. + def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: + """Helper function to load ioconfig. + + If the model is provided by TIAToolbox it will load the default ioconfig. + Otherwise, ioconfig must be specified. + + Args: + ioconfig (ModelIOConfigABC): + IO configuration to run the engines. + + Raises: + ValueError: + If no io configuration is provided or found in the pretrained TIAToolbox + models. + + Returns: + ModelIOConfigABC: + The ioconfig used for the run. + + """ + if self.ioconfig is None and ioconfig is None: + msg = ( + "Please provide a valid ModelIOConfigABC. " + "No default ModelIOConfigABC found." + ) + raise ValueError(msg) + + if ioconfig is not None: + self.ioconfig = ioconfig + + return self.ioconfig + + @staticmethod + def _validate_images_masks(images: list | np.ndarray) -> list | np.ndarray: + """Validate input images for a run.""" + if not isinstance(images, (list, np.ndarray)): + msg = "Input must be a list of file paths or a numpy array." + raise TypeError( + msg, + ) + + if isinstance(images, np.ndarray) and images.ndim != 4: # noqa: PLR2004 + msg = ( + "The input numpy array should be four dimensional." + "The shape of the numpy array should be NHWC." + ) + raise ValueError(msg) + + return images + + @staticmethod + def _validate_input_numbers( + images: list | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ) -> None: + """Validates number of input images, masks and labels.""" + if masks is None and labels is None: + return + + len_images = len(images) + + if masks is not None and len_images != len(masks): + msg = ( + f"len(masks) is not equal to len(images) " + f": {len(masks)} != {len(images)}" + ) + raise ValueError( + msg, + ) + + if labels is not None and len_images != len(labels): + msg = ( + f"len(labels) is not equal to len(images) " + f": {len(labels)} != {len(images)}" + ) + raise ValueError( + msg, + ) + return + + def run( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: dict, + ) -> AnnotationStore | Path | str: + """Run the engine on input images. + + Args: + images (list, ndarray): + List of inputs to process. when using `patch` mode, the + input must be either a list of images, a list of image + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. + ioconfig (IOPatchPredictorConfig): + IO configuration. + save_dir (str or pathlib.Path): + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr", "AnnotationStore". Default is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (dict): + Keyword Args to update :class:`EngineABC` attributes. + + Returns: + (:class:`numpy.ndarray`, dict): + Model predictions of the input dataset. If multiple + whole slide images are provided as input, + or save_output is True, then results are saved to + `save_dir` and a dictionary indicating save location for + each input is returned. + + The dict has the following format: + + - img_path: path of the input image. + - raw: path to save location for raw prediction, + saved in .json. + - merged: path to .npy contain merged + predictions if `merge_predictions` is `True`. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> predictor = EngineABC(model="resnet18-kather100k") + >>> output = predictor.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} + >>> output['wsi2.svs'] + ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} + + >>> predictor = EngineABC(model="alexnet-kather100k") + >>> output = predictor.run( + >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + >>> labels=list(range(10)), + >>> on_gpu=False, + >>> ) + >>> output + ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., + ... 0.034451354295015335, 0.004817609209567308]], + ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), + ... tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} + + >>> predictor = EngineABC(model="alexnet-kather100k") + >>> save_dir = Path("/tmp/patch_output/") + >>> output = eng.run( + >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + >>> on_gpu=False, + >>> verbose=False, + >>> save_dir=save_dir, + >>> overwrite=True + >>> ) + >>> output + ... '/tmp/patch_output/output.zarr' + """ + for key in kwargs: + setattr(self, key, kwargs[key]) + + self.patch_mode = patch_mode + + self._validate_input_numbers(images=images, masks=masks, labels=labels) + self.images = self._validate_images_masks(images=images) + + if masks is not None: + self.masks = self._validate_images_masks(images=masks) + + self.labels = labels + + # if necessary Move model parameters to "cpu" or "gpu" and update ioconfig + self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) + self.model = model_to(model=self.model, device=self.device) + + save_dir = prepare_engines_save_dir( + save_dir, + len(self.images), + patch_mode=patch_mode, + overwrite=overwrite, + ) + + if patch_mode: + data_loader = self.pre_process_patches( + self.images, + self.labels, + ) + raw_predictions = self.infer_patches( + data_loader=data_loader, + ) + return self.setup_patch_dataset( + raw_predictions=raw_predictions, + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + + return {"save_dir": save_dir} diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7fbae30bb..287fc456b 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -39,32 +39,34 @@ if TYPE_CHECKING: # pragma: no cover import torch - from .io_config import IOInstanceSegmentorConfig + from tiatoolbox.typing import IntBounds + + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig # Python is yet to be able to natively pickle Object method/static method. # Only top-level function is passable to multi-processing as caller. # May need 3rd party libraries to use method/static method otherwise. def _process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, + ioconfig: IOSegmentorConfig, + tile_bounds: IntBounds, + tile_flag: list, + tile_mode: int, + tile_output: list, # this would be replaced by annotation store # in the future - ref_inst_dict, - postproc, - merge_predictions, - model_name, -): + ref_inst_dict: dict, + postproc: Callable, + merge_predictions: Callable, + model_name: str, +) -> tuple: """Process Tile Predictions. Function to merge new tile prediction with existing prediction, using the output from each task. Args: - ioconfig (:class:`IOInstanceSegmentorConfig`): Object defines information + ioconfig (:class:`IOSegmentorConfig`): Object defines information about input and output placement of patches. tile_bounds (:class:`numpy.array`): Boundary of the current tile, defined as (top_left_x, top_left_y, bottom_x, bottom_y). @@ -239,7 +241,7 @@ class MultiTaskSegmentor(NucleusInstanceSegmentor): """ def __init__( # noqa: PLR0913 - self, + self: MultiTaskSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -286,12 +288,12 @@ def __init__( # noqa: PLR0913 ) def _predict_one_wsi( - self, + self: MultiTaskSegmentor, wsi_idx: int, ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -393,13 +395,13 @@ def _predict_one_wsi( # may need to chain it with parents def _process_tile_predictions( - self, - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - ): + self: MultiTaskSegmentor, + ioconfig: IOSegmentorConfig, + tile_bounds: IntBounds, + tile_flag: list, + tile_mode: int, + tile_output: list, + ) -> None: """Function to dispatch parallel post processing.""" args = [ ioconfig, @@ -418,10 +420,15 @@ def _process_tile_predictions( future = _process_tile_predictions(*args) self._futures.append(future) - def _merge_post_process_results(self): + def _merge_post_process_results(self: MultiTaskSegmentor) -> None: """Helper to aggregate results from parallel workers.""" - def callback(new_inst_dicts, remove_uuid_lists, tiles, bounds): + def callback( + new_inst_dicts: dict, + remove_uuid_lists: list, + tiles: dict, + bounds: IntBounds, + ) -> None: """Helper to aggregate worker's results.""" # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store @@ -444,7 +451,7 @@ def callback(new_inst_dicts, remove_uuid_lists, tiles, bounds): callback(*future) continue # some errors happen, log it and propagate exception - # ! this will lead to discard a bunch of + # ! this will lead to discard a whole bunch of # ! inferred tiles within this current WSI if future.exception() is not None: raise future.exception() # noqa: RSE102 diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 4156e2c2a..9aac3b8f5 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -18,18 +18,18 @@ from tiatoolbox.tools.patchextraction import PatchExtractor if TYPE_CHECKING: # pragma: no cover - from .io_config import IOInstanceSegmentorConfig + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig def _process_instance_predictions( - inst_dict, - ioconfig, - tile_shape, - tile_flag, - tile_mode, - tile_tl, - ref_inst_dict, -): + inst_dict: dict, + ioconfig: IOSegmentorConfig, + tile_shape: list, + tile_flag: list, + tile_mode: int, + tile_tl: tuple, + ref_inst_dict: dict, +) -> list | tuple: """Function to merge new tile prediction with existing prediction. Args: @@ -50,12 +50,12 @@ def _process_instance_predictions( an overlapping tile from tile generation. The predicted instances are immediately added to accumulated output. - 1: Vertical tile strip that stands between two normal tiles - (flag 0). It has the the same height as normal tile but + (flag 0). It has the same height as normal tile but less width (hence vertical strip). - 2: Horizontal tile strip that stands between two normal tiles - (flag 0). It has the the same width as normal tile but + (flag 0). It has the same width as normal tile but less height (hence horizontal strip). - - 3: tile strip stands at the cross section of four normal tiles + - 3: tile strip stands at the cross-section of four normal tiles (flag 0). tile_tl (tuple): Top left coordinates of the current tile. ref_inst_dict (dict): Dictionary contains accumulated output. The @@ -144,7 +144,7 @@ def _process_instance_predictions( msg = f"Unknown tile mode {tile_mode}." raise ValueError(msg) - def retrieve_sel_uids(sel_indices, inst_dict): + def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: """Helper to retrieved selected instance uids.""" if len(sel_indices) > 0: # not sure how costly this is in large dict @@ -153,7 +153,7 @@ def retrieve_sel_uids(sel_indices, inst_dict): remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) - # external removal only for tile at cross sections + # external removal only for tile at cross-sections # this one should contain UUID with the reference database remove_insts_in_orig = [] if tile_mode == 3: # noqa: PLR2004 @@ -186,17 +186,17 @@ def retrieve_sel_uids(sel_indices, inst_dict): # caller. May need 3rd party libraries to use method/static method # otherwise. def _process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, + ioconfig: IOSegmentorConfig, + tile_bounds: np.ndarray, + tile_flag: list, + tile_mode: int, + tile_output: list, # this would be replaced by annotation store # in the future - ref_inst_dict, - postproc, - merge_predictions, -): + ref_inst_dict: dict, + postproc: Callable, + merge_predictions: Callable, +) -> tuple[dict, list]: """Function to merge new tile prediction with existing prediction. Args: @@ -368,7 +368,7 @@ class NucleusInstanceSegmentor(SemanticSegmentor): """ def __init__( - self, + self: NucleusInstanceSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -406,7 +406,7 @@ def __init__( def _get_tile_info( image_shape: list[int] | np.ndarray, ioconfig: IOInstanceSegmentorConfig, - ): + ) -> list[list, ...]: """Generating tile information. To avoid out of memory problem when processing WSI-scale in @@ -467,7 +467,7 @@ def _get_tile_info( # * remove all sides for boxes # unset for those lie within the selection - def unset_removal_flag(boxes, removal_flag): + def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: """Unset removal flags for tiles intersecting image boundaries.""" sel_boxes = [ shapely_box(0, 0, w, 0), # top edge @@ -581,7 +581,12 @@ def unset_removal_flag(boxes, removal_flag): return info - def _to_shared_space(self, wsi_idx, patch_inputs, patch_outputs): + def _to_shared_space( + self: NucleusInstanceSegmentor, + wsi_idx: int, + patch_inputs: list, + patch_outputs: list, + ) -> None: """Helper functions to transfer variable to shared space. We modify the shared space so that we can update worker info @@ -613,7 +618,7 @@ def _to_shared_space(self, wsi_idx, patch_inputs, patch_outputs): self._mp_shared_space.patch_outputs = patch_outputs self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_() - def _infer_once(self): + def _infer_once(self: NucleusInstanceSegmentor) -> list: """Running the inference only once for the currently active dataloader.""" num_steps = len(self._loader) @@ -640,7 +645,7 @@ def _infer_once(self): sample_outputs = self.model.infer_batch( self._model, sample_datas, - on_gpu=self._on_gpu, + device=self._device, ) # repackage so that it's a N list, each contains # L x etc. output @@ -658,12 +663,12 @@ def _infer_once(self): return cum_output def _predict_one_wsi( - self, + self: NucleusInstanceSegmentor, wsi_idx: int, - ioconfig: IOInstanceSegmentorConfig, + ioconfig: IOSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -751,13 +756,13 @@ def _predict_one_wsi( self._wsi_inst_info = None # clean up def _process_tile_predictions( - self, - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - ): + self: NucleusInstanceSegmentor, + ioconfig: IOSegmentorConfig, + tile_bounds: np.ndarray, + tile_flag: list, + tile_mode: int, + tile_output: list, + ) -> None: """Function to dispatch parallel post processing.""" args = [ ioconfig, @@ -775,10 +780,10 @@ def _process_tile_predictions( future = _process_tile_predictions(*args) self._futures.append(future) - def _merge_post_process_results(self): + def _merge_post_process_results(self: NucleusInstanceSegmentor) -> None: """Helper to aggregate results from parallel workers.""" - def callback(new_inst_dict, remove_uuid_list): + def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: """Helper to aggregate worker's results.""" # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 807cd9fad..3092f827b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -4,7 +4,7 @@ import copy from collections import OrderedDict from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, NoReturn import numpy as np import torch @@ -12,20 +12,24 @@ import tiatoolbox.models.models_abc from tiatoolbox import logger +from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.utils import save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.typing import Resolution, Units + import os -from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.typing import IntPair, Resolution, Units + + from .io_config import ModelIOConfigABC +from .engine_abc import EngineABC from .io_config import IOPatchPredictorConfig -class PatchPredictor: - r"""Patch level predictor. +class PatchPredictor(EngineABC): + r"""Patch level predictor for digital histology images. The models provided by tiatoolbox should give the following results: @@ -125,12 +129,12 @@ class PatchPredictor: be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. - pretrained_weights (str): + weights (str): Path to the weight of the corresponding `pretrained_model`. >>> predictor = PatchPredictor( ... pretrained_model="resnet18-kather100k", - ... pretrained_weights="resnet18_local_weight") + ... weights="resnet18_local_weight") batch_size (int): Number of images fed into the model each time. @@ -141,14 +145,14 @@ class PatchPredictor: Whether to output logging information. Attributes: - img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): + images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): A HWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. model (nn.Module): Defined PyTorch model. - pretrained_model (str): + model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs @@ -166,7 +170,7 @@ class PatchPredictor: Examples: >>> # list of 2 image patches as input - >>> data = [img1, img2] + >>> data = ['path/img.svs', 'path/img.svs'] >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') @@ -202,38 +206,46 @@ class PatchPredictor: """ def __init__( - self, - batch_size=8, - num_loader_workers=0, - model=None, - pretrained_model=None, - pretrained_weights=None, + self: PatchPredictor, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + model: torch.nn.Module = None, + pretrained_model: str | None = None, + weights: str | None = None, *, - verbose=True, + verbose: bool = True, ) -> None: """Initialize :class:`PatchPredictor`.""" - super().__init__() + super().__init__( + batch_size=batch_size, + num_loader_workers=num_loader_workers, + num_post_proc_workers=num_post_proc_workers, + model=model, + pretrained_model=pretrained_model, + weights=weights, + verbose=verbose, + ) - self.imgs = None - self.mode = None + def pre_process_wsi(self: PatchPredictor) -> NoReturn: + """Pre-process a WSI.""" + ... - if model is None and pretrained_model is None: - msg = "Must provide either `model` or `pretrained_model`." - raise ValueError(msg) + def infer_wsi(self: PatchPredictor) -> NoReturn: + """Model inference on a WSI.""" + ... - if model is not None: - self.model = model - ioconfig = None # retrieve iostate from provided model ? - else: - model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) + def post_process_patches( + self: PatchPredictor, + raw_predictions: dict, + output_type: str, + ) -> None: + """Post-process an image patch.""" + ... - self.ioconfig = ioconfig # for storing original - self._ioconfig = None # for storing runtime - self.model = model # for runtime, such as after wrapping with nn.DataParallel - self.pretrained_model = pretrained_model - self.batch_size = batch_size - self.num_loader_worker = num_loader_workers - self.verbose = verbose + def post_process_wsi(self: PatchPredictor) -> NoReturn: + """Post-process a WSI.""" + ... @staticmethod def merge_predictions( @@ -241,10 +253,10 @@ def merge_predictions( output: dict, resolution: Resolution | None = None, units: Units | None = None, - postproc_func: Callable | None = None, + post_proc_func: Callable | None = None, *, return_raw: bool = False, - ): + ) -> np.ndarray: """Merge patch level predictions to form a 2-dimensional prediction map. #! Improve how the below reads. @@ -263,7 +275,7 @@ def merge_predictions( units (Units): Units of resolution used when merging predictions. This must be the same `units` used when processing the data. - postproc_func (callable): + post_proc_func (callable): A function to post-process raw prediction from model. By default, internal code uses the `np.argmax` function. return_raw (bool): @@ -345,8 +357,8 @@ def merge_predictions( output = output / (np.expand_dims(denominator, -1) + 1.0e-8) if not return_raw: # convert raw probabilities to predictions - if postproc_func is not None: - output = postproc_func(output) + if post_proc_func is not None: + output = post_proc_func(output) else: output = np.argmax(output, axis=-1) # to make sure background is 0 while class will be 1...N @@ -354,14 +366,14 @@ def merge_predictions( return output def _predict_engine( - self, - dataset, + self: PatchPredictor, + dataset: torch.utils.data.Dataset, *, - return_probabilities=False, - return_labels=False, - return_coordinates=False, - on_gpu=True, - ): + return_probabilities: bool = False, + return_labels: bool = False, + return_coordinates: bool = False, + device: str = "cpu", + ) -> np.ndarray: """Make a prediction on a dataset. The dataset may be mutated. Args: @@ -374,8 +386,8 @@ def _predict_engine( Whether to return labels. return_coordinates (bool): Whether to return patch coordinates. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". Returns: :class:`numpy.ndarray`: @@ -387,7 +399,7 @@ def _predict_engine( # preprocessing must be defined with the dataset dataloader = torch.utils.data.DataLoader( dataset, - num_workers=self.num_loader_worker, + num_workers=self.num_loader_workers, batch_size=self.batch_size, drop_last=False, shuffle=False, @@ -403,7 +415,7 @@ def _predict_engine( ) # use external for testing - model = tiatoolbox.models.models_abc.model_to(model=self.model, on_gpu=on_gpu) + model = tiatoolbox.models.models_abc.model_to(model=self.model, device=device) cum_output = { "probabilities": [], @@ -415,7 +427,7 @@ def _predict_engine( batch_output_probabilities = self.model.infer_batch( model, batch_data["image"], - on_gpu=on_gpu, + device=device, ) # We get the index of the class with the maximum probability batch_output_predictions = self.model.postproc_func( @@ -447,13 +459,13 @@ def _predict_engine( return cum_output def _update_ioconfig( - self, - ioconfig, - patch_input_shape, - stride_shape, - resolution, - units, - ): + self: PatchPredictor, + ioconfig: IOPatchPredictorConfig, + patch_input_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> IOPatchPredictorConfig: """Update the ioconfig. Args: @@ -519,44 +531,15 @@ def _update_ioconfig( output_resolutions=[], ) - @staticmethod - def _prepare_save_dir(save_dir, imgs): - """Create directory if not defined and number of images is more than 1. - - Args: - save_dir (str or pathlib.Path): - Path to output directory. - imgs (list, ndarray): - List of inputs to process. - - Returns: - :class:`pathlib.Path`: - Path to output directory. - - """ - if save_dir is None and len(imgs) > 1: - logger.warning( - "More than 1 WSIs detected but there is no save directory set." - "All subsequent output will be saved to current runtime" - "location under folder 'output'. Overwriting may happen!", - stacklevel=2, - ) - save_dir = Path.cwd() / "output" - elif save_dir is not None and len(imgs) > 1: - logger.warning( - "When providing multiple whole-slide images / tiles, " - "we save the outputs and return the locations " - "to the corresponding files.", - stacklevel=2, - ) - - if save_dir is not None: - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=False) - - return save_dir - - def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_gpu): + def _predict_patch( + self: PatchPredictor, + imgs: list | np.ndarray, + labels: list, + *, + return_probabilities: bool, + return_labels: bool, + device: str, + ) -> np.ndarray: """Process patch mode. Args: @@ -574,8 +557,8 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + Select the device to run the engine. Returns: :class:`numpy.ndarray`: @@ -600,23 +583,24 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g return_probabilities=return_probabilities, return_labels=return_labels, return_coordinates=return_coordinates, - on_gpu=on_gpu, + device=device, ) def _predict_tile_wsi( # noqa: PLR0913 - self, - imgs, - masks, - labels, - mode, - return_probabilities, - on_gpu, - ioconfig, - merge_predictions, - save_dir, - save_output, - highest_input_resolution, - ): + self: PatchPredictor, + imgs: list, + masks: list | None, + labels: list, + mode: str, + ioconfig: IOPatchPredictorConfig, + save_dir: str | Path, + highest_input_resolution: list[dict], + *, + save_output: bool, + return_probabilities: bool, + merge_predictions: bool, + on_gpu: bool, + ) -> list | dict: """Predict on Tile and WSIs. Args: @@ -626,7 +610,7 @@ def _predict_tile_wsi( # noqa: PLR0913 file paths or a numpy array of an image list. When using `tile` or `wsi` mode, the input must be a list of file paths. - masks (list): + masks (list or None): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a @@ -715,7 +699,7 @@ def _predict_tile_wsi( # noqa: PLR0913 ) output_model["label"] = img_label # add extra information useful for downstream analysis - output_model["pretrained_model"] = self.pretrained_model + output_model["pretrained_model"] = self.model output_model["resolution"] = highest_input_resolution["resolution"] output_model["units"] = highest_input_resolution["units"] @@ -727,7 +711,7 @@ def _predict_tile_wsi( # noqa: PLR0913 output_model, resolution=output_model["resolution"], units=output_model["units"], - postproc_func=self.model.postproc, + post_proc_func=self.model.postproc, ) outputs.append(merged_prediction) @@ -748,25 +732,51 @@ def _predict_tile_wsi( # noqa: PLR0913 return file_dict if save_output else outputs + def run( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: dict, + ) -> AnnotationStore | str: + """Run engine.""" + super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) + def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - labels=None, - mode="patch", + self: PatchPredictor, + imgs: list, + masks: list | None = None, + labels: list | None = None, + mode: str = "patch", ioconfig: IOPatchPredictorConfig | None = None, patch_input_shape: tuple[int, int] | None = None, stride_shape: tuple[int, int] | None = None, - resolution=None, - units=None, + resolution: Resolution | None = None, + units: Units = None, *, - return_probabilities=False, - return_labels=False, - on_gpu=True, - merge_predictions=False, - save_dir=None, - save_output=False, - ): + return_probabilities: bool = False, + return_labels: bool = False, + on_gpu: bool = True, + merge_predictions: bool = False, + save_dir: str | Path | None = None, + save_output: bool = False, + ) -> np.ndarray | list | dict: """Make a prediction for a list of input data. Args: diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index e1341c640..237d032f1 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -27,10 +27,13 @@ from .io_config import IOSegmentorConfig if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.typing import Resolution, Units + from tiatoolbox.typing import IntPair, Resolution, Units -def _estimate_canvas_parameters(sample_prediction, canvas_shape): +def _estimate_canvas_parameters( + sample_prediction: np.ndarray, + canvas_shape: np.ndarray, +) -> tuple[tuple, tuple, bool]: """Estimates canvas parameters. Args: @@ -58,11 +61,11 @@ def _estimate_canvas_parameters(sample_prediction, canvas_shape): def _prepare_save_output( - save_path, - cache_count_path, - canvas_cum_shape_, - canvas_count_shape_, -): + save_path: str | Path, + cache_count_path: str | Path, + canvas_cum_shape_: tuple[int, ...], + canvas_count_shape_: tuple[int, ...], +) -> tuple: """Prepares for saving the cached output.""" if save_path is not None: save_path = Path(save_path) @@ -193,7 +196,7 @@ class SemanticSegmentor: """ def __init__( - self, + self: SemanticSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -251,7 +254,7 @@ def __init__( def get_coordinates( image_shape: list[int] | np.ndarray, ioconfig: IOSegmentorConfig, - ): + ) -> tuple[list, list]: """Calculate patch tiling coordinates. By default, internally, it will call the @@ -309,7 +312,7 @@ def filter_coordinates( bounds: np.ndarray, resolution: Resolution | None = None, units: Units | None = None, - ): + ) -> np.ndarray: """Indicates which coordinate is valid basing on the mask. To use your own approaches, either subclass to overwrite or @@ -369,7 +372,7 @@ def filter_coordinates( scale_factor = mask_real_shape / mask_resolution_shape scale_factor = scale_factor[0] # what if ratio x != y - def sel_func(coord: np.ndarray): + def sel_func(coord: np.ndarray) -> bool: """Accept coord as long as its box contains part of mask.""" coord_in_real_mask = np.ceil(scale_factor * coord).astype(np.int32) start_x, start_y, end_x, end_y = coord_in_real_mask @@ -386,7 +389,7 @@ def get_reader( mode: str, *, auto_get_mask: bool, - ): + ) -> tuple[WSIReader, WSIReader]: """Define how to get reader for mask and source image.""" img_path = Path(img_path) reader = WSIReader.open(img_path) @@ -411,12 +414,12 @@ def get_reader( return reader, mask_reader def _predict_one_wsi( - self, + self: SemanticSegmentor, wsi_idx: int, ioconfig: IOSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -527,13 +530,13 @@ def _predict_one_wsi( shutil.rmtree(cache_dir) def _process_predictions( - self, + self: SemanticSegmentor, cum_batch_predictions: list, wsi_reader: WSIReader, ioconfig: IOSegmentorConfig, save_path: str, cache_dir: str, - ): + ) -> None: """Define how the aggregated predictions are processed. This includes merging the prediction if necessary and also saving afterwards. @@ -595,7 +598,7 @@ def merge_prediction( locations: list | np.ndarray, save_path: str | Path | None = None, cache_count_path: str | Path | None = None, - ): + ) -> np.ndarray: """Merge patch-level predictions to form a 2-dimensional prediction map. When accumulating the raw prediction onto a same canvas (via @@ -665,7 +668,7 @@ def merge_prediction( canvas_count_shape_, ) - def index(arr, tl, br): + def index(arr: np.ndarray, tl: np.ndarray, br: np.ndarray) -> np.ndarray: """Helper to shorten indexing.""" return arr[tl[0] : br[0], tl[1] : br[1]] @@ -726,7 +729,7 @@ def index(arr, tl, br): return cum_canvas @staticmethod - def _prepare_save_dir(save_dir): + def _prepare_save_dir(save_dir: str | Path | None) -> tuple[Path, Path]: """Prepare save directory and cache.""" if save_dir is None: logger.warning( @@ -749,14 +752,14 @@ def _prepare_save_dir(save_dir): @staticmethod def _update_ioconfig( - ioconfig, - mode, - patch_input_shape, - patch_output_shape, - stride_shape, - resolution, - units, - ): + ioconfig: IOSegmentorConfig, + mode: str, + patch_input_shape: IntPair, + patch_output_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> IOSegmentorConfig: """Update ioconfig according to input parameters. Args: @@ -815,7 +818,7 @@ def _update_ioconfig( return ioconfig - def _prepare_workers(self): + def _prepare_workers(self: SemanticSegmentor) -> None: """Prepare number of workers.""" self._postproc_workers = None if self.num_postproc_workers is not None: @@ -823,7 +826,7 @@ def _prepare_workers(self): max_workers=self.num_postproc_workers, ) - def _memory_cleanup(self): + def _memory_cleanup(self: SemanticSegmentor) -> None: """Memory clean up.""" self.imgs = None self.masks = None @@ -838,15 +841,16 @@ def _memory_cleanup(self): self._postproc_workers = None def _predict_wsi_handle_exception( - self, - imgs, - wsi_idx, - img_path, - mode, - ioconfig, - save_dir, - crash_on_exception, - ): + self: SemanticSegmentor, + imgs: list, + wsi_idx: int, + img_path: str | Path, + mode: str, + ioconfig: IOSegmentorConfig, + save_dir: str | Path, + *, + crash_on_exception: bool, + ) -> None: """Predict on multiple WSIs. Args: @@ -916,21 +920,21 @@ def _predict_wsi_handle_exception( logging.exception("Crashed on %s", wsi_save_path) def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - mode="tile", - ioconfig=None, - patch_input_shape=None, - patch_output_shape=None, - stride_shape=None, - resolution=None, - units=None, - save_dir=None, + self: SemanticSegmentor, + imgs: list, + masks: list | None = None, + mode: str = "tile", + ioconfig: IOSegmentorConfig = None, + patch_input_shape: IntPair = None, + patch_output_shape: IntPair = None, + stride_shape: IntPair = None, + resolution: Resolution = 1.0, + units: Units = "baseline", + save_dir: str | Path | None = None, *, - on_gpu=True, - crash_on_exception=False, - ): + device: str = "cpu", + crash_on_exception: bool = False, + ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. By default, if the input model at the object instantiation time @@ -966,8 +970,8 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". patch_input_shape (tuple): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1049,10 +1053,10 @@ def predict( # noqa: PLR0913 ) # use external for testing - self._on_gpu = on_gpu + self._device = device self._model = tiatoolbox.models.models_abc.model_to( model=self.model, - on_gpu=on_gpu, + device=device, ) # workers should be > 0 else Value Error will be thrown @@ -1170,7 +1174,7 @@ class DeepFeatureExtractor(SemanticSegmentor): """ def __init__( - self, + self: DeepFeatureExtractor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -1197,13 +1201,13 @@ def __init__( self.process_prediction_per_batch = False def _process_predictions( - self, + self: DeepFeatureExtractor, cum_batch_predictions: list, wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 ioconfig: IOSegmentorConfig, save_path: str, cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 - ): + ) -> None: """Define how the aggregated predictions are processed. This includes merging the prediction if necessary and also @@ -1241,21 +1245,21 @@ def _process_predictions( np.save(f"{save_path}.features.{idx}.npy", prediction_list) def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - mode="tile", - ioconfig=None, - patch_input_shape=None, - patch_output_shape=None, - stride_shape=None, - resolution=1.0, - units="baseline", - save_dir=None, + self: DeepFeatureExtractor, + imgs: list, + masks: list | None = None, + mode: str = "tile", + ioconfig: IOSegmentorConfig | None = None, + patch_input_shape: IntPair | None = None, + patch_output_shape: IntPair | None = None, + stride_shape: IntPair = None, + resolution: Resolution = 1.0, + units: Units = "baseline", + save_dir: str | Path | None = None, *, - on_gpu=True, - crash_on_exception=False, - ): + device: str = "cpu", + crash_on_exception: bool = False, + ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. By default, if the input model at the time of object @@ -1291,8 +1295,8 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". patch_input_shape (tuple): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1348,7 +1352,7 @@ def predict( # noqa: PLR0913 imgs=imgs, masks=masks, mode=mode, - on_gpu=on_gpu, + device=device, ioconfig=ioconfig, patch_input_shape=patch_input_shape, patch_output_shape=patch_output_shape, diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 9c5bb4cd1..98ca29911 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -2,19 +2,64 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable import torch from torch import nn if TYPE_CHECKING: # pragma: no cover + from pathlib import Path + import numpy as np +def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: + """Helper function to load a torch model. + + Args: + model (torch.nn.Module): + A torch model. + weights (str or Path): + Path to pretrained weights. + + Returns: + torch.nn.Module: + Torch model with pretrained weights loaded on CPU. + + """ + # ! assume to be saved in single GPU mode + # always load on to the CPU + saved_state_dict = torch.load(weights, map_location="cpu") + model.load_state_dict(saved_state_dict, strict=True) + return model + + +def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: + """Transfers model to cpu/gpu. + + Args: + model (torch.nn.Module): + PyTorch defined model. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + torch.nn.Module: + The model after being moved to cpu/gpu. + + """ + if device != "cpu": + # DataParallel work only for cuda + model = torch.nn.DataParallel(model) + + device = torch.device(device) + return model.to(device) + + class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" - def __init__(self) -> None: + def __init__(self: ModelABC) -> None: """Initialize Abstract class ModelABC.""" super().__init__() self._postproc = self.postproc @@ -22,13 +67,13 @@ def __init__(self) -> None: @abstractmethod # This is generic abc, else pylint will complain - def forward(self, *args, **kwargs): + def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover @staticmethod @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool): + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> None: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -39,29 +84,29 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool): batch_data (np.ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ ... # pragma: no cover @staticmethod - def preproc(image): + def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of model.""" return image @staticmethod - def postproc(image): + def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model.""" return image @property - def preproc_func(self): + def preproc_func(self: ModelABC) -> Callable: """Return the current pre-processing function of this instance.""" return self._preproc @preproc_func.setter - def preproc_func(self, func): + def preproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. @@ -73,7 +118,7 @@ def preproc_func(self, func): >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func - >>> transformed_img = model.preproc_func(img) + >>> transformed_img = model.preproc_func(image=np.ndarray) """ if func is not None and not callable(func): @@ -86,12 +131,12 @@ def preproc_func(self, func): self._preproc = func @property - def postproc_func(self): + def postproc_func(self: ModelABC) -> Callable: """Return the current post-processing function of this instance.""" return self._postproc @postproc_func.setter - def postproc_func(self, func): + def postproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. @@ -104,7 +149,7 @@ def postproc_func(self, func): >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func - >>> transformed_img = model.postproc_func(img) + >>> transformed_img = model.postproc_func(image=np.ndarray) """ if func is not None and not callable(func): @@ -115,22 +160,3 @@ def postproc_func(self, func): self._postproc = self.postproc else: self._postproc = func - - -def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: - """Transfers model to cpu/gpu. - - Args: - model (torch.nn.Module): PyTorch defined model. - on_gpu (bool): Transfers model to gpu if True otherwise to cpu. - - Returns: - torch.nn.Module: - The model after being moved to cpu/gpu. - - """ - if on_gpu: # DataParallel work only for cuda - model = torch.nn.DataParallel(model) - return model.to("cuda") - - return model.to("cpu") From 1cbf6185d08f7cbe7620a5493df40cb751d1be4e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:01:53 +0000 Subject: [PATCH 17/36] :goal_net: Fix merge errors in classification.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/dataset/classification.py | 3 ++- tiatoolbox/models/models_abc.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index 07d327dcd..40a77c0f5 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -3,11 +3,12 @@ from typing import TYPE_CHECKING -from PIL import Image from torchvision import transforms if TYPE_CHECKING: # pragma: no cover import numpy as np + import torch + from PIL import Image class _TorchPreprocCaller: diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 13f476674..757f33ca4 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -166,7 +166,7 @@ def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. Args: - model (torch.nn.Module): + self (ModelABC): PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". From 92f0d506cfae18b6ed011ec7e2dd48262b912cdf Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:31:50 +0000 Subject: [PATCH 18/36] :goal_net: Fix merge errors in dataset_abc.py --- tiatoolbox/models/dataset/dataset_abc.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 8558f9098..f239c23f8 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -4,7 +4,7 @@ import copy from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Iterable, List, Union import cv2 import numpy as np @@ -22,6 +22,11 @@ from tiatoolbox.models.engine.io_config import IOSegmentorConfig from tiatoolbox.typing import IntPair, Resolution, Units + try: + from typing import TypeGuard + except ImportError: + from typing_extensions import TypeGuard # to support python <3.10 + input_type = Union[List[Union[str, Path, np.ndarray]], np.ndarray] @@ -354,7 +359,7 @@ class WSIPatchDataset(PatchDatasetABC): """ - def __init__( # noqa: PLR0913, PLR0915 + def __init__( # skipcq: PY-R1000 # noqa: PLR0913, PLR0915 self: WSIPatchDataset, img_path: str | Path, mode: str = "wsi", From d1b0d8235fc76aa10b66ec3171686f33a7b8411e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 15 Nov 2023 14:11:53 +0000 Subject: [PATCH 19/36] :goal_net: Fix merge errors in dataset_abc.py --- tests/models/test_dataset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index 538b2dbd4..5532b5fa2 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -226,7 +226,17 @@ def test_patch_dataset_crash(tmp_path: Path) -> None: # ndarray of mixed dtype imgs = np.array( - [RNG.integers(0, 255, (4, 5, 3)), "Should crash"], + # string array of the same shape + [ + RNG.integers(0, 255, (4, 5, 3)), + np.array( # skipcq: PYL-E1121 + ["PatchDataset should crash here" for _ in range(4 * 5 * 3)], + ).reshape( + 4, + 5, + 3, + ), + ], dtype=object, ) with pytest.raises(ValueError, match="Provided input array is non-numerical."): From 4b357f5b7edea3cdfbb93193f29cd40da184b7ab Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 15 Nov 2023 14:30:59 +0000 Subject: [PATCH 20/36] :goal_net: Fix test_models_abc.py --- tests/models/test_models_abc.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/models/test_models_abc.py b/tests/models/test_models_abc.py index 0dd199c0e..598cb29fa 100644 --- a/tests/models/test_models_abc.py +++ b/tests/models/test_models_abc.py @@ -140,6 +140,16 @@ def test_model_abc() -> None: model.postproc_func = None # skipcq: PYL-W0201 assert model.postproc_func(2) == 0 + # Test load_weights_from_file() method + weights_path = fetch_pretrained_weights("alexnet-kather100k") + with pytest.raises(RuntimeError, match=r".*loading state_dict*"): + _ = model.load_weights_from_file(weights_path) + + # Test on CPU + model = model.to(device="cpu") + assert isinstance(model, nn.Module) + assert model.dummy_param.device.type == "cpu" + def test_model_to() -> None: """Test for placing model on device.""" @@ -157,9 +167,3 @@ def test_model_to() -> None: model = torch_models.resnet18() model = tiatoolbox.models.models_abc.model_to(device="cpu", model=model) assert isinstance(model, nn.Module) - assert model.dummy_param.device.type == "cpu" - - # Test load_weights_from_file() method - weights_path = fetch_pretrained_weights("alexnet-kather100k") - with pytest.raises(RuntimeError, match=r".*loading state_dict*"): - _ = model.load_weights_from_file(weights_path) From aaff1f8ba313b3c62fcb3a7b672e55626d831b99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Feb 2024 12:01:09 +0000 Subject: [PATCH 21/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 1 + tiatoolbox/models/engine/engine_abc.py | 1 + tiatoolbox/models/engine/io_config.py | 1 + 3 files changed, 3 insertions(+) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 8ab54098f..d72cacfe7 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -1,4 +1,5 @@ """Test tiatoolbox.models.engine.engine_abc.""" + from __future__ import annotations from pathlib import Path diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 7e7c71242..32ad1a562 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,4 +1,5 @@ """Defines Abstract Base Class for TIAToolbox Model Engines.""" + from __future__ import annotations from abc import ABC, abstractmethod diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index 8b397b798..f6c9b9c2c 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -1,4 +1,5 @@ """Defines IOConfig for Model Engines.""" + from __future__ import annotations from dataclasses import dataclass, field, replace From 17581f558a2f81a2ca50935000a7c03a5fb38226 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 15:03:15 +0000 Subject: [PATCH 22/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 6 +++--- tiatoolbox/models/engine/patch_predictor.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index d72cacfe7..3701273e2 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -35,15 +35,15 @@ def __init__( def infer_wsi(self: EngineABC) -> NoReturn: """Test infer_wsi.""" - ... # dummy function for tests. + # dummy function for tests. def post_process_wsi(self: EngineABC) -> NoReturn: """Test post_process_wsi.""" - ... # dummy function for tests. + # dummy function for tests. def pre_process_wsi(self: EngineABC) -> NoReturn: """Test pre_process_wsi.""" - ... # dummy function for tests. + # dummy function for tests. def test_engine_abc() -> NoReturn: diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index e2352e0a5..75bda0380 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -230,11 +230,9 @@ def __init__( def pre_process_wsi(self: PatchPredictor) -> NoReturn: """Pre-process a WSI.""" - ... def infer_wsi(self: PatchPredictor) -> NoReturn: """Model inference on a WSI.""" - ... def post_process_patches( self: PatchPredictor, @@ -242,11 +240,9 @@ def post_process_patches( output_type: str, ) -> None: """Post-process an image patch.""" - ... def post_process_wsi(self: PatchPredictor) -> NoReturn: """Post-process a WSI.""" - ... @staticmethod def merge_predictions( From 90f396c84b15cb83ccd3aa46e45b907a4329f0d6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:16:46 +0000 Subject: [PATCH 23/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/dataset/dataset_abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 28fe710c5..d69518a6e 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -18,11 +18,11 @@ from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader if TYPE_CHECKING: # pragma: no cover + from collections.abc import Iterable from multiprocessing.managers import Namespace from tiatoolbox.models.engine.io_config import IOSegmentorConfig from tiatoolbox.typing import IntPair, Resolution, Units - from collections.abc import Iterable try: from typing import TypeGuard From 466d005cfabc33d0ae37229d365f8aa570eceddc Mon Sep 17 00:00:00 2001 From: Abishek Date: Fri, 26 Apr 2024 14:12:46 +0100 Subject: [PATCH 24/36] =?UTF-8?q?=E2=9C=A8=20Add=20WSI=20processing=20to?= =?UTF-8?q?=20`EngineABC`=20(#737)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements cache mode for processing using EngineABC. Cache_mode saves intermediate results to zarr file which can be converted to AnnotationStore. Full WSI pipeline needs to be implemented for Engines inheriting the base class. The output of models should be a dictionary according to the new Engine design. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- .github/workflows/mypy-type-check.yml | 2 +- tests/engines/test_engine_abc.py | 312 ++++++- tests/test_app_bokeh.py | 51 +- tiatoolbox/models/architecture/vanilla.py | 8 +- tiatoolbox/models/engine/engine_abc.py | 881 +++++++++++++++----- tiatoolbox/models/engine/patch_predictor.py | 4 +- tiatoolbox/models/models_abc.py | 7 +- tiatoolbox/utils/misc.py | 204 ++++- 8 files changed, 1155 insertions(+), 314 deletions(-) diff --git a/.github/workflows/mypy-type-check.yml b/.github/workflows/mypy-type-check.yml index b744a1c86..c6c677a73 100644 --- a/.github/workflows/mypy-type-check.yml +++ b/.github/workflows/mypy-type-check.yml @@ -6,7 +6,7 @@ on: push: branches: [ develop, pre-release, master, main ] pull_request: - branches: [ develop, pre-release, master, main ] + branches: [ develop, pre-release, master, main, dev-define-engines-abc ] jobs: diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 3701273e2..d0d6f0bd5 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -2,18 +2,22 @@ from __future__ import annotations +import copy +import shutil from pathlib import Path from typing import TYPE_CHECKING, NoReturn import numpy as np import pytest import torchvision.models as torch_models +import zarr from tiatoolbox.models.architecture import ( fetch_pretrained_weights, get_pretrained_model, ) from tiatoolbox.models.architecture.vanilla import CNNModel +from tiatoolbox.models.dataset import PatchDataset, WSIPatchDataset from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir from tiatoolbox.models.engine.io_config import ModelIOConfigABC @@ -33,24 +37,60 @@ def __init__( """Test EngineABC init.""" super().__init__(model=model, weights=weights, verbose=verbose) - def infer_wsi(self: EngineABC) -> NoReturn: - """Test infer_wsi.""" - # dummy function for tests. + def get_dataloader( + self: EngineABC, + images: Path, + masks: Path | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + ) -> torch.utils.data.DataLoader: + """Test pre process images.""" + return super().get_dataloader( + images, + masks, + labels, + ioconfig, + patch_mode=patch_mode, + ) - def post_process_wsi(self: EngineABC) -> NoReturn: + def save_wsi_output( + self: EngineABC, + raw_output: dict, + save_dir: Path, + **kwargs: dict, + ) -> Path: """Test post_process_wsi.""" - # dummy function for tests. + return super().save_wsi_output( + raw_output, + save_dir=save_dir, + **kwargs, + ) - def pre_process_wsi(self: EngineABC) -> NoReturn: - """Test pre_process_wsi.""" - # dummy function for tests. + def infer_wsi( + self: EngineABC, + dataloader: torch.utils.data.DataLoader, + img_label: str, + highest_input_resolution: list[dict], + save_dir: Path, + **kwargs: dict, + ) -> dict | np.ndarray: + """Test infer_wsi.""" + return super().infer_wsi( + dataloader, + img_label, + highest_input_resolution, + save_dir, + **kwargs, + ) def test_engine_abc() -> NoReturn: """Test EngineABC initialization.""" with pytest.raises( TypeError, - match=r".*Can't instantiate abstract class EngineABC with abstract methods*", + match=r".*Can't instantiate abstract class EngineABC*", ): # Can't instantiate abstract class with abstract methods EngineABC() # skipcq @@ -125,7 +165,6 @@ def test_prepare_engines_save_dir( out_dir = prepare_engines_save_dir( save_dir=tmp_path / "patch_output", patch_mode=True, - len_images=1, overwrite=False, ) @@ -135,7 +174,6 @@ def test_prepare_engines_save_dir( out_dir = prepare_engines_save_dir( save_dir=tmp_path / "patch_output", patch_mode=True, - len_images=1, overwrite=True, ) @@ -145,35 +183,23 @@ def test_prepare_engines_save_dir( out_dir = prepare_engines_save_dir( save_dir=None, patch_mode=True, - len_images=1, overwrite=False, ) assert out_dir is None with pytest.raises( OSError, - match=r".*More than 1 WSIs detected but there is no save directory provided.*", + match=r".*Input WSIs detected but no save directory provided.*", ): _ = prepare_engines_save_dir( save_dir=None, patch_mode=False, - len_images=2, overwrite=False, ) - out_dir = prepare_engines_save_dir( - save_dir=None, - patch_mode=False, - len_images=1, - overwrite=False, - ) - - assert out_dir == Path.cwd() - out_dir = prepare_engines_save_dir( save_dir=tmp_path / "wsi_single_output", patch_mode=False, - len_images=1, overwrite=False, ) @@ -184,7 +210,6 @@ def test_prepare_engines_save_dir( out_dir = prepare_engines_save_dir( save_dir=tmp_path / "wsi_multiple_output", patch_mode=False, - len_images=2, overwrite=False, ) @@ -196,7 +221,6 @@ def test_prepare_engines_save_dir( out_path = prepare_engines_save_dir( save_dir=tmp_path / "patch_output" / "output.zarr", patch_mode=True, - len_images=1, overwrite=True, ) assert out_path.exists() @@ -204,7 +228,6 @@ def test_prepare_engines_save_dir( out_path = prepare_engines_save_dir( save_dir=tmp_path / "patch_output" / "output.zarr", patch_mode=True, - len_images=1, overwrite=True, ) assert out_path.exists() @@ -213,7 +236,6 @@ def test_prepare_engines_save_dir( out_path = prepare_engines_save_dir( save_dir=tmp_path / "patch_output" / "output.zarr", patch_mode=True, - len_images=1, overwrite=False, ) @@ -238,7 +260,7 @@ def test_engine_initalization() -> NoReturn: assert isinstance(eng, EngineABC) -def test_engine_run() -> NoReturn: +def test_engine_run(tmp_path: Path) -> NoReturn: """Test engine run.""" eng = TestEngineABC(model="alexnet-kather100k") assert isinstance(eng, EngineABC) @@ -315,6 +337,15 @@ def test_engine_run() -> NoReturn: assert "predictions" in out assert "labels" in out + eng = TestEngineABC(model="alexnet-kather100k") + + with pytest.raises(NotImplementedError): + eng.run( + images=np.zeros(shape=(10, 224, 224, 3)), + save_dir=tmp_path / "output", + patch_mode=False, + ) + def test_engine_run_with_verbose() -> NoReturn: """Test engine run with verbose.""" @@ -381,7 +412,7 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: ValueError, match=r".*Patch output must contain coordinates.", ): - out = eng.run( + _ = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(10)), on_gpu=False, @@ -394,7 +425,7 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: ValueError, match=r".*Patch output must contain coordinates.", ): - out = eng.run( + _ = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(10)), on_gpu=False, @@ -408,7 +439,7 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: ValueError, match=r".*Patch output must contain coordinates.", ): - out = eng.run( + _ = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(10)), on_gpu=False, @@ -417,3 +448,220 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: output_type="AnnotationStore", scale_factor=(2.0, 2.0), ) + + +def test_cache_mode_patches(tmp_path: pytest.TempPathFactory) -> NoReturn: + """Test the caching mode.""" + save_dir = tmp_path / "patch_output" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + cache_mode=True, + ) + assert out.exists(), "Zarr output file does not exist" + + output_file_name = "output2.zarr" + cache_size = 4 + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + cache_mode=True, + cache_size=4, + batch_size=8, + output_file=output_file_name, + ) + assert out.stem == output_file_name.split(".")[0] + assert eng.batch_size == cache_size + assert out.exists(), "Zarr output file does not exist" + + +def test_get_dataloader(sample_svs: Path) -> None: + """Test the get_dataloader function.""" + eng = TestEngineABC(model="alexnet-kather100k") + ioconfig = ModelIOConfigABC( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + ], + patch_input_shape=(224, 224), + ) + dataloader = eng.get_dataloader( + images=np.zeros(shape=(10, 224, 224, 3), dtype=np.uint8), + patch_mode=True, + ioconfig=ioconfig, + ) + + assert isinstance(dataloader.dataset, PatchDataset) + + dataloader = eng.get_dataloader( + images=sample_svs, + patch_mode=False, + ioconfig=ioconfig, + ) + + assert isinstance(dataloader.dataset, WSIPatchDataset) + + +def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: + """Test the eng.save_output() function.""" + eng = TestEngineABC(model="alexnet-kather100k") + save_path = tmp_path / "output.zarr" + _ = zarr.open(save_path, mode="w") + out = eng.save_wsi_output( + raw_output=save_path, save_path=save_path, output_type="zarr", save_dir=tmp_path + ) + + assert out.exists() + assert out.suffix == ".zarr" + + # Test AnnotationStore + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "other": "other", + } + class_dict = {0: "class0", 1: "class1"} + out = eng.save_wsi_output( + raw_output=patch_output, + scale_factor=(1.0, 1.0), + class_dict=class_dict, + save_dir=tmp_path, + output_type="AnnotationStore", + ) + + assert out.exists() + assert out.suffix == ".db" + + with pytest.raises( + ValueError, + match=r".*supports zarr and AnnotationStore as output_type.", + ): + eng.save_wsi_output( + raw_output=save_path, + save_path=save_path, + output_type="dict", + save_dir=tmp_path, + ) + + +def test_io_config_delegation(tmp_path: Path) -> None: + """Test for delegating args to io config.""" + # test not providing config / full input info for not pretrained models + model = CNNModel("resnet50") + eng = TestEngineABC(model=model) + with pytest.raises(ValueError, match=r".*Please provide a valid ModelIOConfigABC*"): + eng.run( + np.zeros((10, 224, 224, 3)), patch_mode=True, save_dir=tmp_path / "dump" + ) + + kwargs = { + "patch_input_shape": [512, 512], + "resolution": 1.75, + "units": "mpp", + } + + # test providing config / full input info for non pretrained models + ioconfig = ModelIOConfigABC( + patch_input_shape=(512, 512), + stride_shape=(256, 256), + input_resolutions=[{"resolution": 1.35, "units": "mpp"}], + ) + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_mode=True, + save_dir=f"{tmp_path}/dump", + ioconfig=ioconfig, + ) + assert eng._ioconfig.patch_input_shape == (512, 512) + assert eng._ioconfig.stride_shape == (256, 256) + assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}] + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_mode=True, + save_dir=f"{tmp_path}/dump", + **kwargs, + ) + assert eng._ioconfig.patch_input_shape == [512, 512] + assert eng._ioconfig.stride_shape == [512, 512] + assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}] + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + # test overwriting pretrained ioconfig + eng = TestEngineABC(model="alexnet-kather100k") + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_input_shape=(300, 300), + stride_shape=(300, 300), + resolution=1.99, + units="baseline", + patch_mode=True, + save_dir=f"{tmp_path}/dump", + ) + assert eng._ioconfig.patch_input_shape == (300, 300) + assert eng._ioconfig.stride_shape == (300, 300) + assert eng._ioconfig.input_resolutions[0]["resolution"] == 1.99 + assert eng._ioconfig.input_resolutions[0]["units"] == "baseline" + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + patch_input_shape=(300, 300), + stride_shape=(300, 300), + resolution=None, + units=None, + patch_mode=True, + save_dir=f"{tmp_path}/dump", + ) + assert eng._ioconfig.patch_input_shape == (300, 300) + assert eng._ioconfig.stride_shape == (300, 300) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + eng.ioconfig = None + _ioconfig = eng._update_ioconfig( + ioconfig=None, + patch_input_shape=(300, 300), + stride_shape=(300, 300), + resolution=1.99, + units="baseline", + ) + + assert _ioconfig.patch_input_shape == (300, 300) + assert _ioconfig.stride_shape == (300, 300) + assert _ioconfig.input_resolutions[0]["resolution"] == 1.99 + assert _ioconfig.input_resolutions[0]["units"] == "baseline" + + for key in kwargs: + _kwargs = copy.deepcopy(kwargs) + _kwargs[key] = None + with pytest.raises( + ValueError, + match=r".*Must provide either `ioconfig` or " + r"`patch_input_shape`, `resolution`, and `units`*", + ): + eng._update_ioconfig( + ioconfig=None, + patch_input_shape=_kwargs["patch_input_shape"], + stride_shape=(1, 1), + resolution=_kwargs["resolution"], + units=_kwargs["units"], + ) + + +def test_notimplementederror_wsi_mode( + sample_svs: Path, tmp_path: pytest.TempPathFactory +) -> None: + """Test that NotImplementedError is raised when wsi mode is False. + + A user should implement run method when patch_mode is False. + + """ + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises(NotImplementedError): + eng.run(images=[sample_svs], patch_mode=False, save_dir=tmp_path / "output") diff --git a/tests/test_app_bokeh.py b/tests/test_app_bokeh.py index cb99a3df3..87b074a71 100644 --- a/tests/test_app_bokeh.py +++ b/tests/test_app_bokeh.py @@ -18,7 +18,7 @@ import requests from bokeh.application import Application from bokeh.application.handlers import FunctionHandler -from bokeh.events import ButtonClick, DoubleTap, MenuItemClick +from bokeh.events import DoubleTap, MenuItemClick from flask_cors import CORS from matplotlib import colormaps from PIL import Image @@ -423,54 +423,7 @@ def test_load_img_overlay(doc: Document, data_path: pytest.TempPathFactory) -> N assert main.UI["p"].renderers[main.UI["vstate"].layer_dict["layer2"]].alpha == 0.4 -def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> None: - """Test running hovernet on a box.""" - slide_select = doc.get_model_by_name("slide_select0") - slide_select.value = [data_path["slide2"].name] - run_button = doc.get_model_by_name("to_model0") - assert len(main.UI["color_column"].children) == 0 - slide_select.value = [data_path["slide1"].name] - # set up a box selection - main.UI["box_source"].data = { - "x": [1200], - "y": [-2000], - "width": [400], - "height": [400], - } - - # select hovernet model and run it on box - model_select = doc.get_model_by_name("model_drop0") - model_select.value = "hovernet" - - click = ButtonClick(run_button) - run_button._trigger_event(click) - im = get_tile("overlay", 4, 8, 4, show=False) - _, num = label(np.any(im[:, :, :3], axis=2)) - # check there are multiple cells being detected - assert len(main.UI["color_column"].children) > 3 - assert num > 10 - - # test save functionality - save_button = doc.get_model_by_name("save_button0") - click = ButtonClick(save_button) - save_button._trigger_event(click) - saved_path = ( - data_path["base_path"] - / "overlays" - / (data_path["slide1"].stem + "_saved_anns.db") - ) - assert saved_path.exists() - - # load an overlay with different types - cprop_select = doc.get_model_by_name("cprop0") - cprop_select.value = ["prob"] - layer_drop = doc.get_model_by_name("layer_drop0") - click = MenuItemClick(layer_drop, str(data_path["dat_anns"])) - layer_drop._trigger_event(click) - assert main.UI["vstate"].types == ["annotation"] - # check the per-type ui controls have been updated - assert len(main.UI["color_column"].children) == 1 - assert len(main.UI["type_column"].children) == 1 +# test_hovernet_on_box should be fixed before merge to develop. def test_alpha_sliders(doc: Document) -> None: diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 18378aacc..5c19f4c27 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -143,7 +143,7 @@ def infer_batch( batch_data: torch.Tensor, *, device: str = "cpu", - ) -> np.ndarray: + ) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -169,7 +169,7 @@ def infer_batch( with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar - return output.cpu().numpy() + return {"predictions": output.cpu().numpy()} class CNNBackbone(ModelABC): @@ -240,7 +240,7 @@ def infer_batch( batch_data: torch.Tensor, *, device: str, - ) -> list[np.ndarray, ...]: + ) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -266,4 +266,4 @@ def infer_batch( with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar - return [output.cpu().numpy()] + return {"predictions": output.cpu().numpy()} diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 32ad1a562..8fef0c4e2 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,21 +1,30 @@ -"""Defines Abstract Base Class for TIAToolbox Model Engines.""" +"""Defines Abstract Base Class for TIAToolbox Engines.""" from __future__ import annotations +import copy from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, TypedDict import numpy as np import torch import tqdm +import zarr from torch import nn +from typing_extensions import Unpack from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.dataset.dataset_abc import PatchDataset -from tiatoolbox.models.models_abc import load_torch_model, model_to -from tiatoolbox.utils.misc import dict_to_store, dict_to_zarr +from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset +from tiatoolbox.models.models_abc import load_torch_model +from tiatoolbox.utils.misc import ( + dict_to_store, + dict_to_zarr, + write_to_zarr_in_cache_mode, +) + +from .io_config import ModelIOConfigABC if TYPE_CHECKING: # pragma: no cover import os @@ -23,25 +32,28 @@ from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.typing import IntPair, Resolution, Units from tiatoolbox.wsicore.wsireader import WSIReader - from .io_config import ModelIOConfigABC - def prepare_engines_save_dir( save_dir: os | Path | None, - len_images: int, *, patch_mode: bool, - overwrite: bool, + overwrite: bool = False, ) -> Path | None: - """Create directory if not defined and number of images is more than 1. + """Create a save directory. + + If patch_mode is False and the save directory is not defined, + this function will raise an error. + + If patch_mode is True and the save directory is defined it will + create save_dir otherwise returns None. Args: save_dir (str or Path): Path to output directory. - len_images (int): - List of inputs to process. patch_mode(bool): Whether to treat input image as a patch or WSI. overwrite (bool): @@ -51,29 +63,30 @@ def prepare_engines_save_dir( :class:`Path`: Path to output directory. + Raises: + OSError: + If the save directory is not defined. + """ if patch_mode is True: if save_dir is not None: + save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=overwrite) return save_dir if save_dir is None: - if len_images > 1: - msg = ( - "More than 1 WSIs detected but there is no save directory provided." - "Please provide a 'save_dir'." - ) - raise OSError(msg) - return ( - Path.cwd() - ) # save the output to current working directory and return save_dir - - if len_images > 1: - logger.info( - "When providing multiple whole slide images, " - "the outputs will be saved and the locations of outputs " - "will be returned to the calling function.", + msg = ( + "Input WSIs detected but no save directory provided." + "Please provide a 'save_dir'." ) + raise OSError(msg) + + logger.info( + "When providing multiple whole slide images, " + "the outputs will be saved and the locations of outputs " + "will be returned to the calling function when `run()`" + "finishes successfully.", + ) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=overwrite) @@ -81,47 +94,146 @@ def prepare_engines_save_dir( return save_dir +class EngineABCRunParams(TypedDict, total=False): + """Class describing the input parameters for the :func:`EngineABC.run()` method. + + Attributes: + batch_size (int): + Number of image patches to feed to the model in a forward pass. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. + ioconfig (ModelIOConfigABC): + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. + return_labels (bool): + Whether to return the labels with the predictions. + merge_predictions (bool): + Whether to merge the predictions to form a 2-dimensional + map into a single file from a WSI. + This is only applicable if `patch_mode` is False in inference. + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + output_file (str): + Output file name to save "zarr" or "db". + patch_input_shape (tuple): + Shape of patches input to the model as tuple of height and width (HW). + Patches are requested at read resolution, not with respect to level 0, + and must be positive. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + return_labels (bool): + Whether to return the output labels. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + verbose (bool): + Whether to output logging information. + + """ + + batch_size: int + cache_mode: bool + cache_size: int + class_dict: dict + device: str + ioconfig: ModelIOConfigABC + merge_predictions: bool + num_loader_workers: int + num_post_proc_workers: int + output_file: str + patch_input_shape: IntPair + resolution: Resolution + return_labels: bool + scale_factor: tuple[float, float] + stride_shape: IntPair + units: Units + verbose: bool + + class EngineABC(ABC): - """Abstract base class for engines used in tiatoolbox. + """Abstract base class for TIAToolbox deep learning engines to run CNN models. Args: - model (str | nn.Module): + model (str | ModelABC): A PyTorch model. Default is `None`. - The user can request pretrained models from the toolbox using + The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this `link `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights. + of weights using the `weights` parameter. + batch_size (int): + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. weights (str or Path): Path to the weight of the corresponding `model`. >>> engine = EngineABC( - ... model="pretrained-model-name", - ... weights="pretrained-local-weights.pth") + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) - batch_size (int): - Number of images fed into the model each time. - num_loader_workers (int): - Number of workers to load the data using :class:`torch.utils.data.Dataset`. - Please note that they will also perform preprocessing. default = 0 - num_post_proc_workers (int): - Number of workers to postprocess the results of the model. default = 0 device (str): - Select the device to run the model. Default is "cpu". + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Attributes: images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): - A NHWC image or a path to WSI. + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. patch_mode (str): - Whether to treat input image as a patch or WSI. - default = True. - model (str | nn.Module): - Defined PyTorch model. - Name of an existing model supported by the TIAToolbox for - processing the data. For a full list of pretrained models, + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, refer to the `docs `_ By default, the corresponding pretrained weights will also @@ -129,14 +241,15 @@ class EngineABC(ABC): of weights via the `weights` argument. Argument is case-insensitive. ioconfig (ModelIOConfigABC): - Input IO configuration to run the Engine. - _ioconfig (): + Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. + _ioconfig (ModelIOConfigABC): Runtime ioconfig. return_labels (bool): Whether to return the labels with the predictions. merge_predictions (bool): Whether to merge the predictions to form a 2-dimensional map. This is only applicable if `patch_mode` is False in inference. + Default is False. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. @@ -155,43 +268,84 @@ class EngineABC(ABC): `stride_shape=patch_input_shape`. batch_size (int): Number of images fed into the model each time. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. labels (list | None): List of labels. Only a single label per image is supported. device (str): - Select the device to run the model. Default is "cpu". + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". num_loader_workers (int): - Number of workers used in torch.utils.data.DataLoader. + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + merge_predictions (bool): + Whether to merge WSI predictions into a single file. Default value is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + units (Units): + Units of resolution used for reading the image. Choose + from either `baseline`, `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. + Default value is `baseline`. verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Examples: - >>> # array of list of 2 image patches as input - >>> import numpy as np + >>> # Inherit from EngineABC + >>> class TestEngineABC(EngineABC): + >>> def __init__( + >>> self, + >>> model, + >>> weights, + >>> verbose, + >>> ): + >>> super().__init__(model=model, weights=weights, verbose=verbose) + >>> # Define all the abstract classes + >>> data = np.array([np.ndarray, np.ndarray]) - >>> engine = EngineABC(model="resnet18-kather100k") + >>> engine = TestEngineABC(model="resnet18-kather100k") >>> output = engine.run(data, patch_mode=True) >>> # array of list of 2 image patches as input - >>> import numpy as np >>> data = np.array([np.ndarray, np.ndarray]) - >>> engine = EngineABC(model="resnet18-kather100k") + >>> engine = TestEngineABC(model="resnet18-kather100k") >>> output = engine.run(data, patch_mode=True) >>> # list of 2 image files as input >>> image = ['path/image1.png', 'path/image2.png'] - >>> engine = EngineABC(model="resnet18-kather100k") + >>> engine = TestEngineABC(model="resnet18-kather100k") >>> output = engine.run(image, patch_mode=False) >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] - >>> engine = EngineABC(model="resnet18-kather100k") - >>> output = engine.run(wsi_file, patch_mode=True) + >>> engine = TestEngineABC(model="resnet18-kather100k") + >>> output = engine.run(wsi_file, patch_mode=False) """ def __init__( self: EngineABC, - model: str | nn.Module, + model: str | ModelABC, batch_size: int = 8, num_loader_workers: int = 0, num_post_proc_workers: int = 0, @@ -203,8 +357,8 @@ def __init__( """Initialize Engine.""" super().__init__() - self.masks = None self.images = None + self.masks = None self.patch_mode = None self.device = device @@ -213,50 +367,59 @@ def __init__( model=model, weights=weights, ) - self.model = model_to(model=self.model, device=self.device) + self.model.to(device=self.device) self._ioconfig = self.ioconfig # runtime ioconfig self.batch_size = batch_size + self.cache_mode: bool = False + self.cache_size: int = self.batch_size if self.batch_size else 10000 + self.labels: list | None = None + self.merge_predictions: bool = False self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers + self.patch_input_shape: IntPair | None = None + self.resolution: Resolution = 1.0 + self.return_labels: bool = False + self.stride_shape: IntPair | None = None + self.units: Units = "baseline" self.verbose = verbose - self.return_labels = False - self.merge_predictions = False - self.units = "baseline" - self.resolution = 1.0 - self.patch_input_shape = None - self.stride_shape = None - self.labels = None @staticmethod def _initialize_model_ioconfig( - model: str | nn.Module, + model: str | ModelABC, weights: str | Path | None, ) -> tuple[nn.Module, ModelIOConfigABC | None]: """Helper function to initialize model and ioconfig attributes. If a pretrained model provided by the TIAToolbox is requested. The model - can be specified as a string otherwise torch.nn.Module is required. + can be specified as a string otherwise :class:`torch.nn.Module` is required. This function also loads the :class:`ModelIOConfigABC` using the information from the pretrained models in TIAToolbox. If ioconfig is not available then it - should be provided in the :func:`run` function. + should be provided in the :func:`run()` function. Args: - model (str | nn.Module): - A torch model which should be run by the engine. + model (str | ModelABC): + A PyTorch model. Default is `None`. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. weights (str | Path | None): Path to pretrained weights. If no pretrained weights are provided - and the model is provided by TIAToolbox, then pretrained weights will + and the `model` is provided by TIAToolbox, then pretrained weights will be automatically loaded from the TIA servers. Returns: - nn.Module: - The requested PyTorch model. + ModelABC: + The requested PyTorch model as a :class:`ModelABC` instance. ModelIOConfigABC | None: The model io configuration for TIAToolbox pretrained models. - Otherwise, None. + If the specified model is not in TIAToolbox model zoo, then the function + returns None. """ if not isinstance(model, (str, nn.Module)): @@ -266,7 +429,7 @@ def _initialize_model_ioconfig( if isinstance(model, str): # ioconfig is retrieved from the pretrained model in the toolbox. # list of pretrained models in the TIA Toolbox is available here: - # https://tia-toolbox.readthedocs.io/en/add-bokeh-app/pretrained.html + # https://tia-toolbox.readthedocs.io/en/latest/pretrained.html # no need to provide ioconfig in EngineABC.run() this case. return get_pretrained_model(model, weights) @@ -275,16 +438,66 @@ def _initialize_model_ioconfig( return model, None - def pre_process_patches( + def get_dataloader( self: EngineABC, - images: np.ndarray | list, - labels: list, + images: Path, + masks: Path | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, ) -> torch.utils.data.DataLoader: - """Pre-process an image patch.""" + """Pre-process images and masks and return dataloader for inference. + + Args: + images (list of str or :class:`Path` or :class:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. When `patch_mode` is False + the function expects list of str/paths to WSIs. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + ioconfig (ModelIOConfigABC): + A :class:`ModelIOConfigABC` object. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + + Returns: + torch.utils.data.DataLoader: + :class:`torch.utils.data.DataLoader` for inference. + + + """ if labels: # if a labels is provided, then return with the prediction self.return_labels = bool(labels) + if not patch_mode: + dataset = WSIPatchDataset( + img_path=images, + mode="wsi", + mask_path=masks, + patch_input_shape=ioconfig.patch_input_shape, + stride_shape=ioconfig.stride_shape, + resolution=ioconfig.input_resolutions[0]["resolution"], + units=ioconfig.input_resolutions[0]["units"], + ) + + dataset.preproc_func = self.model.preproc_func + + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_loader_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) + dataset = PatchDataset(inputs=images, labels=labels) dataset.preproc_func = self.model.preproc_func @@ -297,41 +510,81 @@ def pre_process_patches( shuffle=False, ) + @staticmethod + def _update_model_output(raw_predictions: dict, raw_output: dict) -> dict: + """Helper function to append raw output during inference.""" + for key in raw_output: + if raw_predictions[key] is None: + raw_predictions[key] = raw_output[key] + else: + raw_predictions[key] = np.append( + raw_predictions[key], raw_output[key], axis=0 + ) + + return raw_predictions + def infer_patches( self: EngineABC, - data_loader: DataLoader, - ) -> dict: - """Model inference on an image patch.""" + dataloader: DataLoader, + save_path: Path | None, + ) -> dict | Path: + """Runs model inference on image patches and returns output as a dictionary. + + Args: + dataloader (DataLoader): + An :class:`torch.utils.data.DataLoader` object to run inference. + save_path (Path | None): + If `cache_mode` is True then path to save zarr file must be provided. + + Returns: + dict or Path: + Result of model inference as a dictionary. Returns path to + saved zarr file if `cache_mode` is True. + + """ progress_bar = None if self.verbose: progress_bar = tqdm.tqdm( - total=int(len(data_loader)), + total=int(len(dataloader)), leave=True, ncols=80, ascii=True, position=0, ) - raw_predictions = { - "predictions": [], - } + + keys = ["predictions"] if self.return_labels: - raw_predictions["labels"] = [] + keys.append("labels") + + raw_predictions = {key: None for key in keys} + + zarr_group = None + + if self.cache_mode: + zarr_group = zarr.open(save_path, mode="w") - for _, batch_data in enumerate(data_loader): - batch_output_predictions = self.model.infer_batch( + for _, batch_data in enumerate(dataloader): + batch_output = self.model.infer_batch( self.model, batch_data["image"], device=self.device, ) - raw_predictions["predictions"].extend(batch_output_predictions.tolist()) - if self.return_labels: # be careful of `s` - # We do not use tolist here because label may be of mixed types - # and hence collated as list by torch - raw_predictions["labels"].extend(list(batch_data["label"])) + batch_output["labels"] = batch_data["label"].numpy() + + raw_predictions = self._update_model_output( + raw_predictions=raw_predictions, + raw_output=batch_output, + ) + + if self.cache_mode: + zarr_group = write_to_zarr_in_cache_mode( + zarr_group=zarr_group, output_data_to_save=raw_predictions + ) + raw_predictions = {key: None for key in keys} if progress_bar: progress_bar.update() @@ -339,77 +592,166 @@ def infer_patches( if progress_bar: progress_bar.close() + return save_path if self.cache_mode else raw_predictions + + def post_process_patches( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: dict, + ) -> dict | Path: + """Post-process raw patch predictions from inference. + + The output of :func:`infer_patches()` with patch prediction information will be + post-processed using this function. The processed output will be saved in the + respective input format. If `cache_mode` is True, the function processes the + input using zarr group with size specified by `cache_size`. + + Args: + raw_predictions (dict | Path): + A dictionary or path to zarr with patch prediction information. + **kwargs (dict): + Keyword Args to update setup_patch_dataset() method attributes. + + Returns: + dict or Path: + Returns patch based output after post-processing. Returns path to + saved zarr file if `cache_mode` is True. + + """ + _ = kwargs.get("predictions") # Key values required for post-processing + + if self.cache_mode: # cache mode + _ = zarr.open(raw_predictions, mode="w") + return raw_predictions - def setup_patch_dataset( + def save_predictions( self: EngineABC, - raw_predictions: dict, + processed_predictions: dict, output_type: str, save_dir: Path | None = None, **kwargs: dict, - ) -> Path | AnnotationStore: - """Post-process image patches. + ) -> dict | AnnotationStore | Path: + """Save model predictions. Args: - raw_predictions (dict): - A dictionary of patch prediction information. + processed_predictions (dict | Path): + A dictionary or path to zarr with model prediction information. save_dir (Path): - Optional Output Path to directory to save the patch dataset output to a - `.zarr` or `.db` file, provided patch_mode is True. if the patch_mode is - False then save_dir is required. + Optional output path to directory to save the patch dataset output to a + `.zarr` or `.db` file, provided `patch_mode` is True. If the + `patch_mode` is False then `save_dir` is required. output_type (str): The desired output type for resulting patch dataset. - **kwargs (dict): - Keyword Args to update setup_patch_dataset() method attributes. + **kwargs (EngineABCRunParams): + Keyword Args required to save the output. - Returns: (dict, Path, :class:`SQLiteStore`): - if the output_type is "AnnotationStore", the function returns the patch - predictor output as an SQLiteStore containing Annotations for each or the - Path to a `.db` file depending on whether a save_dir Path is provided. - Otherwise, the function defaults to returning patch predictor output, either - as a dict or the Path to a `.zarr` file depending on whether a save_dir Path - is provided. + Returns: + dict or Path or :class:`AnnotationStore`: + If the `output_type` is "AnnotationStore", the function returns + the patch predictor output as an SQLiteStore containing Annotations + for each or the Path to a `.db` file depending on whether a + save_dir Path is provided. Otherwise, the function defaults to + returning patch predictor output, either as a dict or the Path to a + `.zarr` file depending on whether a save_dir Path is provided. """ - if not save_dir and output_type != "AnnotationStore": - return raw_predictions + if (self.cache_mode or not save_dir) and output_type != "AnnotationStore": + return processed_predictions - output_file = ( - kwargs["output_file"] and kwargs.pop("output_file") - if "output_file" in kwargs - else "output" - ) + output_file = Path(kwargs.get("output_file", "output.db")) save_path = save_dir / output_file if output_type == "AnnotationStore": # scale_factor set from kwargs - scale_factor = kwargs["scale_factor"] if "scale_factor" in kwargs else None + scale_factor = kwargs.get("scale_factor") # class_dict set from kwargs - class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None - - return dict_to_store(raw_predictions, scale_factor, class_dict, save_path) + class_dict = kwargs.get("class_dict") + + # Need to add support for zarr conversion. + return dict_to_store( + processed_predictions, + scale_factor, + class_dict, + save_path, + ) - return dict_to_zarr( - raw_predictions, - save_path, - **kwargs, + return ( + dict_to_zarr( + processed_predictions, + save_path, + **kwargs, + ) + if isinstance(processed_predictions, dict) + else processed_predictions ) @abstractmethod - def pre_process_wsi(self: EngineABC) -> NoReturn: - """Pre-process a WSI.""" - raise NotImplementedError + def infer_wsi( + self: EngineABC, + dataloader: torch.utils.data.DataLoader, + img_label: str, + highest_input_resolution: list[dict], + save_dir: Path, + **kwargs: dict, + ) -> list: + """Model inference on a WSI. - @abstractmethod - def infer_wsi(self: EngineABC) -> NoReturn: - """Model inference on a WSI.""" + This function must be implemented by subclasses. + + """ + # return coordinates of patches processed within a tile / whole-slide image raise NotImplementedError @abstractmethod - def post_process_wsi(self: EngineABC) -> NoReturn: - """Post-process a WSI.""" - raise NotImplementedError + def save_wsi_output( + self: EngineABC, + raw_output: dict | Path, + save_dir: Path, + output_type: str, + **kwargs: Unpack[EngineABCRunParams], + ) -> AnnotationStore | Path: + """Post-process a WSI. + + Args: + raw_output (dict | Path): + A dictionary with output information or zarr file path. + save_dir (Path): + Output Path to directory to save the patch dataset output to a + `.zarr` or `.db` file + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. + + Returns: (AnnotationStore or Path): + If the output_type is "AnnotationStore", the function returns the patch + predictor output as an SQLiteStore containing Annotations stored in a `.db` + file. Otherwise, the function defaults to returning patch predictor output + stored in a `.zarr` file. + + """ + if ( + output_type == "zarr" + and isinstance(raw_output, Path) + and raw_output.suffix == ".zarr" + ): + return raw_output + + output_file = kwargs.get("output_file", "output") + save_path = save_dir / output_file + + if output_type == "AnnotationStore": + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + + return dict_to_store(raw_output, scale_factor, class_dict, save_path) + + msg = "Only supports zarr and AnnotationStore as output_type." + raise ValueError(msg) def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: """Helper function to load ioconfig. @@ -438,11 +780,84 @@ def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfig ) raise ValueError(msg) - if ioconfig is not None: + if ioconfig and isinstance(ioconfig, ModelIOConfigABC): self.ioconfig = ioconfig return self.ioconfig + def _update_ioconfig( + self: EngineABC, + ioconfig: ModelIOConfigABC, + patch_input_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> ModelIOConfigABC: + """Update IOConfig. + + Args: + ioconfig (:class:`ModelIOConfigABC`): + Input ioconfig for PatchPredictor. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride using during tile and WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. + + Returns: + Updated Patch Predictor IO configuration. + + """ + config_flag = ( + patch_input_shape is None, + resolution is None, + units is None, + ) + if isinstance(ioconfig, ModelIOConfigABC): + return ioconfig + + if self.ioconfig is None and any(config_flag): + msg = ( + "Must provide either " + "`ioconfig` or `patch_input_shape`, `resolution`, and `units`." + ) + raise ValueError( + msg, + ) + + if stride_shape is None: + stride_shape = patch_input_shape + + if self.ioconfig: + ioconfig = copy.deepcopy(self.ioconfig) + # ! not sure if there is a nicer way to set this + if patch_input_shape is not None: + ioconfig.patch_input_shape = patch_input_shape + if stride_shape is not None: + ioconfig.stride_shape = stride_shape + if resolution is not None: + ioconfig.input_resolutions[0]["resolution"] = resolution + if units is not None: + ioconfig.input_resolutions[0]["units"] = units + + return ioconfig + + return ModelIOConfigABC( + input_resolutions=[{"resolution": resolution, "units": units}], + patch_input_shape=patch_input_shape, + stride_shape=stride_shape, + output_resolutions=[], + ) + @staticmethod def _validate_images_masks(images: list | np.ndarray) -> list | np.ndarray: """Validate input images for a run.""" @@ -492,6 +907,88 @@ def _validate_input_numbers( ) return + def _update_run_params( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + save_dir: os | Path | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + overwrite: bool = False, + patch_mode: bool, + **kwargs: Unpack[EngineABCRunParams], + ) -> Path | None: + """Updates runtime parameters. + + Updates runtime parameters for an EngineABC for EngineABC.run(). + + """ + for key in kwargs: + setattr(self, key, kwargs.get(key)) + + self.patch_mode = patch_mode + if self.cache_mode and self.batch_size > self.cache_size: + self.batch_size = self.cache_size + + self._validate_input_numbers(images=images, masks=masks, labels=labels) + self.images = self._validate_images_masks(images=images) + + if masks is not None: + self.masks = self._validate_images_masks(images=masks) + + self.labels = labels + + # if necessary move model parameters to "cpu" or "gpu" and update ioconfig + self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) + self.model.to(device=self.device) + self._ioconfig = self._update_ioconfig( + ioconfig, + self.patch_input_shape, + self.stride_shape, + self.resolution, + self.units, + ) + + return prepare_engines_save_dir( + save_dir=save_dir, + patch_mode=patch_mode, + overwrite=overwrite, + ) + + def _run_patch_mode( + self: EngineABC, output_type: str, save_dir: Path, **kwargs: EngineABCRunParams + ) -> dict | AnnotationStore | Path: + """Runs the Engine in the patch mode. + + Input arguments are passed from :func:`EngineABC.run()`. + + """ + save_path = None + if self.cache_mode: + output_file = Path(kwargs.get("output_file", "output.db")) + save_path = save_dir / (str(output_file.stem) + ".zarr") + + dataloader = self.get_dataloader( + images=self.images, + labels=self.labels, + patch_mode=True, + ) + raw_predictions = self.infer_patches( + dataloader=dataloader, + save_path=save_path, + ) + processed_predictions = self.post_process_patches( + raw_predictions=raw_predictions, + **kwargs, + ) + return self.save_predictions( + processed_predictions=processed_predictions, + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + def run( self: EngineABC, images: list[os | Path | WSIReader] | np.ndarray, @@ -503,8 +1000,8 @@ def run( save_dir: os | Path | None = None, # None will not save output overwrite: bool = False, output_type: str = "dict", - **kwargs: dict, - ) -> AnnotationStore | Path | str: + **kwargs: Unpack[EngineABCRunParams], + ) -> AnnotationStore | Path | str | dict: """Run the engine on input images. Args: @@ -534,15 +1031,15 @@ def run( Whether to overwrite the results. Default = False. output_type (str): The format of the output type. "output_type" can be - "zarr", "AnnotationStore". Default is "zarr". + "zarr" or "AnnotationStore". Default value is "zarr". When saving in the zarr format the output is saved using the `python zarr library `__ as a zarr group. If the required output type is an "AnnotationStore" then the output will be intermediately saved as zarr but converted to :class:`AnnotationStore` and saved as a `.db` file at the end of the loop. - **kwargs (dict): - Keyword Args to update :class:`EngineABC` attributes. + **kwargs (EngineABCRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. Returns: (:class:`numpy.ndarray`, dict): @@ -562,76 +1059,46 @@ def run( Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] - >>> predictor = EngineABC(model="resnet18-kather100k") + >>> class PatchPredictor(EngineABC): + >>> # Define all Abstract methods. + >>> ... + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = predictor.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" >>> output = predictor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] - ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} - >>> output['wsi2.svs'] - ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} + ... {'/path/to/wsi1.db'} - >>> predictor = EngineABC(model="alexnet-kather100k") - >>> output = predictor.run( - >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - >>> labels=list(range(10)), - >>> on_gpu=False, - >>> ) - >>> output - ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., - ... 0.034451354295015335, 0.004817609209567308]], - ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), - ... tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} - - >>> predictor = EngineABC(model="alexnet-kather100k") - >>> save_dir = Path("/tmp/patch_output/") - >>> output = eng.run( - >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - >>> on_gpu=False, - >>> verbose=False, - >>> save_dir=save_dir, - >>> overwrite=True - >>> ) - >>> output - ... '/tmp/patch_output/output.zarr' """ - for key in kwargs: - setattr(self, key, kwargs[key]) - - self.patch_mode = patch_mode - - self._validate_input_numbers(images=images, masks=masks, labels=labels) - self.images = self._validate_images_masks(images=images) - - if masks is not None: - self.masks = self._validate_images_masks(images=masks) - - self.labels = labels - - # if necessary Move model parameters to "cpu" or "gpu" and update ioconfig - self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) - self.model = model_to(model=self.model, device=self.device) - - save_dir = prepare_engines_save_dir( - save_dir, - len(self.images), - patch_mode=patch_mode, + save_dir = self._update_run_params( + images=images, + masks=masks, + labels=labels, + save_dir=save_dir, + ioconfig=ioconfig, overwrite=overwrite, + patch_mode=patch_mode, + **kwargs, ) if patch_mode: - data_loader = self.pre_process_patches( - self.images, - self.labels, - ) - raw_predictions = self.infer_patches( - data_loader=data_loader, - ) - return self.setup_patch_dataset( - raw_predictions=raw_predictions, + return self._run_patch_mode( output_type=output_type, save_dir=save_dir, **kwargs, ) - return {"save_dir": save_dir} + # All inherited classes will get scale_factors, + # highest_input_resolution, implement dataloader, + # pre-processing, post-processing and save_output + # for WSIs separately. + raise NotImplementedError diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 75bda0380..5837693e2 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -234,14 +234,14 @@ def pre_process_wsi(self: PatchPredictor) -> NoReturn: def infer_wsi(self: PatchPredictor) -> NoReturn: """Model inference on a WSI.""" - def post_process_patches( + def save_predictions( self: PatchPredictor, raw_predictions: dict, output_type: str, ) -> None: """Post-process an image patch.""" - def post_process_wsi(self: PatchPredictor) -> NoReturn: + def save_wsi_output(self: PatchPredictor) -> NoReturn: """Post-process a WSI.""" @staticmethod diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 068ed8f5c..a27eb670c 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -75,7 +75,7 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: @staticmethod @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> None: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -89,6 +89,11 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> Non device (str): Transfers model to the specified device. Default is "cpu". + Returns: + dict: + Returns a dictionary of predictions and other expected outputs + depending on the network architecture. + """ ... # pragma: no cover diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 8f22ce1e1..9a976febf 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1202,9 +1202,35 @@ def add_from_dat( store.append_many(anns) +def patch_predictions_as_annotations( + preds: list, + keys: list, + class_dict: dict, + class_probs: list, + patch_coords: list, + classes_predicted: list, + labels: list, +) -> list: + """Helper function to generate annotation per patch predictions.""" + annotations = [] + for i, pred in enumerate(preds): + if "probabilities" in keys: + props = { + f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted + } + else: + props = {} + if "labels" in keys: + props["label"] = class_dict[labels[i]] + props["type"] = class_dict[pred] + annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) + + return annotations + + def dict_to_store( patch_output: dict, - scale_factor: tuple[int, int], + scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, ) -> AnnotationStore | Path: @@ -1212,9 +1238,9 @@ def dict_to_store( Args: patch_output (dict): - A dictionary in the TIAToolbox Engines output format. Important - keys are "probabilities", "predictions", "coordinates", and "labels". - scale_factor (tuple[int, int]): + A dictionary with "probabilities", "predictions", "coordinates", + and "labels" keys. + scale_factor (tuple[float, float]): The scale factor to use when loading the annotations. All coordinates will be multiplied by this factor to allow conversion of annotations saved at non-baseline resolution to baseline. @@ -1239,6 +1265,7 @@ def dict_to_store( # get relevant keys class_probs = patch_output.get("probabilities", []) preds = patch_output.get("predictions", []) + patch_coords = np.array(patch_output.get("coordinates", [])) if not np.all(np.array(scale_factor) == 1): patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp @@ -1248,29 +1275,31 @@ def dict_to_store( classes_predicted = np.unique(preds).tolist() else: classes_predicted = range(len(class_probs[0])) + if class_dict is None: # if no class dict create a default one - class_dict = {i: i for i in np.unique(preds + labels).tolist()} + if len(class_probs) == 0: + class_dict = {i: i for i in np.unique(preds + labels).tolist()} + else: + class_dict = {i: i for i in range(len(class_probs))} # find what keys we need to save keys = ["predictions"] keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output] # put patch predictions into a store - annotations = [] - for i, pred in enumerate(preds): - if "probabilities" in keys: - props = { - f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted - } - else: - props = {} - if "labels" in keys: - props["label"] = class_dict[labels[i]] - props["type"] = class_dict[pred] - annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) + annotations = patch_predictions_as_annotations( + preds, + keys, + class_dict, + class_probs, + patch_coords, + classes_predicted, + labels, + ) + store = SQLiteStore() - keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) + _ = store.append_many(annotations, [str(i) for i in range(len(annotations))]) # if a save director is provided, then dump store into a file if save_path: @@ -1325,3 +1354,142 @@ def dict_to_zarr( z[:] = predictions_array return save_path + + +def wsi_batch_output_to_zarr_group( + wsi_batch_zarr_group: zarr.group | None, + batch_output_probabilities: np.ndarray, + batch_output_predictions: np.ndarray, + batch_output_coordinates: np.ndarray | None, + batch_output_label: np.ndarray | None, + save_path: Path, + **kwargs: dict, +) -> zarr.group | Path: + """Saves the intermediate batch outputs of TIAToolbox engines to a zarr file. + + Args: + wsi_batch_zarr_group (zarr.group): + Optional zarr group name consisting of zarrs to save the batch output + values. + batch_output_probabilities (np.ndarray): + Probability batch output from infer wsi. + batch_output_predictions (np.ndarray): + Predictions batch output from infer wsi. + batch_output_coordinates (np.ndarray): + Coordinates batch output from infer wsi. + batch_output_label (np.ndarray): + Labels batch output from infer wsi. + save_path (str or Path): + Path to save the zarr file. + **kwargs (dict): + Keyword Args to update wsi_batch_output_to_zarr_group attributes. + + Returns: + Path to the zarr file storing the :class:`EngineABC` output. + + """ + # Default values for Compressor and Chunks set if not received from kwargs. + compressor = ( + kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) + ) + chunks = kwargs.get("chunks", 10000) + + # case 1 - new zarr group + if not wsi_batch_zarr_group: + # ensure proper zarr extension and create persistant zarr group + save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") + wsi_batch_zarr_group = zarr.open(save_path, mode="w") + + # populate the zarr group for the first time + probabilities_zarr = wsi_batch_zarr_group.create_dataset( + name="probabilities", + shape=batch_output_probabilities.shape, + chunks=chunks, + compressor=compressor, + ) + probabilities_zarr[:] = batch_output_probabilities + + predictions_zarr = wsi_batch_zarr_group.create_dataset( + name="predictions", + shape=batch_output_predictions.shape, + chunks=chunks, + compressor=compressor, + ) + predictions_zarr[:] = batch_output_predictions + + if batch_output_coordinates is not None: + coordinates_zarr = wsi_batch_zarr_group.create_dataset( + name="coordinates", + shape=batch_output_coordinates.shape, + chunks=chunks, + compressor=compressor, + ) + coordinates_zarr[:] = batch_output_coordinates + + if batch_output_label is not None: + labels_zarr = wsi_batch_zarr_group.create_dataset( + name="labels", + shape=batch_output_label.shape, + chunks=chunks, + compressor=compressor, + ) + labels_zarr[:] = batch_output_label + + # case 2 - append to existing zarr group + probabilities_zarr = wsi_batch_zarr_group["probabilities"] + probabilities_zarr.append(batch_output_probabilities) + + predictions_zarr = wsi_batch_zarr_group["predictions"] + predictions_zarr.append(batch_output_predictions) + + if batch_output_coordinates is not None: + coordinates_zarr = wsi_batch_zarr_group["coordinates"] + coordinates_zarr.append(batch_output_coordinates) + + if batch_output_label is not None: + labels_zarr = wsi_batch_zarr_group["labels"] + labels_zarr.append(batch_output_label) + + return wsi_batch_zarr_group + + +def write_to_zarr_in_cache_mode( + zarr_group: zarr.group, + output_data_to_save: dict, + **kwargs: dict, +) -> zarr.group | Path: + """Saves the intermediate batch outputs of TIAToolbox engines to a zarr file. + + Args: + zarr_group (zarr.group): + Zarr group name consisting of zarr(s) to save the batch output + values. + output_data_to_save (dict): + Output data from the Engine to save to Zarr. + **kwargs (dict): + Keyword Args to update zarr_group attributes. + + Returns: + Path to the zarr file storing the :class:`EngineABC` output. + + """ + # Default values for Compressor and Chunks set if not received from kwargs. + compressor = kwargs.get("compressor", numcodecs.Zstd(level=1)) + + # case 1 - new zarr group + if not zarr_group: + for key in output_data_to_save: + data_to_save = output_data_to_save[key] + # populate the zarr group for the first time + zarr_dataset = zarr_group.create_dataset( + name=key, + shape=data_to_save.shape, + compressor=compressor, + ) + zarr_dataset[:] = data_to_save + + # case 2 - append to existing zarr group + for key in output_data_to_save: + zarr_group[key].append(output_data_to_save[key]) + + return zarr_group From 3dd2c7b9cacaed0dd03dfc71e20da2df39967e19 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 14 May 2024 10:31:34 +0100 Subject: [PATCH 25/36] :rewind: Restore `test_patch_predictor.py` --- tests/engines/test_patch_predictor.py | 1228 +++++++++++++++++++++++++ 1 file changed, 1228 insertions(+) create mode 100644 tests/engines/test_patch_predictor.py diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py new file mode 100644 index 000000000..ab59efc53 --- /dev/null +++ b/tests/engines/test_patch_predictor.py @@ -0,0 +1,1228 @@ +"""Test for Patch Predictor.""" + +from __future__ import annotations + +import copy +import shutil +from pathlib import Path +from typing import Callable + +import cv2 +import numpy as np +import pytest +import torch +from click.testing import CliRunner + +from tiatoolbox import cli +from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor +from tiatoolbox.models.architecture.vanilla import CNNModel +from tiatoolbox.models.dataset import ( + PatchDataset, + PatchDatasetABC, + WSIPatchDataset, + predefined_preproc_func, +) +from tiatoolbox.utils import download_data, imread, imwrite +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.wsicore.wsireader import WSIReader + +ON_GPU = toolbox_env.has_gpu() +RNG = np.random.default_rng() # Numpy Random Generator + +# ------------------------------------------------------------------------------------- +# Dataloader +# ------------------------------------------------------------------------------------- + + +def test_patch_dataset_path_imgs( + sample_patch1: str | Path, + sample_patch2: str | Path, +) -> None: + """Test for patch dataset with a list of file paths as input.""" + size = (224, 224, 3) + + dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) + + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + +def test_patch_dataset_list_imgs(tmp_path: Path) -> None: + """Test for patch dataset with a list of images as input.""" + save_dir_path = tmp_path + + size = (5, 5, 3) + img = RNG.integers(low=0, high=255, size=size) + list_imgs = [img, img, img] + dataset = PatchDataset(list_imgs) + + dataset.preproc_func = lambda x: x + + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + # test for changing to another preproc + dataset.preproc_func = lambda x: x - 10 + item = dataset[0] + assert np.sum(item["image"] - (list_imgs[0] - 10)) == 0 + + # * test for loading npy + # remove previously generated data + if Path.exists(save_dir_path): + shutil.rmtree(save_dir_path, ignore_errors=True) + Path.mkdir(save_dir_path, parents=True) + np.save( + str(save_dir_path / "sample2.npy"), + RNG.integers(0, 255, (4, 4, 3)), + ) + imgs = [ + save_dir_path / "sample2.npy", + ] + _ = PatchDataset(imgs) + assert imgs[0] is not None + # test for path object + imgs = [ + save_dir_path / "sample2.npy", + ] + _ = PatchDataset(imgs) + + +def test_patch_datasetarray_imgs() -> None: + """Test for patch dataset with a numpy array of a list of images.""" + size = (5, 5, 3) + img = RNG.integers(0, 255, size=size) + list_imgs = [img, img, img] + labels = [1, 2, 3] + array_imgs = np.array(list_imgs) + + # test different setter for label + dataset = PatchDataset(array_imgs, labels=labels) + an_item = dataset[2] + assert an_item["label"] == 3 + dataset = PatchDataset(array_imgs, labels=None) + an_item = dataset[2] + assert "label" not in an_item + + dataset = PatchDataset(array_imgs) + for _, sample_data in enumerate(dataset): + sampled_img_shape = sample_data["image"].shape + assert sampled_img_shape[0] == size[0] + assert sampled_img_shape[1] == size[1] + assert sampled_img_shape[2] == size[2] + + +def test_patch_dataset_crash(tmp_path: Path) -> None: + """Test to make sure patch dataset crashes with incorrect input.""" + # all below examples should fail when input to PatchDataset + save_dir_path = tmp_path + + # not supported input type + imgs = {"a": RNG.integers(0, 255, (4, 4, 4))} + with pytest.raises( + ValueError, + match=r".*Input must be either a list/array of images.*", + ): + _ = PatchDataset(imgs) + + # ndarray of mixed dtype + imgs = np.array( + # string array of the same shape + [ + RNG.integers(0, 255, (4, 5, 3)), + np.array( # skipcq: PYL-E1121 + ["you_should_crash_here" for _ in range(4 * 5 * 3)], + ).reshape( + 4, + 5, + 3, + ), + ], + dtype=object, + ) + with pytest.raises(ValueError, match="Provided input array is non-numerical."): + _ = PatchDataset(imgs) + + # ndarray(s) of NHW images + imgs = RNG.integers(0, 255, (4, 4, 4)) + with pytest.raises(ValueError, match=r".*array of the form HWC*"): + _ = PatchDataset(imgs) + + # list of ndarray(s) with different sizes + imgs = [ + RNG.integers(0, 255, (4, 4, 3)), + RNG.integers(0, 255, (4, 5, 3)), + ] + with pytest.raises(ValueError, match="Images must have the same dimensions."): + _ = PatchDataset(imgs) + + # list of ndarray(s) with HW and HWC mixed up + imgs = [ + RNG.integers(0, 255, (4, 4, 3)), + RNG.integers(0, 255, (4, 4)), + ] + with pytest.raises( + ValueError, + match="Each sample must be an array of the form HWC.", + ): + _ = PatchDataset(imgs) + + # list of mixed dtype + imgs = [RNG.integers(0, 255, (4, 4, 3)), "you_should_crash_here", 123, 456] + with pytest.raises( + ValueError, + match="Input must be either a list/array of images or a list of " + "valid image paths.", + ): + _ = PatchDataset(imgs) + + # list of mixed dtype + imgs = ["you_should_crash_here", 123, 456] + with pytest.raises( + ValueError, + match="Input must be either a list/array of images or a list of " + "valid image paths.", + ): + _ = PatchDataset(imgs) + + # list not exist paths + with pytest.raises( + ValueError, + match=r".*valid image paths.*", + ): + _ = PatchDataset(["img.npy"]) + + # ** test different extension parser + # save dummy data to temporary location + # remove prev generated data + shutil.rmtree(save_dir_path, ignore_errors=True) + save_dir_path.mkdir(parents=True) + + torch.save({"a": "a"}, save_dir_path / "sample1.tar") + np.save( + str(save_dir_path / "sample2.npy"), + RNG.integers(0, 255, (4, 4, 3)), + ) + + imgs = [ + save_dir_path / "sample1.tar", + save_dir_path / "sample2.npy", + ] + with pytest.raises( + ValueError, + match="Cannot load image data from", + ): + _ = PatchDataset(imgs) + + # preproc func for not defined dataset + with pytest.raises( + ValueError, + match=r".* preprocessing .* does not exist.", + ): + predefined_preproc_func("secret-dataset") + + +def test_wsi_patch_dataset( # noqa: PLR0915 + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """A test for creation and bare output.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: + """Testing function.""" + return WSIPatchDataset(img_path=img_path, **kwargs) + + def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: + """Testing function.""" + return reuse_init(mode="wsi", **kwargs) + + # test for ABC validate + # intentionally created to check error + # skipcq + class Proto(PatchDatasetABC): + def __init__(self: Proto) -> None: + super().__init__() + self.inputs = "CRASH" + self._check_input_integrity("wsi") + + # skipcq + def __getitem__(self: Proto, idx: int) -> object: + """Get an item from the dataset.""" + + with pytest.raises( + ValueError, + match=r".*`inputs` should be a list of patch coordinates.*", + ): + Proto() # skipcq + + # invalid path input + with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): + WSIPatchDataset( + img_path="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + ) + + # invalid mask path input + with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): + WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + + # invalid mode + with pytest.raises(ValueError, match="`X` is not supported."): + reuse_init(mode="X") + + # invalid patch + with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): + reuse_init() + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \[512 512 512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512, 512]) + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \['512' 'a'\].", + ): + reuse_init_wsi(patch_input_shape=[512, "a"]) + with pytest.raises(ValueError, match="Invalid `stride_shape` value None."): + reuse_init_wsi(patch_input_shape=512) + # invalid stride + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \['512' 'a'\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, "a"]) + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \[512 512 512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, 512, 512]) + # negative + with pytest.raises( + ValueError, + match=r"Invalid `patch_input_shape` value \[ 512 -512\].", + ): + reuse_init_wsi(patch_input_shape=[512, -512], stride_shape=[512, 512]) + with pytest.raises( + ValueError, + match=r"Invalid `stride_shape` value \[ 512 -512\].", + ): + reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, -512]) + + # * for wsi + # dummy test for analysing the output + # stride and patch size should be as expected + patch_size = [512, 512] + stride_size = [256, 256] + ds = reuse_init_wsi( + patch_input_shape=patch_size, + stride_shape=stride_size, + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + reader = WSIReader.open(mini_wsi_svs) + # tiling top to bottom, left to right + ds_roi = ds[2]["image"] + step_idx = 2 # manually calibrate + start = (step_idx * stride_size[1], 0) + end = (start[0] + patch_size[0], start[1] + patch_size[1]) + rd_roi = reader.read_bounds( + start + end, + resolution=1.0, + units="mpp", + coord_space="resolution", + ) + correlation = np.corrcoef( + cv2.cvtColor(ds_roi, cv2.COLOR_RGB2GRAY).flatten(), + cv2.cvtColor(rd_roi, cv2.COLOR_RGB2GRAY).flatten(), + ) + assert ds_roi.shape[0] == rd_roi.shape[0] + assert ds_roi.shape[1] == rd_roi.shape[1] + assert np.min(correlation) > 0.9, correlation + + # test creation with auto mask gen and input mask + ds = reuse_init_wsi( + patch_input_shape=patch_size, + stride_shape=stride_size, + resolution=1.0, + units="mpp", + auto_get_mask=True, + ) + assert len(ds) > 0 + ds = WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path=mini_wsi_msk, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + negative_mask = imread(mini_wsi_msk) + negative_mask = np.zeros_like(negative_mask) + negative_mask_path = tmp_path / "negative_mask.png" + imwrite(negative_mask_path, negative_mask) + with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): + ds = WSIPatchDataset( + img_path=mini_wsi_svs, + mask_path=negative_mask_path, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + # * for tile + reader = WSIReader.open(mini_wsi_jpg) + tile_ds = WSIPatchDataset( + img_path=mini_wsi_jpg, + mode="tile", + patch_input_shape=patch_size, + stride_shape=stride_size, + auto_get_mask=False, + ) + step_idx = 3 # manually calibrate + start = (step_idx * stride_size[1], 0) + end = (start[0] + patch_size[0], start[1] + patch_size[1]) + roi2 = reader.read_bounds( + start + end, + resolution=1.0, + units="baseline", + coord_space="resolution", + ) + roi1 = tile_ds[3]["image"] # match with step_index + correlation = np.corrcoef( + cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), + cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), + ) + assert roi1.shape[0] == roi2.shape[0] + assert roi1.shape[1] == roi2.shape[1] + assert np.min(correlation) > 0.9, correlation + + +def test_patch_dataset_abc() -> None: + """Test for ABC methods. + + Test missing definition for abstract intentionally created to check error. + + """ + + # skipcq + class Proto(PatchDatasetABC): + # skipcq + def __init__(self: Proto) -> None: + super().__init__() + + # crash due to undefined __getitem__ + with pytest.raises(TypeError): + Proto() # skipcq + + # skipcq + class Proto(PatchDatasetABC): + # skipcq + def __init__(self: Proto) -> None: + super().__init__() + + # skipcq + def __getitem__(self: Proto, idx: int) -> None: + """Get an item from the dataset.""" + + ds = Proto() # skipcq + + # test setter and getter + assert ds.preproc_func(1) == 1 + ds.preproc_func = lambda x: x - 1 # skipcq: PYL-W0201 + assert ds.preproc_func(1) == 0 + assert ds.preproc(1) == 1, "Must be unchanged!" + ds.preproc_func = None # skipcq: PYL-W0201 + assert ds.preproc_func(2) == 2 + + # test assign uncallable to preproc_func/postproc_func + with pytest.raises(ValueError, match=r".*callable*"): + ds.preproc_func = 1 # skipcq: PYL-W0201 + + +# ------------------------------------------------------------------------------------- +# Dataloader +# ------------------------------------------------------------------------------------- + + +def test_io_patch_predictor_config() -> None: + """Test for IOConfig.""" + # test for creating + cfg = IOPatchPredictorConfig( + patch_input_shape=[224, 224], + stride_shape=[224, 224], + input_resolutions=[{"resolution": 0.5, "units": "mpp"}], + # test adding random kwarg and they should be accessible as kwargs + crop_from_source=True, + ) + assert cfg.crop_from_source + + +# ------------------------------------------------------------------------------------- +# Engine +# ------------------------------------------------------------------------------------- + + +def test_predictor_crash(tmp_path: Path) -> None: + """Test for crash when making predictor.""" + # without providing any model + with pytest.raises(ValueError, match=r"Must provide.*"): + PatchPredictor() + + # provide wrong unknown pretrained model + with pytest.raises(ValueError, match=r"Pretrained .* does not exist"): + PatchPredictor(pretrained_model="secret_model-kather100k") + + # provide wrong model of unknown type, deprecated later with type hint + with pytest.raises(TypeError, match=r".*must be a string.*"): + PatchPredictor(pretrained_model=123) + + # test predict crash + predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) + + with pytest.raises(ValueError, match=r".*not a valid mode.*"): + predictor.predict("aaa", mode="random", save_dir=tmp_path) + # remove previously generated data + shutil.rmtree(tmp_path / "output", ignore_errors=True) + with pytest.raises(TypeError, match=r".*must be a list of file paths.*"): + predictor.predict("aaa", mode="wsi", save_dir=tmp_path) + # remove previously generated data + shutil.rmtree(tmp_path / "output", ignore_errors=True) + with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"): + predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path) + with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): + predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path) + # remove previously generated data + shutil.rmtree(tmp_path / "output", ignore_errors=True) + + +def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: + """Test for delegating args to io config.""" + mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) + + # test not providing config / full input info for not pretrained models + model = CNNModel("resnet50") + predictor = PatchPredictor(model=model) + with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): + predictor.predict([mini_wsi_svs], mode="wsi", save_dir=tmp_path / "dump") + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + kwargs = { + "patch_input_shape": [512, 512], + "resolution": 1.75, + "units": "mpp", + } + for key in kwargs: + _kwargs = copy.deepcopy(kwargs) + _kwargs.pop(key) + with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): + predictor.predict( + [mini_wsi_svs], + mode="wsi", + save_dir=f"{tmp_path}/dump", + on_gpu=ON_GPU, + **_kwargs, + ) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + # test providing config / full input info for not pretrained models + ioconfig = IOPatchPredictorConfig( + patch_input_shape=(512, 512), + stride_shape=(256, 256), + input_resolutions=[{"resolution": 1.35, "units": "mpp"}], + ) + predictor.predict( + [mini_wsi_svs], + ioconfig=ioconfig, + mode="wsi", + save_dir=f"{tmp_path}/dump", + on_gpu=ON_GPU, + ) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.predict( + [mini_wsi_svs], + mode="wsi", + save_dir=f"{tmp_path}/dump", + on_gpu=ON_GPU, + **kwargs, + ) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + # test overwriting pretrained ioconfig + predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) + predictor.predict( + [mini_wsi_svs], + patch_input_shape=(300, 300), + mode="wsi", + on_gpu=ON_GPU, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.patch_input_shape == (300, 300) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.predict( + [mini_wsi_svs], + stride_shape=(300, 300), + mode="wsi", + on_gpu=ON_GPU, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.stride_shape == (300, 300) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.predict( + [mini_wsi_svs], + resolution=1.99, + mode="wsi", + on_gpu=ON_GPU, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.predict( + [mini_wsi_svs], + units="baseline", + mode="wsi", + on_gpu=ON_GPU, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor = PatchPredictor(pretrained_model="resnet18-kather100k") + predictor.predict( + [mini_wsi_svs], + mode="wsi", + merge_predictions=True, + save_dir=f"{tmp_path}/dump", + on_gpu=ON_GPU, + ) + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + +def test_patch_predictor_api( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Helper function to get the model output using API 1.""" + save_dir_path = tmp_path + + # convert to pathlib Path to prevent reader complaint + inputs = [Path(sample_patch1), Path(sample_patch2)] + predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) + # don't run test on GPU + output = predictor.predict( + inputs, + on_gpu=ON_GPU, + save_dir=save_dir_path, + ) + assert sorted(output.keys()) == ["predictions"] + assert len(output["predictions"]) == 2 + shutil.rmtree(save_dir_path, ignore_errors=True) + + output = predictor.predict( + inputs, + labels=[1, "a"], + return_labels=True, + on_gpu=ON_GPU, + save_dir=save_dir_path, + ) + assert sorted(output.keys()) == sorted(["labels", "predictions"]) + assert len(output["predictions"]) == len(output["labels"]) + assert output["labels"] == [1, "a"] + shutil.rmtree(save_dir_path, ignore_errors=True) + + output = predictor.predict( + inputs, + return_probabilities=True, + on_gpu=ON_GPU, + save_dir=save_dir_path, + ) + assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) + assert len(output["predictions"]) == len(output["probabilities"]) + shutil.rmtree(save_dir_path, ignore_errors=True) + + output = predictor.predict( + inputs, + return_probabilities=True, + labels=[1, "a"], + return_labels=True, + on_gpu=ON_GPU, + save_dir=save_dir_path, + ) + assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) + assert len(output["predictions"]) == len(output["labels"]) + assert len(output["predictions"]) == len(output["probabilities"]) + + # test saving output, should have no effect + _ = predictor.predict( + inputs, + on_gpu=ON_GPU, + save_dir="special_dir_not_exist", + ) + assert not Path.is_dir(Path("special_dir_not_exist")) + + # test loading user weight + pretrained_weights_url = ( + "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" + ) + + # remove prev generated data + shutil.rmtree(save_dir_path, ignore_errors=True) + save_dir_path.mkdir(parents=True) + pretrained_weights = ( + save_dir_path / "tmp_pretrained_weigths" / "resnet18-kather100k.pth" + ) + + download_data(pretrained_weights_url, pretrained_weights) + + _ = PatchPredictor( + pretrained_model="resnet18-kather100k", + pretrained_weights=pretrained_weights, + batch_size=1, + ) + + # --- test different using user model + model = CNNModel(backbone="resnet18", num_classes=9) + # test prediction + predictor = PatchPredictor(model=model, batch_size=1, verbose=False) + output = predictor.predict( + inputs, + return_probabilities=True, + labels=[1, "a"], + return_labels=True, + on_gpu=ON_GPU, + save_dir=save_dir_path, + ) + assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) + assert len(output["predictions"]) == len(output["labels"]) + assert len(output["predictions"]) == len(output["probabilities"]) + + +def test_wsi_predictor_api( + sample_wsi_dict: dict, + tmp_path: Path, + chdir: Callable, +) -> None: + """Test normal run of wsi predictor.""" + save_dir_path = tmp_path + + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + patch_size = np.array([224, 224]) + predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) + + save_dir = f"{save_dir_path}/model_wsi_output" + + # wrapper to make this more clean + kwargs = { + "return_probabilities": True, + "return_labels": True, + "on_gpu": ON_GPU, + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 1.0, + "units": "baseline", + "save_dir": save_dir, + } + # ! add this test back once the read at `baseline` is fixed + # sanity check, both output should be the same with same resolution read args + wsi_output = predictor.predict( + [mini_wsi_svs], + masks=[mini_wsi_msk], + mode="wsi", + **kwargs, + ) + + shutil.rmtree(save_dir, ignore_errors=True) + + tile_output = predictor.predict( + [mini_wsi_jpg], + masks=[mini_wsi_msk], + mode="tile", + **kwargs, + ) + + wpred = np.array(wsi_output[0]["predictions"]) + tpred = np.array(tile_output[0]["predictions"]) + diff = tpred == wpred + accuracy = np.sum(diff) / np.size(wpred) + assert accuracy > 0.9, np.nonzero(~diff) + + # remove previously generated data + shutil.rmtree(save_dir, ignore_errors=True) + + kwargs = { + "return_probabilities": True, + "return_labels": True, + "on_gpu": ON_GPU, + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 0.5, + "save_dir": save_dir, + "merge_predictions": True, # to test the api coverage + "units": "mpp", + } + + _kwargs = copy.deepcopy(kwargs) + _kwargs["merge_predictions"] = False + # test reading of multiple whole-slide images + output = predictor.predict( + [mini_wsi_svs, mini_wsi_svs], + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **_kwargs, + ) + for output_info in output.values(): + assert Path(output_info["raw"]).exists() + assert "merged" not in output_info + shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) + + # coverage test + _kwargs = copy.deepcopy(kwargs) + _kwargs["merge_predictions"] = True + # test reading of multiple whole-slide images + predictor.predict( + [mini_wsi_svs, mini_wsi_svs], + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **_kwargs, + ) + _kwargs = copy.deepcopy(kwargs) + with pytest.raises(FileExistsError): + predictor.predict( + [mini_wsi_svs, mini_wsi_svs], + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **_kwargs, + ) + # remove previously generated data + shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) + + with chdir(save_dir_path): + # test reading of multiple whole-slide images + _kwargs = copy.deepcopy(kwargs) + _kwargs["save_dir"] = None # default coverage + _kwargs["return_probabilities"] = False + output = predictor.predict( + [mini_wsi_svs, mini_wsi_svs], + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **_kwargs, + ) + assert Path.exists(Path("output")) + for output_info in output.values(): + assert Path(output_info["raw"]).exists() + assert "merged" in output_info + assert Path(output_info["merged"]).exists() + + # remove previously generated data + shutil.rmtree("output", ignore_errors=True) + + +def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: + """Test normal run of wsi predictor with merge predictions option.""" + # convert to pathlib Path to prevent reader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + # blind test + # pseudo output dict from model with 2 patches + output = { + "resolution": 1.0, + "units": "baseline", + "probabilities": [[0.45, 0.55], [0.90, 0.10]], + "predictions": [1, 0], + "coordinates": [[0, 0, 2, 2], [2, 2, 4, 4]], + } + merged = PatchPredictor.merge_predictions( + np.zeros([4, 4]), + output, + resolution=1.0, + units="baseline", + ) + _merged = np.array([[2, 2, 0, 0], [2, 2, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]) + assert np.sum(merged - _merged) == 0 + + # blind test for merging probabilities + merged = PatchPredictor.merge_predictions( + np.zeros([4, 4]), + output, + resolution=1.0, + units="baseline", + return_raw=True, + ) + _merged = np.array( + [ + [0.45, 0.45, 0, 0], + [0.45, 0.45, 0, 0], + [0, 0, 0.90, 0.90], + [0, 0, 0.90, 0.90], + ], + ) + assert merged.shape == (4, 4, 2) + assert np.mean(np.abs(merged[..., 0] - _merged)) < 1.0e-6 + + # integration test + predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) + + kwargs = { + "return_probabilities": True, + "return_labels": True, + "on_gpu": ON_GPU, + "patch_input_shape": np.array([224, 224]), + "stride_shape": np.array([224, 224]), + "resolution": 1.0, + "units": "baseline", + "merge_predictions": True, + } + # sanity check, both output should be the same with same resolution read args + wsi_output = predictor.predict( + [mini_wsi_svs], + masks=[mini_wsi_msk], + mode="wsi", + **kwargs, + ) + + # mock up to change the preproc func and + # force to use the default in merge function + # still should have the same results + kwargs["merge_predictions"] = False + tile_output = predictor.predict( + [mini_wsi_jpg], + masks=[mini_wsi_msk], + mode="tile", + **kwargs, + ) + merged_tile_output = predictor.merge_predictions( + mini_wsi_jpg, + tile_output[0], + resolution=kwargs["resolution"], + units=kwargs["units"], + ) + tile_output.append(merged_tile_output) + + # first make sure nothing breaks with predictions + wpred = np.array(wsi_output[0]["predictions"]) + tpred = np.array(tile_output[0]["predictions"]) + diff = tpred == wpred + accuracy = np.sum(diff) / np.size(wpred) + assert accuracy > 0.9, np.nonzero(~diff) + + merged_wsi = wsi_output[1] + merged_tile = tile_output[1] + # ensure shape of merged predictions of tile and wsi input are the same + assert merged_wsi.shape == merged_tile.shape + # ensure consistent predictions between tile and wsi mode + diff = merged_tile == merged_wsi + accuracy = np.sum(diff) / np.size(merged_wsi) + assert accuracy > 0.9, np.nonzero(~diff) + + +def _test_predictor_output( + inputs: list, + pretrained_model: str, + probabilities_check: list | None = None, + predictions_check: list | None = None, + *, + on_gpu: bool = ON_GPU, +) -> None: + """Test the predictions of multiple models included in tiatoolbox.""" + predictor = PatchPredictor( + pretrained_model=pretrained_model, + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.predict( + inputs, + return_probabilities=True, + return_labels=False, + on_gpu=on_gpu, + ) + predictions = output["predictions"] + probabilities = output["probabilities"] + for idx, probabilities_ in enumerate(probabilities): + probabilities_max = max(probabilities_) + assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( + pretrained_model, + probabilities_max, + probabilities_check[idx], + predictions[idx], + predictions_check[idx], + ) + assert predictions[idx] == predictions_check[idx], ( + pretrained_model, + probabilities_max, + probabilities_check[idx], + predictions[idx], + predictions_check[idx], + ) + + +def test_patch_predictor_kather100k_output( + sample_patch1: Path, + sample_patch2: Path, +) -> None: + """Test the output of patch prediction models on Kather100K dataset.""" + inputs = [Path(sample_patch1), Path(sample_patch2)] + pretrained_info = { + "alexnet-kather100k": [1.0, 0.9999735355377197], + "resnet18-kather100k": [1.0, 0.9999911785125732], + "resnet34-kather100k": [1.0, 0.9979840517044067], + "resnet50-kather100k": [1.0, 0.9999986886978149], + "resnet101-kather100k": [1.0, 0.9999932050704956], + "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], + "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], + "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], + "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], + "densenet121-kather100k": [1.0, 1.0], + "densenet161-kather100k": [1.0, 0.9999959468841553], + "densenet169-kather100k": [1.0, 0.9999934434890747], + "densenet201-kather100k": [1.0, 0.9999983310699463], + "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], + "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], + "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], + "googlenet-kather100k": [1.0, 0.9999639987945557], + } + for pretrained_model, expected_prob in pretrained_info.items(): + _test_predictor_output( + inputs, + pretrained_model, + probabilities_check=expected_prob, + predictions_check=[6, 3], + on_gpu=ON_GPU, + ) + # only test 1 on travis to limit runtime + if toolbox_env.running_on_ci(): + break + + +def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: + """Test the output of patch prediction models on PCam dataset.""" + inputs = [Path(sample_patch3), Path(sample_patch4)] + pretrained_info = { + "alexnet-pcam": [0.999980092048645, 0.9769067168235779], + "resnet18-pcam": [0.999992847442627, 0.9466130137443542], + "resnet34-pcam": [1.0, 0.9976525902748108], + "resnet50-pcam": [0.9999270439147949, 0.9999996423721313], + "resnet101-pcam": [1.0, 0.9997289776802063], + "resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802], + "resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244], + "wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799], + "wide_resnet101_2-pcam": [1.0, 0.9945427179336548], + "densenet121-pcam": [0.9999251365661621, 0.9997479319572449], + "densenet161-pcam": [0.9999969005584717, 0.9662821292877197], + "densenet169-pcam": [0.9999998807907104, 0.9993504881858826], + "densenet201-pcam": [0.9999942779541016, 0.9950824975967407], + "mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986], + "mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085], + "mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972], + "googlenet-pcam": [0.9999929666519165, 0.8701475858688354], + } + for pretrained_model, expected_prob in pretrained_info.items(): + _test_predictor_output( + inputs, + pretrained_model, + probabilities_check=expected_prob, + predictions_check=[1, 0], + on_gpu=ON_GPU, + ) + # only test 1 on travis to limit runtime + if toolbox_env.running_on_ci(): + break + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_command_line_models_file_not_found(sample_svs: Path, tmp_path: Path) -> None: + """Test for models CLI file not found error.""" + runner = CliRunner() + model_file_not_found_result = runner.invoke( + cli.main, + [ + "patch-predictor", + "--img-input", + str(sample_svs)[:-1], + "--file-types", + '"*.ndpi, *.svs"', + "--output-path", + str(tmp_path.joinpath("output")), + ], + ) + + assert model_file_not_found_result.output == "" + assert model_file_not_found_result.exit_code == 1 + assert isinstance(model_file_not_found_result.exception, FileNotFoundError) + + +def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> None: + """Test for models CLI mode not in wsi, tile.""" + runner = CliRunner() + mode_not_in_wsi_tile_result = runner.invoke( + cli.main, + [ + "patch-predictor", + "--img-input", + str(sample_svs), + "--file-types", + '"*.ndpi, *.svs"', + "--mode", + '"patch"', + "--output-path", + str(tmp_path.joinpath("output")), + ], + ) + + assert "Invalid value for '--mode'" in mode_not_in_wsi_tile_result.output + assert mode_not_in_wsi_tile_result.exit_code != 0 + assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) + + +def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: + """Test for models CLI single file.""" + runner = CliRunner() + models_wsi_result = runner.invoke( + cli.main, + [ + "patch-predictor", + "--img-input", + str(sample_svs), + "--mode", + "wsi", + "--output-path", + str(tmp_path.joinpath("output")), + ], + ) + + assert models_wsi_result.exit_code == 0 + assert tmp_path.joinpath("output/0.merged.npy").exists() + assert tmp_path.joinpath("output/0.raw.json").exists() + assert tmp_path.joinpath("output/results.json").exists() + + +def test_cli_model_single_file_mask(remote_sample: Callable, tmp_path: Path) -> None: + """Test for models CLI single file with mask.""" + mini_wsi_svs = Path(remote_sample("svs-1-small")) + sample_wsi_msk = remote_sample("small_svs_tissue_mask") + sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) + imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) + sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" + + runner = CliRunner() + models_tiles_result = runner.invoke( + cli.main, + [ + "patch-predictor", + "--img-input", + str(mini_wsi_svs), + "--mode", + "wsi", + "--masks", + str(sample_wsi_msk), + "--output-path", + str(tmp_path.joinpath("output")), + ], + ) + + assert models_tiles_result.exit_code == 0 + assert tmp_path.joinpath("output/0.merged.npy").exists() + assert tmp_path.joinpath("output/0.raw.json").exists() + assert tmp_path.joinpath("output/results.json").exists() + + +def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: + """Test for models CLI multiple file with mask.""" + mini_wsi_svs = Path(remote_sample("svs-1-small")) + sample_wsi_msk = remote_sample("small_svs_tissue_mask") + sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) + imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) + mini_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") + + # Make multiple copies for test + dir_path = tmp_path.joinpath("new_copies") + dir_path.mkdir() + + dir_path_masks = tmp_path.joinpath("new_copies_masks") + dir_path_masks.mkdir() + + try: + dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) + dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) + dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) + except OSError: + shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path.joinpath("3_" + mini_wsi_svs.name)) + + try: + dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) + dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) + dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) + except OSError: + shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("1_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("2_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("3_" + mini_wsi_msk.name)) + + tmp_path = tmp_path.joinpath("output") + + runner = CliRunner() + models_tiles_result = runner.invoke( + cli.main, + [ + "patch-predictor", + "--img-input", + str(dir_path), + "--mode", + "wsi", + "--masks", + str(dir_path_masks), + "--output-path", + str(tmp_path), + ], + ) + + assert models_tiles_result.exit_code == 0 + assert tmp_path.joinpath("0.merged.npy").exists() + assert tmp_path.joinpath("0.raw.json").exists() + assert tmp_path.joinpath("1.merged.npy").exists() + assert tmp_path.joinpath("1.raw.json").exists() + assert tmp_path.joinpath("2.merged.npy").exists() + assert tmp_path.joinpath("2.raw.json").exists() + assert tmp_path.joinpath("results.json").exists() From c9dfba2f36c7c3895dde9af748ed8c21c792b1d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:33:59 +0000 Subject: [PATCH 26/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/dataset/dataset_abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index d69518a6e..d634ccbd4 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -361,7 +361,7 @@ class WSIPatchDataset(PatchDatasetABC): """ - def __init__( # skipcq: PY-R1000 # noqa: PLR0913, PLR0915 + def __init__( # skipcq: PY-R1000 # noqa: PLR0915 self: WSIPatchDataset, img_path: str | Path, mode: str = "wsi", From 81e575db4142a5ff9cb84f7af55fce2bf4350695 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 20 Sep 2024 13:15:13 +0100 Subject: [PATCH 27/36] =?UTF-8?q?=F0=9F=A7=91=E2=80=8D=F0=9F=92=BB=20Defin?= =?UTF-8?q?e=20`PatchPredictor`=20(#783)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Redesigns PatchPredictor engine using the new EngineABC base class. - The WSIs are now processed using the same code as for the processing the patches using WSI based dataloader. - The intermediate output is saved as zarr for the WSIs to resolve memory issues. - The output of model architectures should now be a dictionary. - The output can be specified as AnnotationStore for visualisation using TIAViz. --------- Co-authored-by: abishekrajvg Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/engines/test_engine_abc.py | 107 +- tests/engines/test_patch_predictor.py | 1178 +++++-------------- tests/models/test_dataset.py | 2 +- tests/test_annotation_stores.py | 11 +- tests/test_init.py | 2 +- tests/test_utils.py | 3 +- tiatoolbox/cli/common.py | 79 +- tiatoolbox/cli/patch_predictor.py | 64 +- tiatoolbox/models/architecture/vanilla.py | 5 +- tiatoolbox/models/dataset/dataset_abc.py | 8 +- tiatoolbox/models/engine/engine_abc.py | 247 +++- tiatoolbox/models/engine/patch_predictor.py | 980 +++++---------- tiatoolbox/utils/misc.py | 45 +- 13 files changed, 983 insertions(+), 1748 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index d0d6f0bd5..a5ce07d30 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import logging import shutil from pathlib import Path from typing import TYPE_CHECKING, NoReturn @@ -11,6 +12,7 @@ import pytest import torchvision.models as torch_models import zarr +from typing_extensions import Unpack from tiatoolbox.models.architecture import ( fetch_pretrained_weights, @@ -18,8 +20,13 @@ ) from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.dataset import PatchDataset, WSIPatchDataset -from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir +from tiatoolbox.models.engine.engine_abc import ( + EngineABC, + EngineABCRunParams, + prepare_engines_save_dir, +) from tiatoolbox.models.engine.io_config import ModelIOConfigABC +from tiatoolbox.utils.misc import write_to_zarr_in_cache_mode if TYPE_CHECKING: import torch.nn @@ -57,31 +64,38 @@ def get_dataloader( def save_wsi_output( self: EngineABC, - raw_output: dict, + processed_output: dict, save_dir: Path, **kwargs: dict, ) -> Path: """Test post_process_wsi.""" return super().save_wsi_output( - raw_output, + processed_output, save_dir=save_dir, **kwargs, ) + def post_process_wsi( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output.""" + return super().post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, + ) + def infer_wsi( self: EngineABC, dataloader: torch.utils.data.DataLoader, - img_label: str, - highest_input_resolution: list[dict], - save_dir: Path, + save_path: Path, **kwargs: dict, ) -> dict | np.ndarray: """Test infer_wsi.""" return super().infer_wsi( dataloader, - img_label, - highest_input_resolution, - save_dir, + save_path, **kwargs, ) @@ -115,13 +129,34 @@ def test_incorrect_ioconfig() -> NoReturn: """Test EngineABC initialization with incorrect ioconfig.""" model = torch_models.resnet18() engine = TestEngineABC(model=model) + with pytest.raises( ValueError, - match=r".*provide a valid ModelIOConfigABC.*", + match=r".*Must provide.*`ioconfig`.*", ): engine.run(images=[], masks=[], ioconfig=None) +def test_incorrect_output_type() -> NoReturn: + """Test EngineABC for incorrect output type.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + + with pytest.raises( + TypeError, + match=r".*output_type must be 'dict' or 'zarr' or 'annotationstore*", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + output_type="random", + ) + + def test_pretrained_ioconfig() -> NoReturn: """Test EngineABC initialization with pretrained model name in the toolbox.""" pretrained_model = "alexnet-kather100k" @@ -134,7 +169,7 @@ def test_pretrained_ioconfig() -> NoReturn: patch_mode=True, ioconfig=None, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out @@ -153,7 +188,7 @@ def test_ioconfig() -> NoReturn: ioconfig=ioconfig, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out @@ -260,7 +295,7 @@ def test_engine_initalization() -> NoReturn: assert isinstance(eng, EngineABC) -def test_engine_run(tmp_path: Path) -> NoReturn: +def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn: """Test engine run.""" eng = TestEngineABC(model="alexnet-kather100k") assert isinstance(eng, EngineABC) @@ -316,7 +351,7 @@ def test_engine_run(tmp_path: Path) -> NoReturn: on_gpu=False, patch_mode=True, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out eng = TestEngineABC(model="alexnet-kather100k") @@ -325,7 +360,7 @@ def test_engine_run(tmp_path: Path) -> NoReturn: on_gpu=False, verbose=False, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out eng = TestEngineABC(model="alexnet-kather100k") @@ -334,14 +369,14 @@ def test_engine_run(tmp_path: Path) -> NoReturn: labels=list(range(10)), on_gpu=False, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" in out eng = TestEngineABC(model="alexnet-kather100k") with pytest.raises(NotImplementedError): eng.run( - images=np.zeros(shape=(10, 224, 224, 3)), + images=[sample_svs], save_dir=tmp_path / "output", patch_mode=False, ) @@ -358,7 +393,7 @@ def test_engine_run_with_verbose() -> NoReturn: on_gpu=False, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" in out @@ -513,7 +548,10 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: save_path = tmp_path / "output.zarr" _ = zarr.open(save_path, mode="w") out = eng.save_wsi_output( - raw_output=save_path, save_path=save_path, output_type="zarr", save_dir=tmp_path + processed_output=save_path, + save_path=save_path, + output_type="zarr", + save_dir=tmp_path, ) assert out.exists() @@ -521,13 +559,17 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: # Test AnnotationStore patch_output = { - "predictions": [1, 0, 1], - "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], - "other": "other", + "predictions": np.array([1, 0, 1]), + "coordinates": np.array([(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)]), } class_dict = {0: "class0", 1: "class1"} + save_path = tmp_path / "output_db.zarr" + zarr_group = zarr.open(save_path, mode="w") + _ = write_to_zarr_in_cache_mode( + zarr_group=zarr_group, output_data_to_save=patch_output + ) out = eng.save_wsi_output( - raw_output=patch_output, + processed_output=save_path, scale_factor=(1.0, 1.0), class_dict=class_dict, save_dir=tmp_path, @@ -542,28 +584,35 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: match=r".*supports zarr and AnnotationStore as output_type.", ): eng.save_wsi_output( - raw_output=save_path, + processed_output=save_path, save_path=save_path, output_type="dict", save_dir=tmp_path, ) -def test_io_config_delegation(tmp_path: Path) -> None: +def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: """Test for delegating args to io config.""" # test not providing config / full input info for not pretrained models model = CNNModel("resnet50") eng = TestEngineABC(model=model) - with pytest.raises(ValueError, match=r".*Please provide a valid ModelIOConfigABC*"): - eng.run( - np.zeros((10, 224, 224, 3)), patch_mode=True, save_dir=tmp_path / "dump" - ) kwargs = { "patch_input_shape": [512, 512], "resolution": 1.75, "units": "mpp", } + with caplog.at_level(logging.WARNING): + eng.run( + np.zeros((10, 224, 224, 3)), + patch_mode=True, + save_dir=tmp_path / "dump", + patch_input_shape=kwargs["patch_input_shape"], + resolution=kwargs["resolution"], + units=kwargs["units"], + ) + assert "provide a valid ModelIOConfigABC" in caplog.text + shutil.rmtree(tmp_path / "dump", ignore_errors=True) # test providing config / full input info for non pretrained models ioconfig = ModelIOConfigABC( diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index ab59efc53..8f62f5037 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -3,628 +3,124 @@ from __future__ import annotations import copy +import json import shutil +import sqlite3 from pathlib import Path from typing import Callable -import cv2 import numpy as np -import pytest -import torch +import zarr from click.testing import CliRunner from tiatoolbox import cli -from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor +from tiatoolbox.models import IOPatchPredictorConfig from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.models.dataset import ( - PatchDataset, - PatchDatasetABC, - WSIPatchDataset, - predefined_preproc_func, -) -from tiatoolbox.utils import download_data, imread, imwrite +from tiatoolbox.models.engine.patch_predictor import PatchPredictor +from tiatoolbox.utils import download_data, imwrite from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.wsicore.wsireader import WSIReader +device = "cuda" if toolbox_env.has_gpu() else "cpu" ON_GPU = toolbox_env.has_gpu() RNG = np.random.default_rng() # Numpy Random Generator -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - - -def test_patch_dataset_path_imgs( - sample_patch1: str | Path, - sample_patch2: str | Path, -) -> None: - """Test for patch dataset with a list of file paths as input.""" - size = (224, 224, 3) - - dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_list_imgs(tmp_path: Path) -> None: - """Test for patch dataset with a list of images as input.""" - save_dir_path = tmp_path - - size = (5, 5, 3) - img = RNG.integers(low=0, high=255, size=size) - list_imgs = [img, img, img] - dataset = PatchDataset(list_imgs) - - dataset.preproc_func = lambda x: x - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - # test for changing to another preproc - dataset.preproc_func = lambda x: x - 10 - item = dataset[0] - assert np.sum(item["image"] - (list_imgs[0] - 10)) == 0 - - # * test for loading npy - # remove previously generated data - if Path.exists(save_dir_path): - shutil.rmtree(save_dir_path, ignore_errors=True) - Path.mkdir(save_dir_path, parents=True) - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - assert imgs[0] is not None - # test for path object - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - - -def test_patch_datasetarray_imgs() -> None: - """Test for patch dataset with a numpy array of a list of images.""" - size = (5, 5, 3) - img = RNG.integers(0, 255, size=size) - list_imgs = [img, img, img] - labels = [1, 2, 3] - array_imgs = np.array(list_imgs) - - # test different setter for label - dataset = PatchDataset(array_imgs, labels=labels) - an_item = dataset[2] - assert an_item["label"] == 3 - dataset = PatchDataset(array_imgs, labels=None) - an_item = dataset[2] - assert "label" not in an_item - - dataset = PatchDataset(array_imgs) - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_crash(tmp_path: Path) -> None: - """Test to make sure patch dataset crashes with incorrect input.""" - # all below examples should fail when input to PatchDataset - save_dir_path = tmp_path - - # not supported input type - imgs = {"a": RNG.integers(0, 255, (4, 4, 4))} - with pytest.raises( - ValueError, - match=r".*Input must be either a list/array of images.*", - ): - _ = PatchDataset(imgs) - - # ndarray of mixed dtype - imgs = np.array( - # string array of the same shape - [ - RNG.integers(0, 255, (4, 5, 3)), - np.array( # skipcq: PYL-E1121 - ["you_should_crash_here" for _ in range(4 * 5 * 3)], - ).reshape( - 4, - 5, - 3, - ), - ], - dtype=object, - ) - with pytest.raises(ValueError, match="Provided input array is non-numerical."): - _ = PatchDataset(imgs) - - # ndarray(s) of NHW images - imgs = RNG.integers(0, 255, (4, 4, 4)) - with pytest.raises(ValueError, match=r".*array of the form HWC*"): - _ = PatchDataset(imgs) - - # list of ndarray(s) with different sizes - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 5, 3)), - ] - with pytest.raises(ValueError, match="Images must have the same dimensions."): - _ = PatchDataset(imgs) - - # list of ndarray(s) with HW and HWC mixed up - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 4)), - ] - with pytest.raises( - ValueError, - match="Each sample must be an array of the form HWC.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = [RNG.integers(0, 255, (4, 4, 3)), "you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = ["you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list not exist paths - with pytest.raises( - ValueError, - match=r".*valid image paths.*", - ): - _ = PatchDataset(["img.npy"]) - - # ** test different extension parser - # save dummy data to temporary location - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - - torch.save({"a": "a"}, save_dir_path / "sample1.tar") - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - - imgs = [ - save_dir_path / "sample1.tar", - save_dir_path / "sample2.npy", - ] - with pytest.raises( - ValueError, - match="Cannot load image data from", - ): - _ = PatchDataset(imgs) - - # preproc func for not defined dataset - with pytest.raises( - ValueError, - match=r".* preprocessing .* does not exist.", - ): - predefined_preproc_func("secret-dataset") - - -def test_wsi_patch_dataset( # noqa: PLR0915 - sample_wsi_dict: dict, - tmp_path: Path, -) -> None: - """A test for creation and bare output.""" - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return WSIPatchDataset(img_path=img_path, **kwargs) - - def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return reuse_init(mode="wsi", **kwargs) - - # test for ABC validate - # intentionally created to check error - # skipcq - class Proto(PatchDatasetABC): - def __init__(self: Proto) -> None: - super().__init__() - self.inputs = "CRASH" - self._check_input_integrity("wsi") - - # skipcq - def __getitem__(self: Proto, idx: int) -> object: - """Get an item from the dataset.""" - - with pytest.raises( - ValueError, - match=r".*`inputs` should be a list of patch coordinates.*", - ): - Proto() # skipcq - - # invalid path input - with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): - WSIPatchDataset( - img_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - ) - - # invalid mask path input - with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): - WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - - # invalid mode - with pytest.raises(ValueError, match="`X` is not supported."): - reuse_init(mode="X") - - # invalid patch - with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): - reuse_init() - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, "a"]) - with pytest.raises(ValueError, match="Invalid `stride_shape` value None."): - reuse_init_wsi(patch_input_shape=512) - # invalid stride - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, "a"]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, 512, 512]) - # negative - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, -512], stride_shape=[512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, -512]) - - # * for wsi - # dummy test for analysing the output - # stride and patch size should be as expected - patch_size = [512, 512] - stride_size = [256, 256] - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - reader = WSIReader.open(mini_wsi_svs) - # tiling top to bottom, left to right - ds_roi = ds[2]["image"] - step_idx = 2 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - rd_roi = reader.read_bounds( - start + end, - resolution=1.0, - units="mpp", - coord_space="resolution", - ) - correlation = np.corrcoef( - cv2.cvtColor(ds_roi, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(rd_roi, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert ds_roi.shape[0] == rd_roi.shape[0] - assert ds_roi.shape[1] == rd_roi.shape[1] - assert np.min(correlation) > 0.9, correlation - - # test creation with auto mask gen and input mask - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=True, - ) - assert len(ds) > 0 - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=mini_wsi_msk, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - negative_mask = imread(mini_wsi_msk) - negative_mask = np.zeros_like(negative_mask) - negative_mask_path = tmp_path / "negative_mask.png" - imwrite(negative_mask_path, negative_mask) - with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=negative_mask_path, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - - # * for tile - reader = WSIReader.open(mini_wsi_jpg) - tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, - mode="tile", - patch_input_shape=patch_size, - stride_shape=stride_size, - auto_get_mask=False, - ) - step_idx = 3 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - roi2 = reader.read_bounds( - start + end, - resolution=1.0, - units="baseline", - coord_space="resolution", - ) - roi1 = tile_ds[3]["image"] # match with step_index - correlation = np.corrcoef( - cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert roi1.shape[0] == roi2.shape[0] - assert roi1.shape[1] == roi2.shape[1] - assert np.min(correlation) > 0.9, correlation - - -def test_patch_dataset_abc() -> None: - """Test for ABC methods. - - Test missing definition for abstract intentionally created to check error. - - """ - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # crash due to undefined __getitem__ - with pytest.raises(TypeError): - Proto() # skipcq - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # skipcq - def __getitem__(self: Proto, idx: int) -> None: - """Get an item from the dataset.""" - - ds = Proto() # skipcq - - # test setter and getter - assert ds.preproc_func(1) == 1 - ds.preproc_func = lambda x: x - 1 # skipcq: PYL-W0201 - assert ds.preproc_func(1) == 0 - assert ds.preproc(1) == 1, "Must be unchanged!" - ds.preproc_func = None # skipcq: PYL-W0201 - assert ds.preproc_func(2) == 2 - - # test assign uncallable to preproc_func/postproc_func - with pytest.raises(ValueError, match=r".*callable*"): - ds.preproc_func = 1 # skipcq: PYL-W0201 - - -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - - -def test_io_patch_predictor_config() -> None: - """Test for IOConfig.""" - # test for creating - cfg = IOPatchPredictorConfig( - patch_input_shape=[224, 224], - stride_shape=[224, 224], - input_resolutions=[{"resolution": 0.5, "units": "mpp"}], - # test adding random kwarg and they should be accessible as kwargs - crop_from_source=True, - ) - assert cfg.crop_from_source - # ------------------------------------------------------------------------------------- # Engine # ------------------------------------------------------------------------------------- -def test_predictor_crash(tmp_path: Path) -> None: - """Test for crash when making predictor.""" - # without providing any model - with pytest.raises(ValueError, match=r"Must provide.*"): - PatchPredictor() - - # provide wrong unknown pretrained model - with pytest.raises(ValueError, match=r"Pretrained .* does not exist"): - PatchPredictor(pretrained_model="secret_model-kather100k") - - # provide wrong model of unknown type, deprecated later with type hint - with pytest.raises(TypeError, match=r".*must be a string.*"): - PatchPredictor(pretrained_model=123) - - # test predict crash - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - with pytest.raises(ValueError, match=r".*not a valid mode.*"): - predictor.predict("aaa", mode="random", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(TypeError, match=r".*must be a list of file paths.*"): - predictor.predict("aaa", mode="wsi", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"): - predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path) - with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): - predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - - def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: """Test for delegating args to io config.""" mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - - # test not providing config / full input info for not pretrained models model = CNNModel("resnet50") - predictor = PatchPredictor(model=model) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict([mini_wsi_svs], mode="wsi", save_dir=tmp_path / "dump") - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - + predictor = PatchPredictor(model=model, weights=None) kwargs = { "patch_input_shape": [512, 512], "resolution": 1.75, "units": "mpp", } - for key in kwargs: - _kwargs = copy.deepcopy(kwargs) - _kwargs.pop(key) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **_kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test providing config / full input info for not pretrained models + + # test providing config / full input info for default models without weights ioconfig = IOPatchPredictorConfig( patch_input_shape=(512, 512), stride_shape=(256, 256), input_resolutions=[{"resolution": 1.35, "units": "mpp"}], ) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], ioconfig=ioconfig, - mode="wsi", + patch_mode=False, save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], - mode="wsi", + predictor.run( + images=[mini_wsi_svs], + patch_mode=False, save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, **kwargs, ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) # test overwriting pretrained ioconfig - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - predictor.predict( - [mini_wsi_svs], + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1) + predictor.run( + images=[mini_wsi_svs], patch_input_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.patch_input_shape == (300, 300) shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], stride_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.stride_shape == (300, 300) shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], resolution=1.99, - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], units="baseline", - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - predictor.predict( - [mini_wsi_svs], - mode="wsi", - merge_predictions=True, + predictor.run( + images=[mini_wsi_svs], + units="level", + resolution=0, + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "level" + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 0 + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + units="power", + resolution=20, + patch_mode=False, save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "power" + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 20 shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -638,59 +134,28 @@ def test_patch_predictor_api( # convert to pathlib Path to prevent reader complaint inputs = [Path(sample_patch1), Path(sample_patch2)] - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1) # don't run test on GPU - output = predictor.predict( + # Default run + output = predictor.run( inputs, - on_gpu=ON_GPU, - save_dir=save_dir_path, + device="cpu", ) - assert sorted(output.keys()) == ["predictions"] - assert len(output["predictions"]) == 2 + assert sorted(output.keys()) == ["probabilities"] + assert len(output["probabilities"]) == 2 shutil.rmtree(save_dir_path, ignore_errors=True) - output = predictor.predict( + # whether to return labels + output = predictor.run( inputs, - labels=[1, "a"], + labels=["1", "a"], return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, ) - assert sorted(output.keys()) == sorted(["labels", "predictions"]) - assert len(output["predictions"]) == len(output["labels"]) - assert output["labels"] == [1, "a"] + assert sorted(output.keys()) == sorted(["labels", "probabilities"]) + assert len(output["probabilities"]) == len(output["labels"]) + assert output["labels"].tolist() == ["1", "a"] shutil.rmtree(save_dir_path, ignore_errors=True) - output = predictor.predict( - inputs, - return_probabilities=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["probabilities"]) - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - # test saving output, should have no effect - _ = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir="special_dir_not_exist", - ) - assert not Path.is_dir(Path("special_dir_not_exist")) - # test loading user weight pretrained_weights_url = ( "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" @@ -705,33 +170,31 @@ def test_patch_predictor_api( download_data(pretrained_weights_url, pretrained_weights) - _ = PatchPredictor( - pretrained_model="resnet18-kather100k", - pretrained_weights=pretrained_weights, + predictor = PatchPredictor( + model="resnet18-kather100k", + weights=pretrained_weights, batch_size=1, ) + ioconfig = predictor.ioconfig # --- test different using user model model = CNNModel(backbone="resnet18", num_classes=9) # test prediction predictor = PatchPredictor(model=model, batch_size=1, verbose=False) - output = predictor.predict( + output = predictor.run( inputs, - return_probabilities=True, - labels=[1, "a"], + labels=[1, 2], return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, + ioconfig=ioconfig, ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) + assert sorted(output.keys()) == sorted(["labels", "probabilities"]) + assert len(output["probabilities"]) == len(output["labels"]) + assert output["labels"].tolist() == [1, 2] def test_wsi_predictor_api( sample_wsi_dict: dict, tmp_path: Path, - chdir: Callable, ) -> None: """Test normal run of wsi predictor.""" save_dir_path = tmp_path @@ -742,15 +205,12 @@ def test_wsi_predictor_api( mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) patch_size = np.array([224, 224]) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=32) save_dir = f"{save_dir_path}/model_wsi_output" # wrapper to make this more clean kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, "patch_input_shape": patch_size, "stride_shape": patch_size, "resolution": 1.0, @@ -759,236 +219,60 @@ def test_wsi_predictor_api( } # ! add this test back once the read at `baseline` is fixed # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - # remove previously generated data - shutil.rmtree(save_dir, ignore_errors=True) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 0.5, - "save_dir": save_dir, - "merge_predictions": True, # to test the api coverage - "units": "mpp", - } - - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = False - # test reading of multiple whole-slide images - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" not in output_info - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - # coverage test _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = True # test reading of multiple whole-slide images - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], + output = predictor.run( + images=[mini_wsi_svs, mini_wsi_jpg], masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", + patch_mode=False, **_kwargs, ) - _kwargs = copy.deepcopy(kwargs) - with pytest.raises(FileExistsError): - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - # remove previously generated data - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - with chdir(save_dir_path): - # test reading of multiple whole-slide images - _kwargs = copy.deepcopy(kwargs) - _kwargs["save_dir"] = None # default coverage - _kwargs["return_probabilities"] = False - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - assert Path.exists(Path("output")) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" in output_info - assert Path(output_info["merged"]).exists() - - # remove previously generated data - shutil.rmtree("output", ignore_errors=True) - - -def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: - """Test normal run of wsi predictor with merge predictions option.""" - # convert to pathlib Path to prevent reader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - # blind test - # pseudo output dict from model with 2 patches - output = { - "resolution": 1.0, - "units": "baseline", - "probabilities": [[0.45, 0.55], [0.90, 0.10]], - "predictions": [1, 0], - "coordinates": [[0, 0, 2, 2], [2, 2, 4, 4]], - } - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - ) - _merged = np.array([[2, 2, 0, 0], [2, 2, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]) - assert np.sum(merged - _merged) == 0 - # blind test for merging probabilities - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - return_raw=True, - ) - _merged = np.array( - [ - [0.45, 0.45, 0, 0], - [0.45, 0.45, 0, 0], - [0, 0, 0.90, 0.90], - [0, 0, 0.90, 0.90], - ], - ) - assert merged.shape == (4, 4, 2) - assert np.mean(np.abs(merged[..., 0] - _merged)) < 1.0e-6 + wsi_pred = zarr.open(str(output[mini_wsi_svs]), mode="r") + tile_pred = zarr.open(str(output[mini_wsi_jpg]), mode="r") + diff = tile_pred["probabilities"][:] == wsi_pred["probabilities"][:] + accuracy = np.sum(diff) / np.size(wsi_pred["probabilities"][:]) + assert accuracy > 0.99, np.nonzero(~diff) - # integration test - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": np.array([224, 224]), - "stride_shape": np.array([224, 224]), - "resolution": 1.0, - "units": "baseline", - "merge_predictions": True, - } - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - # mock up to change the preproc func and - # force to use the default in merge function - # still should have the same results - kwargs["merge_predictions"] = False - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - merged_tile_output = predictor.merge_predictions( - mini_wsi_jpg, - tile_output[0], - resolution=kwargs["resolution"], - units=kwargs["units"], - ) - tile_output.append(merged_tile_output) - - # first make sure nothing breaks with predictions - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - merged_wsi = wsi_output[1] - merged_tile = tile_output[1] - # ensure shape of merged predictions of tile and wsi input are the same - assert merged_wsi.shape == merged_tile.shape - # ensure consistent predictions between tile and wsi mode - diff = merged_tile == merged_wsi - accuracy = np.sum(diff) / np.size(merged_wsi) - assert accuracy > 0.9, np.nonzero(~diff) + shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) def _test_predictor_output( inputs: list, - pretrained_model: str, + model: str, probabilities_check: list | None = None, predictions_check: list | None = None, - *, - on_gpu: bool = ON_GPU, ) -> None: """Test the predictions of multiple models included in tiatoolbox.""" predictor = PatchPredictor( - pretrained_model=pretrained_model, + model=model, batch_size=32, verbose=False, ) # don't run test on GPU - output = predictor.predict( + output = predictor.run( inputs, return_probabilities=True, return_labels=False, - on_gpu=on_gpu, + device=device, ) - predictions = output["predictions"] - probabilities = output["probabilities"] - for idx, probabilities_ in enumerate(probabilities): + predictions = output["probabilities"] + for idx, probabilities_ in enumerate(predictions): probabilities_max = max(probabilities_) assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - pretrained_model, + model, probabilities_max, probabilities_check[idx], - predictions[idx], + probabilities_, predictions_check[idx], ) - assert predictions[idx] == predictions_check[idx], ( - pretrained_model, + assert np.argmax(probabilities_) == predictions_check[idx], ( + model, probabilities_max, probabilities_check[idx], - predictions[idx], + probabilities_, predictions_check[idx], ) @@ -1018,52 +302,188 @@ def test_patch_predictor_kather100k_output( "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], "googlenet-kather100k": [1.0, 0.9999639987945557], } - for pretrained_model, expected_prob in pretrained_info.items(): + for model, expected_prob in pretrained_info.items(): _test_predictor_output( inputs, - pretrained_model, + model, probabilities_check=expected_prob, predictions_check=[6, 3], - on_gpu=ON_GPU, ) # only test 1 on travis to limit runtime if toolbox_env.running_on_ci(): break -def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: - """Test the output of patch prediction models on PCam dataset.""" - inputs = [Path(sample_patch3), Path(sample_patch4)] - pretrained_info = { - "alexnet-pcam": [0.999980092048645, 0.9769067168235779], - "resnet18-pcam": [0.999992847442627, 0.9466130137443542], - "resnet34-pcam": [1.0, 0.9976525902748108], - "resnet50-pcam": [0.9999270439147949, 0.9999996423721313], - "resnet101-pcam": [1.0, 0.9997289776802063], - "resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802], - "resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244], - "wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799], - "wide_resnet101_2-pcam": [1.0, 0.9945427179336548], - "densenet121-pcam": [0.9999251365661621, 0.9997479319572449], - "densenet161-pcam": [0.9999969005584717, 0.9662821292877197], - "densenet169-pcam": [0.9999998807907104, 0.9993504881858826], - "densenet201-pcam": [0.9999942779541016, 0.9950824975967407], - "mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986], - "mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085], - "mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972], - "googlenet-pcam": [0.9999929666519165, 0.8701475858688354], +def _validate_probabilities(predictions: list | dict) -> bool: + """Helper function to test if the probabilities value are valid.""" + if isinstance(predictions, dict): + return all(0 <= probability <= 1 for _, probability in predictions.items()) + + for row in predictions: + for element in row: + if not (0 <= element <= 1): + return False + return True + + +def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None: + """Test normal run of patch predictor for WSIs.""" + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + images=[mini_wsi_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + ) + + assert output[mini_wsi_svs].exists() + + output_ = zarr.open(output[mini_wsi_svs]) + + assert output_["probabilities"].shape == (70, 9) # number of patches x classes + assert output_["probabilities"].ndim == 2 + # number of patches x [start_x, start_y, end_x, end_y] + assert output_["coordinates"].shape == (70, 4) + assert output_["coordinates"].ndim == 2 + assert _validate_probabilities(predictions=output_["probabilities"]) + + +def test_wsi_predictor_zarr_baseline(sample_wsi_dict: dict, tmp_path: Path) -> None: + """Test normal run of patch predictor for WSIs.""" + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + images=[mini_wsi_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + units="baseline", + resolution=1.0, + ) + + assert output[mini_wsi_svs].exists() + + output_ = zarr.open(output[mini_wsi_svs]) + + assert output_["probabilities"].shape == (244, 9) # number of patches x classes + assert output_["probabilities"].ndim == 2 + # number of patches x [start_x, start_y, end_x, end_y] + assert output_["coordinates"].shape == (244, 4) + assert output_["coordinates"].ndim == 2 + assert _validate_probabilities(predictions=output_["probabilities"]) + + +def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: + """Helper function to extract probabilities from Annotation Store.""" + probs_dict = {} + con = sqlite3.connect(dbfile) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + for item in annotations_properties: + for json_str in item: + probs_dict = json.loads(json_str) + probs_dict.pop("prob_0") + + return probs_dict + + +def test_engine_run_wsi_annotation_store( + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """Test the engine run for Whole slide images.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + eng = PatchPredictor(model="alexnet-kather100k") + + patch_size = np.array([224, 224]) + save_dir = f"{tmp_path}/model_wsi_output" + + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 0.5, + "save_dir": save_dir, + "units": "mpp", + "scale_factor": (2.0, 2.0), } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[1, 0], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break + + output = eng.run( + images=[mini_wsi_svs], + masks=[mini_wsi_msk], + patch_mode=False, + output_type="AnnotationStore", + **kwargs, + ) + + output_ = output[mini_wsi_svs] + + assert output_.exists() + assert output_.suffix == ".db" + predictions = _extract_probabilities_from_annotation_store(output_) + assert _validate_probabilities(predictions) + + shutil.rmtree(save_dir) + + +def test_engine_run_wsi_annotation_store_power( + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """Test the engine run for Whole slide images.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + eng = PatchPredictor(model="alexnet-kather100k") + + patch_size = np.array([224, 224]) + save_dir = f"{tmp_path}/model_wsi_output" + + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 20, + "save_dir": save_dir, + "units": "power", + } + + output = eng.run( + images=[mini_wsi_svs], + masks=[mini_wsi_msk], + patch_mode=False, + output_type="AnnotationStore", + **kwargs, + ) + + output_ = output[mini_wsi_svs] + + assert output_.exists() + assert output_.suffix == ".db" + predictions = _extract_probabilities_from_annotation_store(output_) + assert _validate_probabilities(predictions) + + shutil.rmtree(save_dir) # ------------------------------------------------------------------------------------- @@ -1103,14 +523,14 @@ def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> str(sample_svs), "--file-types", '"*.ndpi, *.svs"', - "--mode", + "--patch-mode", '"patch"', "--output-path", str(tmp_path.joinpath("output")), ], ) - assert "Invalid value for '--mode'" in mode_not_in_wsi_tile_result.output + assert "Invalid value for '--patch-mode'" in mode_not_in_wsi_tile_result.output assert mode_not_in_wsi_tile_result.exit_code != 0 assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) @@ -1124,47 +544,15 @@ def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: "patch-predictor", "--img-input", str(sample_svs), - "--mode", - "wsi", + "--patch-mode", + "False", "--output-path", - str(tmp_path.joinpath("output")), + str(tmp_path / "output"), ], ) assert models_wsi_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_single_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI single file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() + assert (tmp_path / "output" / (sample_svs.stem + ".db")).exists() def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: @@ -1187,20 +575,18 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("3_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("1_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("2_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("3_" + mini_wsi_svs.name)) try: dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) except OSError: - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("1_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("2_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("3_" + mini_wsi_msk.name)) - - tmp_path = tmp_path.joinpath("output") + shutil.copy(mini_wsi_msk, dir_path_masks / ("1_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks / ("2_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks / ("3_" + mini_wsi_msk.name)) runner = CliRunner() models_tiles_result = runner.invoke( @@ -1209,20 +595,18 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) - "patch-predictor", "--img-input", str(dir_path), - "--mode", - "wsi", + "--patch-mode", + str(False), "--masks", str(dir_path_masks), "--output-path", - str(tmp_path), + str(tmp_path / "output"), + "--output-type", + "zarr", ], ) assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("0.merged.npy").exists() - assert tmp_path.joinpath("0.raw.json").exists() - assert tmp_path.joinpath("1.merged.npy").exists() - assert tmp_path.joinpath("1.raw.json").exists() - assert tmp_path.joinpath("2.merged.npy").exists() - assert tmp_path.joinpath("2.raw.json").exists() - assert tmp_path.joinpath("results.json").exists() + assert (tmp_path / "output" / ("1_" + mini_wsi_svs.stem + ".zarr")).exists() + assert (tmp_path / "output" / ("2_" + mini_wsi_svs.stem + ".zarr")).exists() + assert (tmp_path / "output" / ("3_" + mini_wsi_svs.stem + ".zarr")).exists() diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index de5b726a7..ab9a6033f 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -309,7 +309,7 @@ def test_patch_dataset_crash(tmp_path: Path) -> None: save_dir_path / "sample2.npy", ] with pytest.raises( - ValueError, + TypeError, match="Cannot load image data from", ): _ = PatchDataset(imgs) diff --git a/tests/test_annotation_stores.py b/tests/test_annotation_stores.py index 01bbdac45..66c990161 100644 --- a/tests/test_annotation_stores.py +++ b/tests/test_annotation_stores.py @@ -53,14 +53,6 @@ FILLED_LEN = 2 * (GRID_SIZE[0] * GRID_SIZE[1]) RNG = np.random.default_rng(0) # Numpy Random Generator -# ---------------------------------------------------------------------- -# Resets -# ---------------------------------------------------------------------- - -# Reset filters in logger. -for filter_ in logger.filters: - logger.removeFilter(filter_) - # ---------------------------------------------------------------------- # Helper Functions # ---------------------------------------------------------------------- @@ -546,6 +538,9 @@ def test_sqlite_store_compile_options_missing_math( caplog: pytest.LogCaptureFixture, ) -> None: """Test that a warning is shown if the sqlite math module is missing.""" + # Reset filters in logger. + for filter_ in logger.filters[:]: + logger.removeFilter(filter_) monkeypatch.setattr( SQLiteStore, "compile_options", diff --git a/tests/test_init.py b/tests/test_init.py index 509a9c49f..6d8ed8238 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -114,7 +114,7 @@ def test_duplicate_filter(caplog: pytest.LogCaptureFixture) -> None: logger.addFilter(duplicate_filter) # Reset filters in logger. - for filter_ in logger.filters: + for filter_ in logger.filters[:]: logger.removeFilter(filter_) for _ in range(2): diff --git a/tests/test_utils.py b/tests/test_utils.py index ad8e0e3da..0b1f3f484 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1646,6 +1646,7 @@ def test_patch_pred_store() -> None: """Test patch_pred_store.""" # Define a mock patch_output patch_output = { + "probabilities": [(0.99, 0.01), (0.01, 0.99), (0.99, 0.01)], "predictions": [1, 0, 1], "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], "other": "other", @@ -1680,7 +1681,7 @@ def test_patch_pred_store_cdict() -> None: class_dict = {0: "class0", 1: "class1"} store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict) - # Check that its an SQLiteStore containing the expected annotations + # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 81ba7b5f4..6b4a23a5a 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -86,6 +86,24 @@ def cli_file_type( ) +def cli_output_type( + usage_help: str = "The format of the output type. " + "'output_type' can be 'zarr' or 'AnnotationStore'. " + "Default value is 'AnnotationStore'.", + default: str = "AnnotationStore", + input_type: click.Choice | None = None, +) -> callable: + """Enables --file-types option for cli.""" + if input_type is None: + input_type = click.Choice(["zarr", "AnnotationStore"], case_sensitive=False) + return click.option( + "--output-type", + help=add_default_to_usage_help(usage_help, default), + default=default, + type=input_type, + ) + + def cli_mode( usage_help: str = "Selected mode to show or save the required information.", default: str = "save", @@ -102,6 +120,20 @@ def cli_mode( ) +def cli_patch_mode( + usage_help: str = "Whether to run the model in patch mode or WSI mode.", + *, + default: bool = False, +) -> callable: + """Enables --return-probabilities option for cli.""" + return click.option( + "--patch-mode", + type=bool, + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_region( usage_help: str = "Image region in the whole slide image to read from. " "default=0 0 2000 2000", @@ -215,7 +247,7 @@ def cli_pretrained_model( ) -> callable: """Enables --pretrained-model option for cli.""" return click.option( - "--pretrained-model", + "--model", help=add_default_to_usage_help(usage_help, default), default=default, ) @@ -234,6 +266,51 @@ def cli_pretrained_weights( ) +def cli_model( + usage_help: str = "Name of the predefined model used to process the data. " + "The format is _. For example, " + "`resnet18-kather100K` is a resnet18 model trained on the Kather dataset. " + "Please see " + "https://tia-toolbox.readthedocs.io/en/latest/usage.html#deep-learning-models " + "for a detailed list of available pretrained models." + "By default, the corresponding pretrained weights will also be" + "downloaded. However, you can override with your own set of weights" + "via the `pretrained_weights` argument. Argument is case insensitive.", + default: str = "resnet18-kather100k", +) -> callable: + """Enables --pretrained-model option for cli.""" + return click.option( + "--model", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + +def cli_weights( + usage_help: str = "Path to the model weight file. If not supplied, the default " + "pretrained weight will be used.", + default: str | None = None, +) -> callable: + """Enables --pretrained-weights option for cli.""" + return click.option( + "--weights", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + +def cli_device( + usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.", + default: str = "cpu", +) -> callable: + """Enables --pretrained-weights option for cli.""" + return click.option( + "--device", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_return_probabilities( usage_help: str = "Whether to return raw model probabilities.", *, diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index f6cc1b397..263809146 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -2,25 +2,21 @@ from __future__ import annotations -import click - from tiatoolbox.cli.common import ( cli_batch_size, + cli_device, cli_file_type, cli_img_input, cli_masks, - cli_merge_predictions, - cli_mode, + cli_model, cli_num_loader_workers, - cli_on_gpu, cli_output_path, - cli_pretrained_model, - cli_pretrained_weights, + cli_output_type, + cli_patch_mode, cli_resolution, - cli_return_labels, - cli_return_probabilities, cli_units, cli_verbose, + cli_weights, prepare_model_cli, tiatoolbox_cli, ) @@ -35,45 +31,36 @@ @cli_file_type( default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", ) -@cli_mode( - usage_help="Type of input file to process.", - default="wsi", - input_type=click.Choice(["patch", "wsi", "tile"], case_sensitive=False), -) -@cli_pretrained_model(default="resnet18-kather100k") -@cli_pretrained_weights() -@cli_return_probabilities(default=False) -@cli_merge_predictions(default=True) -@cli_return_labels(default=True) -@cli_on_gpu(default=False) +@cli_patch_mode(default=False) +@cli_model(default="resnet18-kather100k") +@cli_weights() +@cli_device(default="cpu") @cli_batch_size(default=1) @cli_resolution(default=0.5) @cli_units(default="mpp") @cli_masks(default=None) @cli_num_loader_workers(default=0) +@cli_output_type(default="AnnotationStore") @cli_verbose(default=True) def patch_predictor( - pretrained_model: str, - pretrained_weights: str, + model: str, + weights: str, img_input: str, file_types: str, masks: str | None, - mode: str, output_path: str, batch_size: int, resolution: float, units: str, num_loader_workers: int, + device: str, + output_type: str, *, - return_probabilities: bool, - return_labels: bool, - merge_predictions: bool, - on_gpu: bool, + patch_mode: bool, verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" - from tiatoolbox.models import PatchPredictor - from tiatoolbox.utils import save_as_json + from tiatoolbox.models.engine.patch_predictor import PatchPredictor files_all, masks_all, output_path = prepare_model_cli( img_input=img_input, @@ -83,26 +70,21 @@ def patch_predictor( ) predictor = PatchPredictor( - pretrained_model=pretrained_model, - weights=pretrained_weights, + model=model, + weights=weights, batch_size=batch_size, num_loader_workers=num_loader_workers, verbose=verbose, ) - output = predictor.predict( - imgs=files_all, + _ = predictor.run( + images=files_all, masks=masks_all, - mode=mode, - return_probabilities=return_probabilities, - merge_predictions=merge_predictions, - labels=None, - return_labels=return_labels, + patch_mode=patch_mode, resolution=resolution, units=units, - on_gpu=on_gpu, + device=device, save_dir=output_path, save_output=True, + output_type=output_type, ) - - save_as_json(output, str(output_path.joinpath("results.json"))) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 5c19f4c27..e7b956411 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -169,7 +169,7 @@ def infer_batch( with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar - return {"predictions": output.cpu().numpy()} + return {"probabilities": output.cpu().numpy()} class CNNBackbone(ModelABC): @@ -265,5 +265,6 @@ def infer_batch( # Do not compute the gradient (not training) with torch.inference_mode(): output = model(img_patches_device) + # Output should be a single tensor or scalar - return {"predictions": output.cpu().numpy()} + return {"probabilities": output.cpu().numpy()} diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index d634ccbd4..045bb39b7 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -145,7 +145,7 @@ def load_img(path: str | Path) -> np.ndarray: if path.suffix not in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"): msg = f"Cannot load image data from `{path.suffix}` files." - raise ValueError(msg) + raise TypeError(msg) return imread(path, as_uint8=False) @@ -399,10 +399,8 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 `units`. Expected to be positive and of (height, width). Note, this is not at level 0. resolution (Resolution): - Check (:class:`.WSIReader`) for details. When - `mode='tile'`, value is fixed to be `resolution=1.0` and - `units='baseline'` units: check (:class:`.WSIReader`) for - details. + Requested resolution corresponding to units. Check + (:class:`WSIReader`) for details. units (Units): Units in which `resolution` is defined. auto_get_mask (bool): diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 8fef0c4e2..465230116 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import shutil from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, TypedDict @@ -14,7 +15,7 @@ from torch import nn from typing_extensions import Unpack -from tiatoolbox import logger +from tiatoolbox import DuplicateFilter, logger from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model @@ -128,7 +129,8 @@ class EngineABCRunParams(TypedDict, total=False): num_post_proc_workers (int): Number of workers to postprocess the results of the model. output_file (str): - Output file name to save "zarr" or "db". + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. patch_input_shape (tuple): Shape of patches input to the model as tuple of height and width (HW). Patches are requested at read resolution, not with respect to level 0, @@ -355,8 +357,6 @@ def __init__( verbose: bool = False, ) -> None: """Initialize Engine.""" - super().__init__() - self.images = None self.masks = None self.patch_mode = None @@ -378,10 +378,10 @@ def __init__( self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers self.patch_input_shape: IntPair | None = None - self.resolution: Resolution = 1.0 + self.resolution: Resolution | None = None self.return_labels: bool = False self.stride_shape: IntPair | None = None - self.units: Units = "baseline" + self.units: Units | None = None self.verbose = verbose @staticmethod @@ -440,7 +440,7 @@ def _initialize_model_ioconfig( def get_dataloader( self: EngineABC, - images: Path, + images: str | Path | list[str | Path] | np.ndarray, masks: Path | None = None, labels: list | None = None, ioconfig: ModelIOConfigABC | None = None, @@ -453,7 +453,7 @@ def get_dataloader( images (list of str or :class:`Path` or :class:`numpy.ndarray`): A list of image patches in NHWC format as a numpy array or a list of str/paths to WSIs. When `patch_mode` is False - the function expects list of str/paths to WSIs. + the function expects path to a single WSI. masks (list | None): List of masks. Only utilised when patch_mode is False. Patches are only generated within a masked area. @@ -470,7 +470,6 @@ def get_dataloader( torch.utils.data.DataLoader: :class:`torch.utils.data.DataLoader` for inference. - """ if labels: # if a labels is provided, then return with the prediction @@ -527,6 +526,8 @@ def infer_patches( self: EngineABC, dataloader: DataLoader, save_path: Path | None, + *, + return_coordinates: bool = False, ) -> dict | Path: """Runs model inference on image patches and returns output as a dictionary. @@ -535,6 +536,9 @@ def infer_patches( An :class:`torch.utils.data.DataLoader` object to run inference. save_path (Path | None): If `cache_mode` is True then path to save zarr file must be provided. + return_coordinates (bool): + Whether to save coordinates in the output. This is required when + this function is called by `infer_wsi` and `patch_mode` is False. Returns: dict or Path: @@ -553,11 +557,14 @@ def infer_patches( position=0, ) - keys = ["predictions"] + keys = ["probabilities"] if self.return_labels: keys.append("labels") + if return_coordinates: + keys.append("coordinates") + raw_predictions = {key: None for key in keys} zarr_group = None @@ -571,9 +578,14 @@ def infer_patches( batch_data["image"], device=self.device, ) + if return_coordinates: + batch_output["coordinates"] = batch_data["coords"].numpy() if self.return_labels: # be careful of `s` - batch_output["labels"] = batch_data["label"].numpy() + if isinstance(batch_data["label"], torch.Tensor): + batch_output["labels"] = batch_data["label"].numpy() + else: + batch_output["labels"] = batch_data["label"] raw_predictions = self._update_model_output( raw_predictions=raw_predictions, @@ -597,7 +609,7 @@ def infer_patches( def post_process_patches( self: EngineABC, raw_predictions: dict | Path, - **kwargs: dict, + **kwargs: Unpack[EngineABCRunParams], ) -> dict | Path: """Post-process raw patch predictions from inference. @@ -609,8 +621,9 @@ def post_process_patches( Args: raw_predictions (dict | Path): A dictionary or path to zarr with patch prediction information. - **kwargs (dict): - Keyword Args to update setup_patch_dataset() method attributes. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`EngineRunParams` for accepted keyword arguments. Returns: dict or Path: @@ -618,7 +631,7 @@ def post_process_patches( saved zarr file if `cache_mode` is True. """ - _ = kwargs.get("predictions") # Key values required for post-processing + _ = kwargs.get("probabilities") # Key values required for post-processing if self.cache_mode: # cache mode _ = zarr.open(raw_predictions, mode="w") @@ -627,7 +640,7 @@ def post_process_patches( def save_predictions( self: EngineABC, - processed_predictions: dict, + processed_predictions: dict | Path, output_type: str, save_dir: Path | None = None, **kwargs: dict, @@ -656,26 +669,36 @@ def save_predictions( `.zarr` file depending on whether a save_dir Path is provided. """ - if (self.cache_mode or not save_dir) and output_type != "AnnotationStore": + if ( + self.cache_mode or not save_dir + ) and output_type.lower() != "annotationstore": return processed_predictions - output_file = Path(kwargs.get("output_file", "output.db")) - - save_path = save_dir / output_file + save_path = Path(kwargs.get("output_file", save_dir / "output.db")) - if output_type == "AnnotationStore": + if output_type.lower() == "annotationstore": # scale_factor set from kwargs - scale_factor = kwargs.get("scale_factor") + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) # class_dict set from kwargs class_dict = kwargs.get("class_dict") + processed_predictions_path: str | Path | None = None + # Need to add support for zarr conversion. - return dict_to_store( + if self.cache_mode: + processed_predictions_path = processed_predictions + processed_predictions = zarr.open(processed_predictions, mode="r") + + out_file = dict_to_store( processed_predictions, scale_factor, class_dict, save_path, ) + if processed_predictions_path is not None: + shutil.rmtree(processed_predictions_path) + + return out_file return ( dict_to_zarr( @@ -691,11 +714,9 @@ def save_predictions( def infer_wsi( self: EngineABC, dataloader: torch.utils.data.DataLoader, - img_label: str, - highest_input_resolution: list[dict], - save_dir: Path, + save_path: Path | str, **kwargs: dict, - ) -> list: + ) -> dict | Path: """Model inference on a WSI. This function must be implemented by subclasses. @@ -704,22 +725,28 @@ def infer_wsi( # return coordinates of patches processed within a tile / whole-slide image raise NotImplementedError + @abstractmethod + def post_process_wsi( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output.""" + _ = kwargs.get("probabilities") # Key values required for post-processing + return raw_predictions + @abstractmethod def save_wsi_output( self: EngineABC, - raw_output: dict | Path, - save_dir: Path, + processed_output: Path, output_type: str, **kwargs: Unpack[EngineABCRunParams], - ) -> AnnotationStore | Path: - """Post-process a WSI. + ) -> Path: + """Aggregate the output at the WSI level and save to file. Args: - raw_output (dict | Path): - A dictionary with output information or zarr file path. - save_dir (Path): - Output Path to directory to save the patch dataset output to a - `.zarr` or `.db` file + processed_output (Path): + Path to Zarr file with intermediate results. output_type (str): The desired output type for resulting patch dataset. **kwargs (EngineABCRunParams): @@ -732,23 +759,22 @@ def save_wsi_output( stored in a `.zarr` file. """ - if ( - output_type == "zarr" - and isinstance(raw_output, Path) - and raw_output.suffix == ".zarr" - ): - return raw_output + if output_type.lower() == "zarr": + msg = "Output file saved at %s.", processed_output + logger.info(msg=msg) + return processed_output - output_file = kwargs.get("output_file", "output") - save_path = save_dir / output_file - - if output_type == "AnnotationStore": + if output_type.lower() == "annotationstore": + save_path = Path(kwargs.get("output_file", processed_output.stem + ".db")) # scale_factor set from kwargs scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # Read zarr file to a dict + raw_output_dict = zarr.open(str(processed_output), mode="r") + # class_dict set from kwargs class_dict = kwargs.get("class_dict") - return dict_to_store(raw_output, scale_factor, class_dict, save_path) + return dict_to_store(raw_output_dict, scale_factor, class_dict, save_path) msg = "Only supports zarr and AnnotationStore as output_type." raise ValueError(msg) @@ -778,7 +804,7 @@ def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfig "Please provide a valid ModelIOConfigABC. " "No default ModelIOConfigABC found." ) - raise ValueError(msg) + logger.warning(msg) if ioconfig and isinstance(ioconfig, ModelIOConfigABC): self.ioconfig = ioconfig @@ -914,6 +940,7 @@ def _update_run_params( labels: list | None = None, save_dir: os | Path | None = None, ioconfig: ModelIOConfigABC | None = None, + output_type: str = "dict", *, overwrite: bool = False, patch_mode: bool, @@ -928,10 +955,17 @@ def _update_run_params( setattr(self, key, kwargs.get(key)) self.patch_mode = patch_mode + if not self.patch_mode: + self.cache_mode = True # if input is WSI run using cache mode. + if self.cache_mode and self.batch_size > self.cache_size: self.batch_size = self.cache_size self._validate_input_numbers(images=images, masks=masks, labels=labels) + if output_type.lower() not in ["dict", "zarr", "annotationstore"]: + msg = "output_type must be 'dict' or 'zarr' or 'annotationstore'." + raise TypeError(msg) + self.images = self._validate_images_masks(images=images) if masks is not None: @@ -966,11 +1000,15 @@ def _run_patch_mode( """ save_path = None if self.cache_mode: - output_file = Path(kwargs.get("output_file", "output.db")) + output_file = Path(kwargs.get("output_file", "output.zarr")) save_path = save_dir / (str(output_file.stem) + ".zarr") + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + dataloader = self.get_dataloader( images=self.images, + masks=self.masks, labels=self.labels, patch_mode=True, ) @@ -982,6 +1020,8 @@ def _run_patch_mode( raw_predictions=raw_predictions, **kwargs, ) + logger.removeFilter(duplicate_filter) + return self.save_predictions( processed_predictions=processed_predictions, output_type=output_type, @@ -989,6 +1029,106 @@ def _run_patch_mode( **kwargs, ) + @staticmethod + def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, float]: + """Calculates scale factor for final output. + + Uses the dataloader resolution and the WSI resolution to calculate scale + factor for final WSI output. + + Args: + dataloader (DataLoader): + Dataloader for the current run. + + Returns: + scale_factor (float | tuple[float, float]): + Scale factor for final output. + + """ + # get units and resolution from dataloader. + dataloader_units = dataloader.dataset.units + dataloader_resolution = dataloader.dataset.resolution + + # if dataloader units is baseline slide resolution is 1.0. + # in this case dataloader resolution / slide resolution will be + # equal to dataloader resolution. + + if dataloader_units in ["mpp", "level", "power"]: + wsimeta_dict = dataloader.dataset.reader.info.as_dict() + + if dataloader_units == "mpp": + slide_resolution = wsimeta_dict[dataloader_units] + scale_factor = np.divide(slide_resolution, dataloader_resolution) + return scale_factor[0], scale_factor[1] + + if dataloader_units == "level": + downsample_ratio = wsimeta_dict["level_downsamples"][dataloader_resolution] + return 1.0 / downsample_ratio, 1.0 / downsample_ratio + + if dataloader_units == "power": + slide_objective_power = wsimeta_dict["objective_power"] + return ( + dataloader_resolution / slide_objective_power, + dataloader_resolution / slide_objective_power, + ) + + return dataloader_resolution + + def _run_wsi_mode( + self: EngineABC, + output_type: str, + save_dir: Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | AnnotationStore | Path: + """Runs the Engine in the WSI mode (patch_mode = False). + + Input arguments are passed from :func:`EngineABC.run()`. + + """ + suffix = ".zarr" + if output_type == "AnnotationStore": + suffix = ".db" + + out = {image: save_dir / (str(image.stem) + suffix) for image in self.images} + + save_path = { + image: save_dir / (str(image.stem) + ".zarr") for image in self.images + } + + for image_num, image in enumerate(self.images): + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + mask = self.masks[image_num] if self.masks is not None else None + dataloader = self.get_dataloader( + images=image, + masks=mask, + patch_mode=False, + ioconfig=self._ioconfig, + ) + + scale_factor = self._calculate_scale_factor(dataloader=dataloader) + + raw_predictions = self.infer_wsi( + dataloader=dataloader, + save_path=save_path[image], + **kwargs, + ) + processed_predictions = self.post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, + ) + kwargs["output_file"] = out[image] + kwargs["scale_factor"] = scale_factor + out[image] = self.save_predictions( + processed_predictions=processed_predictions, + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + logger.removeFilter(duplicate_filter) + + return out + def run( self: EngineABC, images: list[os | Path | WSIReader] | np.ndarray, @@ -1031,7 +1171,7 @@ def run( Whether to overwrite the results. Default = False. output_type (str): The format of the output type. "output_type" can be - "zarr" or "AnnotationStore". Default value is "zarr". + "dict", "zarr" or "AnnotationStore". Default value is "zarr". When saving in the zarr format the output is saved using the `python zarr library `__ as a zarr group. If the required output type is an "AnnotationStore" @@ -1087,6 +1227,7 @@ def run( ioconfig=ioconfig, overwrite=overwrite, patch_mode=patch_mode, + output_type=output_type, **kwargs, ) @@ -1101,4 +1242,8 @@ def run( # highest_input_resolution, implement dataloader, # pre-processing, post-processing and save_output # for WSIs separately. - raise NotImplementedError + return self._run_wsi_mode( + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 5837693e2..b98c6676d 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -1,33 +1,26 @@ -"""This module implements patch level prediction.""" +"""Defines Abstract Base Class for TIAToolbox Model Engines.""" from __future__ import annotations -import copy -from collections import OrderedDict -from pathlib import Path -from typing import TYPE_CHECKING, Callable, NoReturn +from typing import TYPE_CHECKING -import numpy as np -import torch -import tqdm +from typing_extensions import Unpack -import tiatoolbox.models.models_abc -from tiatoolbox import logger -from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset -from tiatoolbox.utils import save_as_json -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from .engine_abc import EngineABC, EngineABCRunParams if TYPE_CHECKING: # pragma: no cover import os + from pathlib import Path + + import numpy as np + from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore - from tiatoolbox.typing import IntPair, Resolution, Units + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore.wsireader import WSIReader from .io_config import ModelIOConfigABC -from .engine_abc import EngineABC -from .io_config import IOPatchPredictorConfig - class PatchPredictor(EngineABC): r"""Patch level predictor for digital histology images. @@ -117,83 +110,161 @@ class PatchPredictor(EngineABC): - 0.867 Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with. - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, - refer to the `docs + model (str | ModelABC): + A PyTorch model or name of pretrained model. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case-insensitive. - weights (str): - Path to the weight of the corresponding `pretrained_model`. - - >>> predictor = PatchPredictor( - ... pretrained_model="resnet18-kather100k", - ... weights="resnet18_local_weight") - + of weights using the `weights` parameter. Default is `None`. batch_size (int): - Number of images fed into the model each time. + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. num_loader_workers (int): - Number of workers to load the data. Take note that they will - also perform preprocessing. + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = EngineABC( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Attributes: - images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): - A HWC image or a path to WSI. - mode (str): - Type of input to process. Choose from either `patch`, `tile` - or `wsi`. - model (nn.Module): - Defined PyTorch model. - model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (str): + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, refer to the `docs `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case insensitive. + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. + _ioconfig (ModelIOConfigABC): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + merge_predictions (bool): + Whether to merge the predictions to form a 2-dimensional + map. This is only applicable if `patch_mode` is False in inference. + Default is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. batch_size (int): Number of images fed into the model each time. - num_loader_worker (int): - Number of workers used in torch.utils.data.DataLoader. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + merge_predictions (bool): + Whether to merge WSI predictions into a single file. Default value is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + units (Units): + Units of resolution used for reading the image. Choose + from either `baseline`, `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. + Default value is `baseline`. verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Examples: >>> # list of 2 image patches as input >>> data = ['path/img.svs', 'path/img.svs'] - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') >>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") + >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') >>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] - >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") - >>> output = predictor.predict(tile_file, mode='tile') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(tile_file, mode='tile') >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] - >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") - >>> output = predictor.predict(wsi_file, mode='wsi') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(wsi_file, mode='wsi') References: [1] Kather, Jakob Nikolas, et al. "Predicting survival from colorectal cancer @@ -208,526 +279,143 @@ class PatchPredictor(EngineABC): def __init__( self: PatchPredictor, + model: str | ModelABC, batch_size: int = 8, num_loader_workers: int = 0, num_post_proc_workers: int = 0, - model: torch.nn.Module = None, - pretrained_model: str | None = None, - weights: str | None = None, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, ) -> None: """Initialize :class:`PatchPredictor`.""" super().__init__( + model=model, batch_size=batch_size, num_loader_workers=num_loader_workers, num_post_proc_workers=num_post_proc_workers, - model=model, - pretrained_model=pretrained_model, weights=weights, + device=device, verbose=verbose, ) - def pre_process_wsi(self: PatchPredictor) -> NoReturn: - """Pre-process a WSI.""" - - def infer_wsi(self: PatchPredictor) -> NoReturn: - """Model inference on a WSI.""" - - def save_predictions( + def get_dataloader( self: PatchPredictor, - raw_predictions: dict, - output_type: str, - ) -> None: - """Post-process an image patch.""" - - def save_wsi_output(self: PatchPredictor) -> NoReturn: - """Post-process a WSI.""" - - @staticmethod - def merge_predictions( - img: str | Path | np.ndarray, - output: dict, - resolution: Resolution | None = None, - units: Units | None = None, - post_proc_func: Callable | None = None, + images: Path, + masks: Path | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, *, - return_raw: bool = False, - ) -> np.ndarray: - """Merge patch level predictions to form a 2-dimensional prediction map. - - #! Improve how the below reads. - The prediction map will contain values from 0 to N, where N is - the number of classes. Here, 0 is the background which has not - been processed by the model and N is the number of classes - predicted by the model. + patch_mode: bool = True, + ) -> DataLoader: + """Pre-process images and masks and return dataloader for inference. Args: - img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): - A HWC image or a path to WSI. - output (dict): - Output generated by the model. - resolution (Resolution): - Resolution of merged predictions. - units (Units): - Units of resolution used when merging predictions. This - must be the same `units` used when processing the data. - post_proc_func (callable): - A function to post-process raw prediction from model. By - default, internal code uses the `np.argmax` function. - return_raw (bool): - Return raw result without applying the `postproc_func` - on the assembled image. + images (list of str or :class:`Path` or :class:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. When `patch_mode` is False + the function expects list of str/paths to WSIs. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + ioconfig (ModelIOConfigABC): + A :class:`ModelIOConfigABC` object. + patch_mode (bool): + Whether to treat input image as a patch or WSI. Returns: - :class:`numpy.ndarray`: - Merged predictions as a 2D array. + DataLoader: + :class:`DataLoader` for inference. - Examples: - >>> # pseudo output dict from model with 2 patches - >>> output = { - ... 'resolution': 1.0, - ... 'units': 'baseline', - ... 'probabilities': [[0.45, 0.55], [0.90, 0.10]], - ... 'predictions': [1, 0], - ... 'coordinates': [[0, 0, 2, 2], [2, 2, 4, 4]], - ... } - >>> merged = PatchPredictor.merge_predictions( - ... np.zeros([4, 4]), - ... output, - ... resolution=1.0, - ... units='baseline' - ... ) - >>> merged - ... array([[2, 2, 0, 0], - ... [2, 2, 0, 0], - ... [0, 0, 1, 1], - ... [0, 0, 1, 1]]) """ - reader = WSIReader.open(img) - if isinstance(reader, VirtualWSIReader): - logger.warning( - "Image is not pyramidal hence read is forced to be " - "at `units='baseline'` and `resolution=1.0`.", - stacklevel=2, - ) - resolution = 1.0 - units = "baseline" - - canvas_shape = reader.slide_dimensions(resolution=resolution, units=units) - canvas_shape = canvas_shape[::-1] # XY to YX - - # may crash here, do we need to deal with this ? - output_shape = reader.slide_dimensions( - resolution=output["resolution"], - units=output["units"], - ) - output_shape = output_shape[::-1] # XY to YX - fx = np.array(canvas_shape) / np.array(output_shape) - - if "probabilities" not in output: - coordinates = output["coordinates"] - predictions = output["predictions"] - denominator = None - output = np.zeros(list(canvas_shape), dtype=np.float32) - else: - coordinates = output["coordinates"] - predictions = output["probabilities"] - num_class = np.array(predictions[0]).shape[0] - denominator = np.zeros(canvas_shape) - output = np.zeros([*list(canvas_shape), num_class], dtype=np.float32) - - for idx, bound in enumerate(coordinates): - prediction = predictions[idx] - # assumed to be in XY - # top-left for output placement - tl = np.ceil(np.array(bound[:2]) * fx).astype(np.int32) - # bot-right for output placement - br = np.ceil(np.array(bound[2:]) * fx).astype(np.int32) - output[tl[1] : br[1], tl[0] : br[0]] += prediction - if denominator is not None: - denominator[tl[1] : br[1], tl[0] : br[0]] += 1 - - # deal with overlapping regions - if denominator is not None: - output = output / (np.expand_dims(denominator, -1) + 1.0e-8) - if not return_raw: - # convert raw probabilities to predictions - if post_proc_func is not None: - output = post_proc_func(output) - else: - output = np.argmax(output, axis=-1) - # to make sure background is 0 while class will be 1...N - output[denominator > 0] += 1 - return output - - def _predict_engine( - self: PatchPredictor, - dataset: torch.utils.data.Dataset, - *, - return_probabilities: bool = False, - return_labels: bool = False, - return_coordinates: bool = False, - device: str = "cpu", - ) -> np.ndarray: - """Make a prediction on a dataset. The dataset may be mutated. - - Args: - dataset (torch.utils.data.Dataset): - PyTorch dataset object created using - `tiatoolbox.models.data.classification.Patch_Dataset`. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return labels. - return_coordinates (bool): - Whether to return patch coordinates. - device (str): - Select the device to run the model. Default is "cpu". - - Returns: - :class:`numpy.ndarray`: - Model predictions of the input dataset - - """ - dataset.preproc_func = self.model.preproc_func - - # preprocessing must be defined with the dataset - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=self.num_loader_workers, - batch_size=self.batch_size, - drop_last=False, - shuffle=False, + return super().get_dataloader( + images, + masks, + labels, + ioconfig, + patch_mode=patch_mode, ) - if self.verbose: - pbar = tqdm.tqdm( - total=int(len(dataloader)), - leave=True, - ncols=80, - ascii=True, - position=0, - ) - - # use external for testing - model = tiatoolbox.models.models_abc.model_to(model=self.model, device=device) - - cum_output = { - "probabilities": [], - "predictions": [], - "coordinates": [], - "labels": [], - } - for _, batch_data in enumerate(dataloader): - batch_output_probabilities = self.model.infer_batch( - model, - batch_data["image"], - device=device, - ) - # We get the index of the class with the maximum probability - batch_output_predictions = self.model.postproc_func( - batch_output_probabilities, - ) - - # tolist might be very expensive - cum_output["probabilities"].extend(batch_output_probabilities.tolist()) - cum_output["predictions"].extend(batch_output_predictions.tolist()) - if return_coordinates: - cum_output["coordinates"].extend(batch_data["coords"].tolist()) - if return_labels: # be careful of `s` - # We do not use tolist here because label may be of mixed types - # and hence collated as list by torch - cum_output["labels"].extend(list(batch_data["label"])) - - if self.verbose: - pbar.update() - if self.verbose: - pbar.close() - - if not return_probabilities: - cum_output.pop("probabilities") - if not return_labels: - cum_output.pop("labels") - if not return_coordinates: - cum_output.pop("coordinates") - - return cum_output - - def _update_ioconfig( - self: PatchPredictor, - ioconfig: IOPatchPredictorConfig, - patch_input_shape: IntPair, - stride_shape: IntPair, - resolution: Resolution, - units: Units, - ) -> IOPatchPredictorConfig: - """Update the ioconfig. + def infer_wsi( + self: EngineABC, + dataloader: DataLoader, + save_path: Path, + **kwargs: EngineABCRunParams, + ) -> Path: + """Model inference on a WSI. Args: - ioconfig (:class:`IOPatchPredictorConfig`): - Input ioconfig for PatchPredictor. - patch_input_shape (tuple): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. + dataloader (DataLoader): + A torch dataloader to process WSIs. + + save_path (Path): + Path to save the intermediate output. The intermediate output is saved + in a zarr file. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`EngineRunParams` for accepted keyword arguments. Returns: - Updated Patch Predictor IO configuration. + save_path (Path): + Path to zarr file where intermediate output is saved. """ - config_flag = ( - patch_input_shape is None, - resolution is None, - units is None, + _ = kwargs.get("patch_mode", False) + return self.infer_patches( + dataloader=dataloader, + save_path=save_path, + return_coordinates=True, ) - if ioconfig: - return ioconfig - - if self.ioconfig is None and any(config_flag): - msg = ( - "Must provide either " - "`ioconfig` or `patch_input_shape`, `resolution`, and `units`." - ) - raise ValueError( - msg, - ) - - if stride_shape is None: - stride_shape = patch_input_shape - - if self.ioconfig: - ioconfig = copy.deepcopy(self.ioconfig) - # ! not sure if there is a nicer way to set this - if patch_input_shape is not None: - ioconfig.patch_input_shape = patch_input_shape - if stride_shape is not None: - ioconfig.stride_shape = stride_shape - if resolution is not None: - ioconfig.input_resolutions[0]["resolution"] = resolution - if units is not None: - ioconfig.input_resolutions[0]["units"] = units - - return ioconfig - - return IOPatchPredictorConfig( - input_resolutions=[{"resolution": resolution, "units": units}], - patch_input_shape=patch_input_shape, - stride_shape=stride_shape, - output_resolutions=[], - ) - - def _predict_patch( - self: PatchPredictor, - imgs: list | np.ndarray, - labels: list, - *, - return_probabilities: bool, - return_labels: bool, - device: str, - ) -> np.ndarray: - """Process patch mode. - Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return the labels with the predictions. - device (str): - Select the device to run the engine. + def post_process_wsi( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output. - Returns: - :class:`numpy.ndarray`: - Model predictions of the input dataset + Takes the raw output from patch predictions and post-processes it to improve the + results e.g., using information from neighbouring patches. """ - if labels: - # if a labels is provided, then return with the prediction - return_labels = bool(labels) - - if labels and len(labels) != len(imgs): - msg = f"len(labels) != len(imgs) : {len(labels)} != {len(imgs)}" - raise ValueError( - msg, - ) - - # don't return coordinates if patches are already extracted - return_coordinates = False - dataset = PatchDataset(imgs, labels) - return self._predict_engine( - dataset, - return_probabilities=return_probabilities, - return_labels=return_labels, - return_coordinates=return_coordinates, - device=device, + return super().post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, ) - def _predict_tile_wsi( # noqa: PLR0913 - self: PatchPredictor, - imgs: list, - masks: list | None, - labels: list, - mode: str, - ioconfig: IOPatchPredictorConfig, - save_dir: str | Path, - highest_input_resolution: list[dict], - *, - save_output: bool, - return_probabilities: bool, - merge_predictions: bool, - on_gpu: bool, - ) -> list | dict: - """Predict on Tile and WSIs. + def save_wsi_output( + self: EngineABC, + processed_output: Path, + output_type: str, + **kwargs: Unpack[EngineABCRunParams], + ) -> Path: + """Aggregate the output at the WSI level and save to file. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - masks (list or None): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - mode (str): - Type of input to process. Choose from either `patch`, - `tile` or `wsi`. - return_probabilities (bool): - Whether to return per-class probabilities. - on_gpu (bool): - Whether to run model on the GPU. - ioconfig (IOPatchPredictorConfig): - Patch Predictor IO configuration.. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable for `mode='wsi'` or - `mode='tile'`. - save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - save_output (bool): - Whether to save output for a single file. default=False - highest_input_resolution (list(dict)): - Highest available input resolution. - - - Returns: - dict: - Results are saved to `save_dir` and a dictionary indicating save - location for each input is returned. The dict is in the following - format: - - img_path: path of the input image. - - raw: path to save location for raw prediction, - saved in .json. - - merged: path to .npy contain merged - predictions if - `merge_predictions` is `True`. + processed_output (Path): + Path to Zarr file with intermediate results. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. + + Returns: (AnnotationStore or Path): + If the output_type is "AnnotationStore", the function returns the patch + predictor output as an SQLiteStore containing Annotations stored in a `.db` + file. Otherwise, the function defaults to returning patch predictor output + stored in a `.zarr` file. """ - # return coordinates of patches processed within a tile / whole-slide image - return_coordinates = True - - input_is_path_like = isinstance(imgs[0], (str, Path)) - default_save_dir = ( - imgs[0].parent / "output" if input_is_path_like else Path.cwd() + return super().save_wsi_output( + processed_output=processed_output, + output_type=output_type, + **kwargs, ) - save_dir = default_save_dir if save_dir is None else Path(save_dir) - - # None if no output - outputs = None - - self._ioconfig = ioconfig - # generate a list of output file paths if number of input images > 1 - file_dict = OrderedDict() - - if len(imgs) > 1: - save_output = True - - for idx, img_path in enumerate(imgs): - img_path_ = Path(img_path) - img_label = None if labels is None else labels[idx] - img_mask = None if masks is None else masks[idx] - - dataset = WSIPatchDataset( - img_path_, - mode=mode, - mask_path=img_mask, - patch_input_shape=ioconfig.patch_input_shape, - stride_shape=ioconfig.stride_shape, - resolution=ioconfig.input_resolutions[0]["resolution"], - units=ioconfig.input_resolutions[0]["units"], - ) - output_model = self._predict_engine( - dataset, - return_labels=False, - return_probabilities=return_probabilities, - return_coordinates=return_coordinates, - on_gpu=on_gpu, - ) - output_model["label"] = img_label - # add extra information useful for downstream analysis - output_model["pretrained_model"] = self.model - output_model["resolution"] = highest_input_resolution["resolution"] - output_model["units"] = highest_input_resolution["units"] - - outputs = [output_model] # assign to a list - merged_prediction = None - if merge_predictions: - merged_prediction = self.merge_predictions( - img_path_, - output_model, - resolution=output_model["resolution"], - units=output_model["units"], - post_proc_func=self.model.postproc, - ) - outputs.append(merged_prediction) - - if save_output: - # dynamic 0 padding - img_code = f"{idx:0{len(str(len(imgs)))}d}" - - save_info = {} - save_path = save_dir / img_code - raw_save_path = f"{save_path}.raw.json" - save_info["raw"] = raw_save_path - save_as_json(output_model, raw_save_path) - if merge_predictions: - merged_file_path = f"{save_path}.merged.npy" - np.save(merged_file_path, merged_prediction) - save_info["merged"] = merged_file_path - file_dict[str(img_path_)] = save_info - - return file_dict if save_output else outputs def run( self: EngineABC, @@ -740,102 +428,51 @@ def run( save_dir: os | Path | None = None, # None will not save output overwrite: bool = False, output_type: str = "dict", - **kwargs: dict, - ) -> AnnotationStore | str: - """Run engine.""" - super().run( - images=images, - masks=masks, - labels=labels, - ioconfig=ioconfig, - patch_mode=patch_mode, - save_dir=save_dir, - overwrite=overwrite, - output_type=output_type, - **kwargs, - ) - - def predict( # noqa: PLR0913 - self: PatchPredictor, - imgs: list, - masks: list | None = None, - labels: list | None = None, - mode: str = "patch", - ioconfig: IOPatchPredictorConfig | None = None, - patch_input_shape: tuple[int, int] | None = None, - stride_shape: tuple[int, int] | None = None, - resolution: Resolution | None = None, - units: Units = None, - *, - return_probabilities: bool = False, - return_labels: bool = False, - on_gpu: bool = True, - merge_predictions: bool = False, - save_dir: str | Path | None = None, - save_output: bool = False, - ) -> np.ndarray | list | dict: - """Make a prediction for a list of input data. + **kwargs: Unpack[EngineABCRunParams], + ) -> AnnotationStore | Path | str | dict: + """Run the engine on input images. Args: - imgs (list, ndarray): + images (list, ndarray): List of inputs to process. when using `patch` mode, the input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - masks (list): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - mode (str): - Type of input to process. Choose from either `patch`, - `tile` or `wsi`. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. ioconfig (IOPatchPredictorConfig): - Patch Predictor IO configuration. - patch_input_shape (tuple): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. Choose - from either `level`, `power` or `mpp`. Please see - :obj:`WSIReader` for details. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable for `mode='wsi'` or - `mode='tile'`. + IO configuration. save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - save_output (bool): - Whether to save output for a single file. default=False + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr" or "AnnotationStore". Default value is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (EngineABCRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. Returns: (:class:`numpy.ndarray`, dict): Model predictions of the input dataset. If multiple - image tiles or whole-slide images are provided as input, + whole slide images are provided as input, or save_output is True, then results are saved to `save_dir` and a dictionary indicating save location for each input is returned. @@ -850,79 +487,34 @@ def predict( # noqa: PLR0913 Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] - >>> predictor = PatchPredictor( - ... pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(wsis, mode="wsi") + >>> class PatchPredictor(EngineABC): + >>> # Define all Abstract methods. + >>> ... + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = predictor.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" + >>> output = predictor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] - ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} - >>> output['wsi2.svs'] - ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} + ... {'/path/to/wsi1.db'} """ - if mode not in ["patch", "wsi", "tile"]: - msg = f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`" - raise ValueError( - msg, - ) - if mode == "patch": - return self._predict_patch( - imgs, - labels, - return_probabilities, - return_labels, - on_gpu, - ) - - if not isinstance(imgs, list): - msg = "Input to `tile` and `wsi` mode must be a list of file paths." - raise TypeError( - msg, - ) - - if mode == "wsi" and masks is not None and len(masks) != len(imgs): - msg = f"len(masks) != len(imgs) : {len(masks)} != {len(imgs)}" - raise ValueError( - msg, - ) - - ioconfig = self._update_ioconfig( - ioconfig, - patch_input_shape, - stride_shape, - resolution, - units, - ) - if mode == "tile": - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - ioconfig = ioconfig.to_baseline() - - fx_list = ioconfig.scale_to_highest( - ioconfig.input_resolutions, - ioconfig.input_resolutions[0]["units"], - ) - fx_list = zip(fx_list, ioconfig.input_resolutions) - fx_list = sorted(fx_list, key=lambda x: x[0]) - highest_input_resolution = fx_list[0][1] - - save_dir = self._prepare_save_dir(save_dir, imgs) - - return self._predict_tile_wsi( - imgs, - masks, - labels, - mode, - return_probabilities, - on_gpu, - ioconfig, - merge_predictions, - save_dir, - save_output, - highest_input_resolution, + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, ) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 59a14c48c..5a3c8c7fe 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1201,33 +1201,40 @@ def add_from_dat( def patch_predictions_as_annotations( - preds: list, + preds: list | np.ndarray, keys: list, class_dict: dict, - class_probs: list, + class_probs: list | np.ndarray, patch_coords: list, classes_predicted: list, labels: list, ) -> list: """Helper function to generate annotation per patch predictions.""" annotations = [] - for i, pred in enumerate(preds): + for i, probs in enumerate(class_probs): if "probabilities" in keys: - props = { - f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted - } + props = {f"prob_{class_dict[j]}": probs[j] for j in classes_predicted} else: props = {} if "labels" in keys: props["label"] = class_dict[labels[i]] - props["type"] = class_dict[pred] + if len(preds) > 0: + props["type"] = class_dict[preds[i]] annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) return annotations +def _get_zarr_array(zarr_array: zarr.core.Array | np.ndarray) -> np.ndarray: + """Converts a zarr array into a numpy array.""" + if isinstance(zarr_array, zarr.core.Array): + return zarr_array[:] + + return zarr_array + + def dict_to_store( - patch_output: dict, + patch_output: dict | zarr.group, scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, @@ -1235,7 +1242,7 @@ def dict_to_store( """Converts (and optionally saves) output of TIAToolbox engines as AnnotationStore. Args: - patch_output (dict): + patch_output (dict | zarr.Group): A dictionary with "probabilities", "predictions", "coordinates", and "labels" keys. scale_factor (tuple[float, float]): @@ -1260,9 +1267,10 @@ def dict_to_store( # we cant create annotations without coordinates msg = "Patch output must contain coordinates." raise ValueError(msg) + # get relevant keys - class_probs = patch_output.get("probabilities", []) - preds = patch_output.get("predictions", []) + class_probs = _get_zarr_array(patch_output.get("probabilities", [])) + preds = _get_zarr_array(patch_output.get("predictions", [])) patch_coords = np.array(patch_output.get("coordinates", [])) if not np.all(np.array(scale_factor) == 1): @@ -1301,7 +1309,7 @@ def dict_to_store( # if a save director is provided, then dump store into a file if save_path: - # ensure parent directory exisits + # ensure parent directory exists save_path.parent.absolute().mkdir(parents=True, exist_ok=True) # ensure proper db extension save_path = save_path.parent.absolute() / (save_path.stem + ".db") @@ -1341,15 +1349,15 @@ def dict_to_zarr( save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") # save to zarr - predictions_array = np.array(raw_predictions["predictions"]) + probabilities_array = np.array(raw_predictions["probabilities"]) z = zarr.open( - save_path, + str(save_path), mode="w", - shape=predictions_array.shape, + shape=probabilities_array.shape, chunks=chunks, compressor=compressor, ) - z[:] = predictions_array + z[:] = probabilities_array return save_path @@ -1463,7 +1471,8 @@ def write_to_zarr_in_cache_mode( Zarr group name consisting of zarr(s) to save the batch output values. output_data_to_save (dict): - Output data from the Engine to save to Zarr. + Output data from the Engine to save to Zarr. Expects the data saved in + dictionary to be a numpy array. **kwargs (dict): Keyword Args to update zarr_group attributes. @@ -1486,6 +1495,8 @@ def write_to_zarr_in_cache_mode( ) zarr_dataset[:] = data_to_save + return zarr_group + # case 2 - append to existing zarr group for key in output_data_to_save: zarr_group[key].append(output_data_to_save[key]) From 8c2f50b396d530ddc7dd2f85c8193e28dbf403b4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 20 Sep 2024 14:26:41 +0100 Subject: [PATCH 28/36] :bug: Fix `mypy` Type Checks for `cli/common.py` (#864) - Fix `mypy` Type Checks for `cli/common.py` --- tiatoolbox/cli/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 6767fa578..6a278d01f 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -92,7 +92,7 @@ def cli_output_type( "Default value is 'AnnotationStore'.", default: str = "AnnotationStore", input_type: click.Choice | None = None, -) -> callable: +) -> Callable: """Enables --file-types option for cli.""" if input_type is None: input_type = click.Choice(["zarr", "AnnotationStore"], case_sensitive=False) @@ -124,7 +124,7 @@ def cli_patch_mode( usage_help: str = "Whether to run the model in patch mode or WSI mode.", *, default: bool = False, -) -> callable: +) -> Callable: """Enables --return-probabilities option for cli.""" return click.option( "--patch-mode", @@ -277,7 +277,7 @@ def cli_model( "downloaded. However, you can override with your own set of weights" "via the `pretrained_weights` argument. Argument is case insensitive.", default: str = "resnet18-kather100k", -) -> callable: +) -> Callable: """Enables --pretrained-model option for cli.""" return click.option( "--model", @@ -290,7 +290,7 @@ def cli_weights( usage_help: str = "Path to the model weight file. If not supplied, the default " "pretrained weight will be used.", default: str | None = None, -) -> callable: +) -> Callable: """Enables --pretrained-weights option for cli.""" return click.option( "--weights", @@ -302,7 +302,7 @@ def cli_weights( def cli_device( usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.", default: str = "cpu", -) -> callable: +) -> Callable: """Enables --pretrained-weights option for cli.""" return click.option( "--device", From 43afaf7fd2c75d976e23877d52e3c67a1aa3e2b3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 17 Nov 2024 14:23:11 +0000 Subject: [PATCH 29/36] :bug: Fix `model_to` import --- tests/models/test_arch_vanilla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index df0e03fb0..894bd2ef3 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -5,7 +5,8 @@ import torch from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel -from tiatoolbox.utils.misc import model_to, select_device +from tiatoolbox.models.models_abc import model_to +from tiatoolbox.utils.misc import select_device ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator From dba269cc503035e6bbf1e935046419a4a2a44a07 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 17 Nov 2024 14:51:51 +0000 Subject: [PATCH 30/36] :bug: Fix `model_to` device specification --- tests/models/test_arch_vanilla.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index 894bd2ef3..a87424dfd 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -6,7 +6,6 @@ from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel from tiatoolbox.models.models_abc import model_to -from tiatoolbox.utils.misc import select_device ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator @@ -46,7 +45,7 @@ def test_functional() -> None: for backbone in backbones: model = CNNModel(backbone, num_classes=1) model_ = model_to(device=device, model=model) - model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU)) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc @@ -72,8 +71,8 @@ def test_timm_functional() -> None: try: for backbone in backbones: model = TimmModel(backbone=backbone, num_classes=1, pretrained=False) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc From 1fccf15b4db1f95bc82979279c65ef1776457eab Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:52:59 +0000 Subject: [PATCH 31/36] =?UTF-8?q?=E2=9C=A8=20Add=20`PatchPredictor`=20Engi?= =?UTF-8?q?ne=20(#865)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `PatchPredictor` Engine based on `EngineABC` - Add `return_probabilities` option to Params - Removes `merge_predictions` option in `PatchPredictor` engine. - Defines `post_process_cache_mode` which allows running the algorithm on `WSI` - Add `infer_wsi` for WSI inference - Removes `save_wsi_output` as this is not required after post processing. - Removes `merge_predictions` and fixes docstring in EngineABCRunParams - `compile_model` is now moved to EngineABC init - Fixes bug with `_calculate_scale_factor` - Fixes a bug in `class_dict` definition. - `_get_zarr_array` is now a public function `get_zarr_array` in `misc` - `patch_predictions_as_annotations` runs the loop on `patch_coords` instead of `class_probs` --------- Co-authored-by: Mark Eastwood <20169086+measty@users.noreply.github.com> Co-authored-by: Mostafa Jahanifar <74412979+mostafajahanifar@users.noreply.github.com> Co-authored-by: Adam Shephard <39619155+adamshephard@users.noreply.github.com> Co-authored-by: Jiaqi-Lv <60471431+Jiaqi-Lv@users.noreply.github.com> --- tests/engines/test_engine_abc.py | 101 +---- tests/engines/test_patch_predictor.py | 430 ++++++++++++-------- tests/models/test_models_abc.py | 12 + tiatoolbox/cli/common.py | 2 +- tiatoolbox/cli/patch_predictor.py | 18 +- tiatoolbox/models/architecture/utils.py | 8 +- tiatoolbox/models/engine/engine_abc.py | 147 +++---- tiatoolbox/models/engine/patch_predictor.py | 258 ++++++------ tiatoolbox/utils/misc.py | 19 +- 9 files changed, 511 insertions(+), 484 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index a5ce07d30..acfcfc888 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -11,7 +11,6 @@ import numpy as np import pytest import torchvision.models as torch_models -import zarr from typing_extensions import Unpack from tiatoolbox.models.architecture import ( @@ -26,7 +25,6 @@ prepare_engines_save_dir, ) from tiatoolbox.models.engine.io_config import ModelIOConfigABC -from tiatoolbox.utils.misc import write_to_zarr_in_cache_mode if TYPE_CHECKING: import torch.nn @@ -62,19 +60,6 @@ def get_dataloader( patch_mode=patch_mode, ) - def save_wsi_output( - self: EngineABC, - processed_output: dict, - save_dir: Path, - **kwargs: dict, - ) -> Path: - """Test post_process_wsi.""" - return super().save_wsi_output( - processed_output, - save_dir=save_dir, - **kwargs, - ) - def post_process_wsi( self: EngineABC, raw_predictions: dict | Path, @@ -100,16 +85,6 @@ def infer_wsi( ) -def test_engine_abc() -> NoReturn: - """Test EngineABC initialization.""" - with pytest.raises( - TypeError, - match=r".*Can't instantiate abstract class EngineABC*", - ): - # Can't instantiate abstract class with abstract methods - EngineABC() # skipcq - - def test_engine_abc_incorrect_model_type() -> NoReturn: """Test EngineABC initialization with incorrect model type.""" with pytest.raises( @@ -295,7 +270,7 @@ def test_engine_initalization() -> NoReturn: assert isinstance(eng, EngineABC) -def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn: +def test_engine_run() -> NoReturn: """Test engine run.""" eng = TestEngineABC(model="alexnet-kather100k") assert isinstance(eng, EngineABC) @@ -372,14 +347,10 @@ def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn: assert "probabilities" in out assert "labels" in out - eng = TestEngineABC(model="alexnet-kather100k") - - with pytest.raises(NotImplementedError): - eng.run( - images=[sample_svs], - save_dir=tmp_path / "output", - patch_mode=False, - ) + pred = eng.post_process_wsi( + raw_predictions=Path("/path/to/raw_predictions.npy"), + ) + assert str(pred) == "/path/to/raw_predictions.npy" def test_engine_run_with_verbose() -> NoReturn: @@ -542,55 +513,6 @@ def test_get_dataloader(sample_svs: Path) -> None: assert isinstance(dataloader.dataset, WSIPatchDataset) -def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: - """Test the eng.save_output() function.""" - eng = TestEngineABC(model="alexnet-kather100k") - save_path = tmp_path / "output.zarr" - _ = zarr.open(save_path, mode="w") - out = eng.save_wsi_output( - processed_output=save_path, - save_path=save_path, - output_type="zarr", - save_dir=tmp_path, - ) - - assert out.exists() - assert out.suffix == ".zarr" - - # Test AnnotationStore - patch_output = { - "predictions": np.array([1, 0, 1]), - "coordinates": np.array([(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)]), - } - class_dict = {0: "class0", 1: "class1"} - save_path = tmp_path / "output_db.zarr" - zarr_group = zarr.open(save_path, mode="w") - _ = write_to_zarr_in_cache_mode( - zarr_group=zarr_group, output_data_to_save=patch_output - ) - out = eng.save_wsi_output( - processed_output=save_path, - scale_factor=(1.0, 1.0), - class_dict=class_dict, - save_dir=tmp_path, - output_type="AnnotationStore", - ) - - assert out.exists() - assert out.suffix == ".db" - - with pytest.raises( - ValueError, - match=r".*supports zarr and AnnotationStore as output_type.", - ): - eng.save_wsi_output( - processed_output=save_path, - save_path=save_path, - output_type="dict", - save_dir=tmp_path, - ) - - def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: """Test for delegating args to io config.""" # test not providing config / full input info for not pretrained models @@ -701,16 +623,3 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) resolution=_kwargs["resolution"], units=_kwargs["units"], ) - - -def test_notimplementederror_wsi_mode( - sample_svs: Path, tmp_path: pytest.TempPathFactory -) -> None: - """Test that NotImplementedError is raised when wsi mode is False. - - A user should implement run method when patch_mode is False. - - """ - eng = TestEngineABC(model="alexnet-kather100k") - with pytest.raises(NotImplementedError): - eng.run(images=[sample_svs], patch_mode=False, save_dir=tmp_path / "output") diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index 8f62f5037..9b647fc97 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -7,27 +7,76 @@ import shutil import sqlite3 from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING, Callable import numpy as np +import torch import zarr from click.testing import CliRunner -from tiatoolbox import cli +from tests.conftest import timed +from tiatoolbox import cli, logger, rcParam from tiatoolbox.models import IOPatchPredictorConfig from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.engine.patch_predictor import PatchPredictor -from tiatoolbox.utils import download_data, imwrite from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import download_data, get_zarr_array, imwrite + +if TYPE_CHECKING: + import pytest device = "cuda" if toolbox_env.has_gpu() else "cpu" -ON_GPU = toolbox_env.has_gpu() -RNG = np.random.default_rng() # Numpy Random Generator -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- +def _test_predictor_output( + inputs: list, + model: str, + probabilities_check: list | None = None, + classification_check: list | None = None, + output_type: str = "dict", + tmp_path: Path | None = None, +) -> None: + """Test the predictions of multiple models included in tiatoolbox.""" + cache_mode = None if tmp_path is None else True + save_dir = None if tmp_path is None else tmp_path / "output" + predictor = PatchPredictor( + model=model, + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + inputs, + return_labels=False, + device=device, + cache_mode=cache_mode, + save_dir=save_dir, + output_type=output_type, + ) + + if tmp_path is not None: + output = zarr.open(output, mode="r") + + probabilities = output["probabilities"] + classification = output["predictions"] + for idx, probabilities_ in enumerate(probabilities): + probabilities_max = max(probabilities_) + assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( + model, + probabilities_max, + probabilities_check[idx], + probabilities_, + classification_check[idx], + ) + assert classification[idx] == classification_check[idx], ( + model, + probabilities_max, + probabilities_check[idx], + probabilities_, + classification_check[idx], + ) + if save_dir: + shutil.rmtree(save_dir) def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: @@ -129,7 +178,7 @@ def test_patch_predictor_api( sample_patch2: Path, tmp_path: Path, ) -> None: - """Helper function to get the model output using API 1.""" + """Test PatchPredictor API.""" save_dir_path = tmp_path # convert to pathlib Path to prevent reader complaint @@ -141,7 +190,7 @@ def test_patch_predictor_api( inputs, device="cpu", ) - assert sorted(output.keys()) == ["probabilities"] + assert sorted(output.keys()) == ["predictions", "probabilities"] assert len(output["probabilities"]) == 2 shutil.rmtree(save_dir_path, ignore_errors=True) @@ -151,7 +200,7 @@ def test_patch_predictor_api( labels=["1", "a"], return_labels=True, ) - assert sorted(output.keys()) == sorted(["labels", "probabilities"]) + assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) assert len(output["probabilities"]) == len(output["labels"]) assert output["labels"].tolist() == ["1", "a"] shutil.rmtree(save_dir_path, ignore_errors=True) @@ -187,7 +236,7 @@ def test_patch_predictor_api( return_labels=True, ioconfig=ioconfig, ) - assert sorted(output.keys()) == sorted(["labels", "probabilities"]) + assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) assert len(output["probabilities"]) == len(output["labels"]) assert output["labels"].tolist() == [1, 2] @@ -230,58 +279,25 @@ def test_wsi_predictor_api( **_kwargs, ) - wsi_pred = zarr.open(str(output[mini_wsi_svs]), mode="r") - tile_pred = zarr.open(str(output[mini_wsi_jpg]), mode="r") - diff = tile_pred["probabilities"][:] == wsi_pred["probabilities"][:] - accuracy = np.sum(diff) / np.size(wsi_pred["probabilities"][:]) + wsi_out = zarr.open(str(output[mini_wsi_svs]), mode="r") + tile_out = zarr.open(str(output[mini_wsi_jpg]), mode="r") + diff = tile_out["probabilities"][:] == wsi_out["probabilities"][:] + accuracy = np.sum(diff) / np.size(wsi_out["probabilities"][:]) assert accuracy > 0.99, np.nonzero(~diff) - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - + diff = tile_out["predictions"][:] == wsi_out["predictions"][:] + accuracy = np.sum(diff) / np.size(wsi_out["predictions"][:]) + assert accuracy > 0.99, np.nonzero(~diff) -def _test_predictor_output( - inputs: list, - model: str, - probabilities_check: list | None = None, - predictions_check: list | None = None, -) -> None: - """Test the predictions of multiple models included in tiatoolbox.""" - predictor = PatchPredictor( - model=model, - batch_size=32, - verbose=False, - ) - # don't run test on GPU - output = predictor.run( - inputs, - return_probabilities=True, - return_labels=False, - device=device, - ) - predictions = output["probabilities"] - for idx, probabilities_ in enumerate(predictions): - probabilities_max = max(probabilities_) - assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - model, - probabilities_max, - probabilities_check[idx], - probabilities_, - predictions_check[idx], - ) - assert np.argmax(probabilities_) == predictions_check[idx], ( - model, - probabilities_max, - probabilities_check[idx], - probabilities_, - predictions_check[idx], - ) + shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) def test_patch_predictor_kather100k_output( sample_patch1: Path, sample_patch2: Path, + tmp_path: Path, ) -> None: - """Test the output of patch prediction models on Kather100K dataset.""" + """Test the output of patch classification models on Kather100K dataset.""" inputs = [Path(sample_patch1), Path(sample_patch2)] pretrained_info = { "alexnet-kather100k": [1.0, 0.9999735355377197], @@ -307,26 +323,64 @@ def test_patch_predictor_kather100k_output( inputs, model, probabilities_check=expected_prob, - predictions_check=[6, 3], + classification_check=[6, 3], ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break + # cache mode + for model, expected_prob in pretrained_info.items(): + _test_predictor_output( + inputs, + model, + probabilities_check=expected_prob, + classification_check=[6, 3], + tmp_path=tmp_path, + ) -def _validate_probabilities(predictions: list | dict) -> bool: + +def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: + """Helper function to extract probabilities from Annotation Store.""" + con = sqlite3.connect(dbfile) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + output = {"probabilities": [], "predictions": []} + + for item in annotations_properties: + for json_str in item: + probs_dict = json.loads(json_str) + if "proba_0" in probs_dict: + output["probabilities"].append(probs_dict.pop("prob_0")) + output["predictions"].append(probs_dict.pop("type")) + + return output + + +def _validate_probabilities(output: list | dict | zarr.group) -> bool: """Helper function to test if the probabilities value are valid.""" - if isinstance(predictions, dict): - return all(0 <= probability <= 1 for _, probability in predictions.items()) + probabilities = np.array([0.5]) + + if "probabilities" in output: + probabilities = output["probabilities"] + + predictions = output["predictions"] + if isinstance(probabilities, dict): + return all(0 <= probability <= 1 for _, probability in probabilities.items()) + + predictions = np.array(get_zarr_array(predictions)).astype(int) + probabilities = get_zarr_array(probabilities) + + if not np.all(np.array(probabilities) <= 1): + return False + + if not np.all(np.array(probabilities) >= 0): + return False - for row in predictions: - for element in row: - if not (0 <= element <= 1): - return False - return True + return np.all(predictions[:][0:5] == [7, 3, 2, 3, 3]) -def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None: +def test_wsi_predictor_zarr( + sample_wsi_dict: dict, tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: """Test normal run of patch predictor for WSIs.""" mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) @@ -354,12 +408,43 @@ def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None: # number of patches x [start_x, start_y, end_x, end_y] assert output_["coordinates"].shape == (70, 4) assert output_["coordinates"].ndim == 2 - assert _validate_probabilities(predictions=output_["probabilities"]) + # prediction for each patch + assert output_["predictions"].shape == (70,) + assert output_["predictions"].ndim == 1 + assert _validate_probabilities(output=output_) + assert "Output file saved at " in caplog.text + output = predictor.run( + images=[mini_wsi_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check_no_probabilities", + ) -def test_wsi_predictor_zarr_baseline(sample_wsi_dict: dict, tmp_path: Path) -> None: - """Test normal run of patch predictor for WSIs.""" - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + assert output[mini_wsi_svs].exists() + + output_ = zarr.open(output[mini_wsi_svs]) + + assert "probabilities" not in output_ + # number of patches x [start_x, start_y, end_x, end_y] + assert output_["coordinates"].shape == (70, 4) + assert output_["coordinates"].ndim == 2 + # prediction for each patch + assert output_["predictions"].shape == (70,) + assert output_["predictions"].ndim == 1 + assert _validate_probabilities(output=output_) + assert "Output file saved at " in caplog.text + + +def test_patch_predictor_patch_mode_annotation_store( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test the output of patch classification models on Kather100K dataset.""" + inputs = [Path(sample_patch1), Path(sample_patch2)] predictor = PatchPredictor( model="alexnet-kather100k", @@ -368,46 +453,67 @@ def test_wsi_predictor_zarr_baseline(sample_wsi_dict: dict, tmp_path: Path) -> N ) # don't run test on GPU output = predictor.run( - images=[mini_wsi_svs], + images=inputs, return_probabilities=True, return_labels=False, device=device, - patch_mode=False, - save_dir=tmp_path / "wsi_out_check", - units="baseline", - resolution=1.0, + patch_mode=True, + save_dir=tmp_path / "patch_out_check", + output_type="annotationstore", ) - assert output[mini_wsi_svs].exists() + assert output.exists() + output = _extract_probabilities_from_annotation_store(output) + assert np.all(output["predictions"] == [6, 3]) + assert np.all(np.array(output["probabilities"]) <= 1) + assert np.all(np.array(output["probabilities"]) >= 0) - output_ = zarr.open(output[mini_wsi_svs]) - assert output_["probabilities"].shape == (244, 9) # number of patches x classes - assert output_["probabilities"].ndim == 2 - # number of patches x [start_x, start_y, end_x, end_y] - assert output_["coordinates"].shape == (244, 4) - assert output_["coordinates"].ndim == 2 - assert _validate_probabilities(predictions=output_["probabilities"]) +def test_patch_predictor_patch_mode_no_probabilities( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test the output of patch classification models on Kather100K dataset.""" + inputs = [Path(sample_patch1), Path(sample_patch2)] + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + output = predictor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + ) -def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: - """Helper function to extract probabilities from Annotation Store.""" - probs_dict = {} - con = sqlite3.connect(dbfile) - cur = con.cursor() - annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + assert "probabilities" not in output - for item in annotations_properties: - for json_str in item: - probs_dict = json.loads(json_str) - probs_dict.pop("prob_0") + # don't run test on GPU + output = predictor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "patch_out_check", + output_type="annotationstore", + ) - return probs_dict + assert output.exists() + output = _extract_probabilities_from_annotation_store(output) + assert np.all(output["predictions"] == [6, 3]) + assert output["probabilities"] == [] def test_engine_run_wsi_annotation_store( sample_wsi_dict: dict, tmp_path: Path, + caplog: pytest.LogCaptureFixture, ) -> None: """Test the engine run for Whole slide images.""" # convert to pathlib Path to prevent wsireader complaint @@ -433,6 +539,7 @@ def test_engine_run_wsi_annotation_store( masks=[mini_wsi_msk], patch_mode=False, output_type="AnnotationStore", + batch_size=4, **kwargs, ) @@ -440,50 +547,63 @@ def test_engine_run_wsi_annotation_store( assert output_.exists() assert output_.suffix == ".db" - predictions = _extract_probabilities_from_annotation_store(output_) - assert _validate_probabilities(predictions) + output_ = _extract_probabilities_from_annotation_store(output_) + + # prediction for each patch + assert np.array(output_["predictions"]).shape == (69,) + assert _validate_probabilities(output_) + + assert "Output file saved at " in caplog.text shutil.rmtree(save_dir) -def test_engine_run_wsi_annotation_store_power( - sample_wsi_dict: dict, +# -------------------------------------------------------------------------------------- +# torch.compile +# -------------------------------------------------------------------------------------- +def test_patch_predictor_torch_compile( + sample_patch1: Path, + sample_patch2: Path, tmp_path: Path, ) -> None: - """Test the engine run for Whole slide images.""" - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - eng = PatchPredictor(model="alexnet-kather100k") - - patch_size = np.array([224, 224]) - save_dir = f"{tmp_path}/model_wsi_output" - - kwargs = { - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 20, - "save_dir": save_dir, - "units": "power", - } - - output = eng.run( - images=[mini_wsi_svs], - masks=[mini_wsi_msk], - patch_mode=False, - output_type="AnnotationStore", - **kwargs, + """Test PatchPredictor with torch.compile functionality. + + Args: + sample_patch1 (Path): Path to sample patch 1. + sample_patch2 (Path): Path to sample patch 2. + tmp_path (Path): Path to temporary directory. + + """ + torch_compile_mode = rcParam["torch_compile_mode"] + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "default" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, ) - - output_ = output[mini_wsi_svs] - - assert output_.exists() - assert output_.suffix == ".db" - predictions = _extract_probabilities_from_annotation_store(output_) - assert _validate_probabilities(predictions) - - shutil.rmtree(save_dir) + logger.info("torch.compile default mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "reduce-overhead" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile reduce-overhead mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "max-autotune" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile max-autotune mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = torch_compile_mode # ------------------------------------------------------------------------------------- @@ -491,50 +611,6 @@ def test_engine_run_wsi_annotation_store_power( # ------------------------------------------------------------------------------------- -def test_command_line_models_file_not_found(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI file not found error.""" - runner = CliRunner() - model_file_not_found_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs)[:-1], - "--file-types", - '"*.ndpi, *.svs"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert model_file_not_found_result.output == "" - assert model_file_not_found_result.exit_code == 1 - assert isinstance(model_file_not_found_result.exception, FileNotFoundError) - - -def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI mode not in wsi, tile.""" - runner = CliRunner() - mode_not_in_wsi_tile_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--file-types", - '"*.ndpi, *.svs"', - "--patch-mode", - '"patch"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert "Invalid value for '--patch-mode'" in mode_not_in_wsi_tile_result.output - assert mode_not_in_wsi_tile_result.exit_code != 0 - assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) - - def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: """Test for models CLI single file.""" runner = CliRunner() diff --git a/tests/models/test_models_abc.py b/tests/models/test_models_abc.py index 8167a86ee..bfe6a62d8 100644 --- a/tests/models/test_models_abc.py +++ b/tests/models/test_models_abc.py @@ -168,3 +168,15 @@ def test_model_to() -> None: model = torch_models.resnet18() model = tiatoolbox.models.models_abc.model_to(device="cpu", model=model) assert isinstance(model, nn.Module) + + +def test_get_pretrained_model_not_str() -> None: + """Test TypeError is raised if input is not str.""" + with pytest.raises(TypeError, match="pretrained_model must be a string."): + _ = get_pretrained_model(1) + + +def test_get_pretrained_model_not_in_info() -> None: + """Test ValueError is raised if input is not in info.""" + with pytest.raises(ValueError, match="Pretrained model `alexnet` does not exist."): + _ = get_pretrained_model("alexnet") diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 6a278d01f..4b75400d1 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -342,7 +342,7 @@ def cli_merge_predictions( def cli_return_labels( usage_help: str = "Whether to return raw model output as labels.", *, - default: bool = True, + default: bool = False, ) -> Callable: """Enables --return-labels option for cli.""" return click.option( diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index 263809146..7f22acb0c 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -2,6 +2,8 @@ from __future__ import annotations +import click + from tiatoolbox.cli.common import ( cli_batch_size, cli_device, @@ -14,6 +16,8 @@ cli_output_type, cli_patch_mode, cli_resolution, + cli_return_labels, + cli_return_probabilities, cli_units, cli_verbose, cli_weights, @@ -31,7 +35,6 @@ @cli_file_type( default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", ) -@cli_patch_mode(default=False) @cli_model(default="resnet18-kather100k") @cli_weights() @cli_device(default="cpu") @@ -40,7 +43,13 @@ @cli_units(default="mpp") @cli_masks(default=None) @cli_num_loader_workers(default=0) -@cli_output_type(default="AnnotationStore") +@cli_output_type( + default="AnnotationStore", + input_type=click.Choice(["zarr", "AnnotationStore"], case_sensitive=False), +) +@cli_patch_mode(default=False) +@cli_return_probabilities(default=True) +@cli_return_labels(default=False) @cli_verbose(default=True) def patch_predictor( model: str, @@ -56,6 +65,8 @@ def patch_predictor( device: str, output_type: str, *, + return_probabilities: bool, + return_labels: bool, patch_mode: bool, verbose: bool, ) -> None: @@ -85,6 +96,7 @@ def patch_predictor( units=units, device=device, save_dir=output_path, - save_output=True, output_type=output_type, + return_probabilities=return_probabilities, + return_labels=return_labels, ) diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index 9df4dd56f..2ec47d99d 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -3,7 +3,7 @@ from __future__ import annotations import sys -from typing import Callable, NoReturn +from typing import NoReturn import numpy as np import torch @@ -41,7 +41,7 @@ def compile_model( model: nn.Module | None = None, *, mode: str = "default", -) -> Callable: +) -> nn.Module: """A decorator to compile a model using torch-compile. Args: @@ -60,7 +60,7 @@ def compile_model( CUDA graphs Returns: - Callable: + torch.nn.Module: Compiled model. """ @@ -71,7 +71,7 @@ def compile_model( is_torch_compile_compatible() # This check will be removed when torch.compile is supported in Python 3.12+ - if sys.version_info >= (3, 12): # pragma: no cover + if sys.version_info > (3, 12): # pragma: no cover logger.warning( ("torch-compile is currently not supported in Python 3.12+. ",), ) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 465230116..6e33da409 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -4,7 +4,7 @@ import copy import shutil -from abc import ABC, abstractmethod +from abc import ABC from pathlib import Path from typing import TYPE_CHECKING, TypedDict @@ -15,8 +15,9 @@ from torch import nn from typing_extensions import Unpack -from tiatoolbox import DuplicateFilter, logger +from tiatoolbox import DuplicateFilter, logger, rcParam from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( @@ -120,10 +121,6 @@ class EngineABCRunParams(TypedDict, total=False): Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. return_labels (bool): Whether to return the labels with the predictions. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map into a single file from a WSI. - This is only applicable if `patch_mode` is False in inference. num_loader_workers (int): Number of workers used in :class:`torch.utils.data.DataLoader`. num_post_proc_workers (int): @@ -138,8 +135,6 @@ class EngineABCRunParams(TypedDict, total=False): resolution (Resolution): Resolution used for reading the image. Please see :class:`WSIReader` for details. - return_labels (bool): - Whether to return the output labels. scale_factor (tuple[float, float]): The scale factor to use when loading the annotations. All coordinates will be multiplied by this factor to allow @@ -165,7 +160,6 @@ class EngineABCRunParams(TypedDict, total=False): class_dict: dict device: str ioconfig: ModelIOConfigABC - merge_predictions: bool num_loader_workers: int num_post_proc_workers: int output_file: str @@ -178,7 +172,7 @@ class EngineABCRunParams(TypedDict, total=False): verbose: bool -class EngineABC(ABC): +class EngineABC(ABC): # noqa: B024 """Abstract base class for TIAToolbox deep learning engines to run CNN models. Args: @@ -248,10 +242,6 @@ class EngineABC(ABC): Runtime ioconfig. return_labels (bool): Whether to return the labels with the predictions. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable if `patch_mode` is False in inference. - Default is False. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. @@ -293,8 +283,6 @@ class EngineABC(ABC): Number of workers to postprocess the results of the model. return_labels (bool): Whether to return the output labels. Default value is False. - merge_predictions (bool): - Whether to merge WSI predictions into a single file. Default value is False. resolution (Resolution): Resolution used for reading the image. Please see :class:`WSIReader` for details. @@ -368,13 +356,18 @@ def __init__( weights=weights, ) self.model.to(device=self.device) + self.model = ( + compile_model( # for runtime, such as after wrapping with nn.DataParallel + self.model, + mode=rcParam["torch_compile_mode"], + ) + ) self._ioconfig = self.ioconfig # runtime ioconfig self.batch_size = batch_size self.cache_mode: bool = False self.cache_size: int = self.batch_size if self.batch_size else 10000 self.labels: list | None = None - self.merge_predictions: bool = False self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers self.patch_input_shape: IntPair | None = None @@ -453,7 +446,7 @@ def get_dataloader( images (list of str or :class:`Path` or :class:`numpy.ndarray`): A list of image patches in NHWC format as a numpy array or a list of str/paths to WSIs. When `patch_mode` is False - the function expects path to a single WSI. + the function expects list of str/paths to WSIs. masks (list | None): List of masks. Only utilised when patch_mode is False. Patches are only generated within a masked area. @@ -522,6 +515,13 @@ def _update_model_output(raw_predictions: dict, raw_output: dict) -> dict: return raw_predictions + def _get_coordinates(self: EngineABC, batch_data: dict) -> np.ndarray: + """Helper function to collect coordinates for AnnotationStore.""" + if self.patch_mode: + coordinates = [0, 0, *batch_data["image"].shape[1:3]] + return np.tile(coordinates, reps=(batch_data["image"].shape[0], 1)) + return batch_data["coords"].numpy() + def infer_patches( self: EngineABC, dataloader: DataLoader, @@ -579,7 +579,7 @@ def infer_patches( device=self.device, ) if return_coordinates: - batch_output["coordinates"] = batch_data["coords"].numpy() + batch_output["coordinates"] = self._get_coordinates(batch_data) if self.return_labels: # be careful of `s` if isinstance(batch_data["label"], torch.Tensor): @@ -631,7 +631,7 @@ def post_process_patches( saved zarr file if `cache_mode` is True. """ - _ = kwargs.get("probabilities") # Key values required for post-processing + _ = kwargs.get("return_labels") # Key values required for post-processing if self.cache_mode: # cache mode _ = zarr.open(raw_predictions, mode="w") @@ -710,74 +710,51 @@ def save_predictions( else processed_predictions ) - @abstractmethod def infer_wsi( self: EngineABC, - dataloader: torch.utils.data.DataLoader, - save_path: Path | str, - **kwargs: dict, - ) -> dict | Path: + dataloader: DataLoader, + save_path: Path, + **kwargs: EngineABCRunParams, + ) -> Path: """Model inference on a WSI. - This function must be implemented by subclasses. + Args: + dataloader (DataLoader): + A torch dataloader to process WSIs. + + save_path (Path): + Path to save the intermediate output. The intermediate output is saved + in a zarr file. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`EngineRunParams` for accepted keyword arguments. + + Returns: + save_path (Path): + Path to zarr file where intermediate output is saved. """ - # return coordinates of patches processed within a tile / whole-slide image - raise NotImplementedError + _ = kwargs.get("patch_mode", False) + return self.infer_patches( + dataloader=dataloader, + save_path=save_path, + return_coordinates=True, + ) - @abstractmethod - def post_process_wsi( + # This is not a static model for child classes. + def post_process_wsi( # skipcq: PYL-R0201 self: EngineABC, raw_predictions: dict | Path, **kwargs: Unpack[EngineABCRunParams], ) -> dict | Path: - """Post process WSI output.""" - _ = kwargs.get("probabilities") # Key values required for post-processing - return raw_predictions + """Post process WSI output. - @abstractmethod - def save_wsi_output( - self: EngineABC, - processed_output: Path, - output_type: str, - **kwargs: Unpack[EngineABCRunParams], - ) -> Path: - """Aggregate the output at the WSI level and save to file. - - Args: - processed_output (Path): - Path to Zarr file with intermediate results. - output_type (str): - The desired output type for resulting patch dataset. - **kwargs (EngineABCRunParams): - Keyword Args to update setup_patch_dataset() method attributes. - - Returns: (AnnotationStore or Path): - If the output_type is "AnnotationStore", the function returns the patch - predictor output as an SQLiteStore containing Annotations stored in a `.db` - file. Otherwise, the function defaults to returning patch predictor output - stored in a `.zarr` file. + Takes the raw output from patch predictions and post-processes it to improve the + results e.g., using information from neighbouring patches. """ - if output_type.lower() == "zarr": - msg = "Output file saved at %s.", processed_output - logger.info(msg=msg) - return processed_output - - if output_type.lower() == "annotationstore": - save_path = Path(kwargs.get("output_file", processed_output.stem + ".db")) - # scale_factor set from kwargs - scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) - # Read zarr file to a dict - raw_output_dict = zarr.open(str(processed_output), mode="r") - - # class_dict set from kwargs - class_dict = kwargs.get("class_dict") - - return dict_to_store(raw_output_dict, scale_factor, class_dict, save_path) - - msg = "Only supports zarr and AnnotationStore as output_type." - raise ValueError(msg) + _ = kwargs.get("return_labels") # Key values required for post-processing + return raw_predictions def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: """Helper function to load ioconfig. @@ -1015,6 +992,7 @@ def _run_patch_mode( raw_predictions = self.infer_patches( dataloader=dataloader, save_path=save_path, + return_coordinates=output_type == "annotationstore", ) processed_predictions = self.post_process_patches( raw_predictions=raw_predictions, @@ -1022,13 +1000,20 @@ def _run_patch_mode( ) logger.removeFilter(duplicate_filter) - return self.save_predictions( + out = self.save_predictions( processed_predictions=processed_predictions, output_type=output_type, save_dir=save_dir, **kwargs, ) + if save_dir: + msg = f"Output file saved at {out}." + logger.info(msg=msg) + return out + + return out + @staticmethod def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, float]: """Calculates scale factor for final output. @@ -1058,18 +1043,18 @@ def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, floa if dataloader_units == "mpp": slide_resolution = wsimeta_dict[dataloader_units] - scale_factor = np.divide(slide_resolution, dataloader_resolution) + scale_factor = np.divide(dataloader_resolution, slide_resolution) return scale_factor[0], scale_factor[1] if dataloader_units == "level": downsample_ratio = wsimeta_dict["level_downsamples"][dataloader_resolution] - return 1.0 / downsample_ratio, 1.0 / downsample_ratio + return downsample_ratio, downsample_ratio if dataloader_units == "power": slide_objective_power = wsimeta_dict["objective_power"] return ( - dataloader_resolution / slide_objective_power, - dataloader_resolution / slide_objective_power, + slide_objective_power / dataloader_resolution, + slide_objective_power / dataloader_resolution, ) return dataloader_resolution @@ -1126,6 +1111,8 @@ def _run_wsi_mode( **kwargs, ) logger.removeFilter(duplicate_filter) + msg = f"Output file saved at {out[image]}." + logger.info(msg=msg) return out @@ -1194,8 +1181,6 @@ def run( - img_path: path of the input image. - raw: path to save location for raw prediction, saved in .json. - - merged: path to .npy contain merged - predictions if `merge_predictions` is `True`. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index b98c6676d..35b2c7d56 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -1,9 +1,11 @@ -"""Defines Abstract Base Class for TIAToolbox Model Engines.""" +"""Defines PatchPredictor Engine.""" from __future__ import annotations +import math from typing import TYPE_CHECKING +import zarr from typing_extensions import Unpack from .engine_abc import EngineABC, EngineABCRunParams @@ -13,19 +15,80 @@ from pathlib import Path import numpy as np - from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import ModelIOConfigABC from tiatoolbox.models.models_abc import ModelABC - from tiatoolbox.wsicore.wsireader import WSIReader + from tiatoolbox.wsicore import WSIReader - from .io_config import ModelIOConfigABC + +class PredictorRunParams(EngineABCRunParams): + """Class describing the input parameters for the :func:`EngineABC.run()` method. + + Attributes: + batch_size (int): + Number of image patches to feed to the model in a forward pass. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. + ioconfig (ModelIOConfigABC): + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. + return_labels (bool): + Whether to return the labels with the predictions. + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + output_file (str): + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. + patch_input_shape (tuple): + Shape of patches input to the model as tuple of height and width (HW). + Patches are requested at read resolution, not with respect to level 0, + and must be positive. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + verbose (bool): + Whether to output logging information. + + """ + + return_probabilities: bool class PatchPredictor(EngineABC): - r"""Patch level predictor for digital histology images. + r"""Patch level prediction for digital histology images. - The models provided by tiatoolbox should give the following results: + The models provided by TIAToolbox should give the following results: .. list-table:: PatchPredictor performance on the Kather100K dataset [1] :widths: 15 15 @@ -130,7 +193,7 @@ class PatchPredictor(EngineABC): weights (str or Path): Path to the weight of the corresponding `model`. - >>> engine = EngineABC( + >>> engine = PatchPredictor( ... model="pretrained-model", ... weights="/path/to/pretrained-local-weights.pth" ... ) @@ -176,10 +239,6 @@ class PatchPredictor(EngineABC): Runtime ioconfig. return_labels (bool): Whether to return the labels with the predictions. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable if `patch_mode` is False in inference. - Default is False. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. @@ -221,8 +280,6 @@ class PatchPredictor(EngineABC): Number of workers to postprocess the results of the model. return_labels (bool): Whether to return the output labels. Default value is False. - merge_predictions (bool): - Whether to merge WSI predictions into a single file. Default value is False. resolution (Resolution): Resolution used for reading the image. Please see :class:`WSIReader` for details. @@ -249,7 +306,7 @@ class PatchPredictor(EngineABC): >>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) >>> predictor = PatchPredictor(model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> output = predictor.run(data, mode='patch') >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] @@ -299,83 +356,90 @@ def __init__( verbose=verbose, ) - def get_dataloader( + def post_process_cache_mode( self: PatchPredictor, - images: Path, - masks: Path | None = None, - labels: list | None = None, - ioconfig: ModelIOConfigABC | None = None, - *, - patch_mode: bool = True, - ) -> DataLoader: - """Pre-process images and masks and return dataloader for inference. + raw_predictions: Path, + **kwargs: Unpack[PredictorRunParams], + ) -> Path: + """Returns an array from raw predictions.""" + return_probabilities = kwargs.get("return_probabilities") + zarr_group = zarr.open(str(raw_predictions), mode="r+") + + num_iter = math.ceil(len(zarr_group["probabilities"]) / self.batch_size) + start = 0 + for _ in range(num_iter): + # Probabilities for post-processing + probabilities = zarr_group["probabilities"][start : start + self.batch_size] + start = start + self.batch_size + predictions = self.model.postproc_func( + probabilities, + ) + if "predictions" in zarr_group: + zarr_group["predictions"].append(predictions) + continue + + zarr_dataset = zarr_group.create_dataset( + name="predictions", + shape=predictions.shape, + compressor=zarr_group["probabilities"].compressor, + ) + zarr_dataset[:] = predictions + + if return_probabilities is not False: + return raw_predictions + + del zarr_group["probabilities"] + + return raw_predictions + + def post_process_patches( + self: PatchPredictor, + raw_predictions: dict | Path, + **kwargs: Unpack[PredictorRunParams], + ) -> dict | Path: + """Post-process raw patch predictions from inference. + + The output of :func:`infer_patches()` with patch prediction information will be + post-processed using this function. The processed output will be saved in the + respective input format. If `cache_mode` is True, the function processes the + input using zarr group with size specified by `cache_size`. Args: - images (list of str or :class:`Path` or :class:`numpy.ndarray`): - A list of image patches in NHWC format as a numpy array - or a list of str/paths to WSIs. When `patch_mode` is False - the function expects list of str/paths to WSIs. - masks (list | None): - List of masks. Only utilised when patch_mode is False. - Patches are only generated within a masked area. - If not provided, then a tissue mask will be automatically - generated for whole slide images. - labels (list | None): - List of labels. Only a single label per image is supported. - ioconfig (ModelIOConfigABC): - A :class:`ModelIOConfigABC` object. - patch_mode (bool): - Whether to treat input image as a patch or WSI. + raw_predictions (dict | Path): + A dictionary or path to zarr with patch prediction information. + **kwargs (PredictorRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`PredictorRunParams` for accepted keyword arguments. Returns: - DataLoader: - :class:`DataLoader` for inference. - + dict or Path: + Returns patch based output after post-processing. Returns path to + saved zarr file if `cache_mode` is True. """ - return super().get_dataloader( - images, - masks, - labels, - ioconfig, - patch_mode=patch_mode, - ) + return_probabilities = kwargs.get("return_probabilities") + if self.cache_mode: + return self.post_process_cache_mode(raw_predictions, **kwargs) - def infer_wsi( - self: EngineABC, - dataloader: DataLoader, - save_path: Path, - **kwargs: EngineABCRunParams, - ) -> Path: - """Model inference on a WSI. + probabilities = raw_predictions.get("probabilities") - Args: - dataloader (DataLoader): - A torch dataloader to process WSIs. + predictions = self.model.postproc_func( + probabilities, + ) - save_path (Path): - Path to save the intermediate output. The intermediate output is saved - in a zarr file. - **kwargs (EngineABCRunParams): - Keyword Args to update setup_patch_dataset() method attributes. See - :class:`EngineRunParams` for accepted keyword arguments. + raw_predictions["predictions"] = predictions - Returns: - save_path (Path): - Path to zarr file where intermediate output is saved. + if return_probabilities is not False: + return raw_predictions - """ - _ = kwargs.get("patch_mode", False) - return self.infer_patches( - dataloader=dataloader, - save_path=save_path, - return_coordinates=True, - ) + del raw_predictions["probabilities"] + + return raw_predictions def post_process_wsi( - self: EngineABC, + self: PatchPredictor, raw_predictions: dict | Path, - **kwargs: Unpack[EngineABCRunParams], + **kwargs: Unpack[PredictorRunParams], ) -> dict | Path: """Post process WSI output. @@ -383,42 +447,10 @@ def post_process_wsi( results e.g., using information from neighbouring patches. """ - return super().post_process_wsi( - raw_predictions=raw_predictions, - **kwargs, - ) - - def save_wsi_output( - self: EngineABC, - processed_output: Path, - output_type: str, - **kwargs: Unpack[EngineABCRunParams], - ) -> Path: - """Aggregate the output at the WSI level and save to file. - - Args: - processed_output (Path): - Path to Zarr file with intermediate results. - output_type (str): - The desired output type for resulting patch dataset. - **kwargs (EngineABCRunParams): - Keyword Args to update setup_patch_dataset() method attributes. - - Returns: (AnnotationStore or Path): - If the output_type is "AnnotationStore", the function returns the patch - predictor output as an SQLiteStore containing Annotations stored in a `.db` - file. Otherwise, the function defaults to returning patch predictor output - stored in a `.zarr` file. - - """ - return super().save_wsi_output( - processed_output=processed_output, - output_type=output_type, - **kwargs, - ) + return self.post_process_cache_mode(raw_predictions, **kwargs) def run( - self: EngineABC, + self: PatchPredictor, images: list[os | Path | WSIReader] | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, labels: list | None = None, @@ -428,7 +460,7 @@ def run( save_dir: os | Path | None = None, # None will not save output overwrite: bool = False, output_type: str = "dict", - **kwargs: Unpack[EngineABCRunParams], + **kwargs: Unpack[PredictorRunParams], ) -> AnnotationStore | Path | str | dict: """Run the engine on input images. @@ -466,7 +498,7 @@ def run( then the output will be intermediately saved as zarr but converted to :class:`AnnotationStore` and saved as a `.db` file at the end of the loop. - **kwargs (EngineABCRunParams): + **kwargs (PredictorRunParams): Keyword Args to update :class:`EngineABC` attributes during runtime. Returns: @@ -482,8 +514,6 @@ def run( - img_path: path of the input image. - raw: path to save location for raw prediction, saved in .json. - - merged: path to .npy contain merged - predictions if `merge_predictions` is `True`. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 5a3c8c7fe..e0bf0c077 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1211,9 +1211,11 @@ def patch_predictions_as_annotations( ) -> list: """Helper function to generate annotation per patch predictions.""" annotations = [] - for i, probs in enumerate(class_probs): + for i, _ in enumerate(patch_coords): if "probabilities" in keys: - props = {f"prob_{class_dict[j]}": probs[j] for j in classes_predicted} + props = { + f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted + } else: props = {} if "labels" in keys: @@ -1225,12 +1227,12 @@ def patch_predictions_as_annotations( return annotations -def _get_zarr_array(zarr_array: zarr.core.Array | np.ndarray) -> np.ndarray: +def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray | list) -> np.ndarray: """Converts a zarr array into a numpy array.""" if isinstance(zarr_array, zarr.core.Array): return zarr_array[:] - return zarr_array + return np.array(zarr_array).astype(float) def dict_to_store( @@ -1269,12 +1271,13 @@ def dict_to_store( raise ValueError(msg) # get relevant keys - class_probs = _get_zarr_array(patch_output.get("probabilities", [])) - preds = _get_zarr_array(patch_output.get("predictions", [])) + class_probs = get_zarr_array(patch_output.get("probabilities", [])) + preds = get_zarr_array(patch_output.get("predictions", [])) patch_coords = np.array(patch_output.get("coordinates", [])) if not np.all(np.array(scale_factor) == 1): patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp + patch_coords = patch_coords.astype(float) labels = patch_output.get("labels", []) # get classes to consider if len(class_probs) == 0: @@ -1285,9 +1288,9 @@ def dict_to_store( if class_dict is None: # if no class dict create a default one if len(class_probs) == 0: - class_dict = {i: i for i in np.unique(preds + labels).tolist()} + class_dict = {i: i for i in np.unique(np.append(preds, labels)).tolist()} else: - class_dict = {i: i for i in range(len(class_probs))} + class_dict = {i: i for i in range(len(class_probs[0]))} # find what keys we need to save keys = ["predictions"] From 819e1388368c6bf0368bef5e5000fd5bba319b09 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:00:23 +0000 Subject: [PATCH 32/36] :twisted_rightwards_arrows: Undo unwanted changes during merge. --- tests/models/test_models_abc.py | 13 +++++-------- tests/test_utils.py | 11 +++++++++++ tiatoolbox/cli/common.py | 20 ++++---------------- tiatoolbox/cli/patch_predictor.py | 3 --- 4 files changed, 20 insertions(+), 27 deletions(-) diff --git a/tests/models/test_models_abc.py b/tests/models/test_models_abc.py index bfe6a62d8..7451ae9ac 100644 --- a/tests/models/test_models_abc.py +++ b/tests/models/test_models_abc.py @@ -6,15 +6,15 @@ import pytest import torch +import torchvision.models as torch_models from torch import nn -import tiatoolbox.models from tiatoolbox import rcParam, utils from tiatoolbox.models.architecture import ( fetch_pretrained_weights, get_pretrained_model, ) -from tiatoolbox.models.models_abc import ModelABC +from tiatoolbox.models.models_abc import ModelABC, model_to from tiatoolbox.utils import env_detection as toolbox_env if TYPE_CHECKING: @@ -154,19 +154,16 @@ def test_model_abc() -> None: def test_model_to() -> None: """Test for placing model on device.""" - import torchvision.models as torch_models - from torch import nn - # Test on GPU - # no GPU on Travis so this will crash + # no GPU on GitHub Actions so this will crash if not utils.env_detection.has_gpu(): model = torch_models.resnet18() with pytest.raises((AssertionError, RuntimeError)): - _ = tiatoolbox.models.models_abc.model_to(device="cuda", model=model) + _ = model_to(device="cuda", model=model) # Test on CPU model = torch_models.resnet18() - model = tiatoolbox.models.models_abc.model_to(device="cpu", model=model) + model = model_to(device="cpu", model=model) assert isinstance(model, nn.Module) diff --git a/tests/test_utils.py b/tests/test_utils.py index db72fa728..4708b5afe 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1669,6 +1669,17 @@ def test_patch_pred_store() -> None: with pytest.raises(ValueError, match="coordinates"): misc.dict_to_store(patch_output, (1.0, 1.0)) + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "other": "other", + } + + store = misc.dict_to_store(patch_output, (1.0, 1.0)) + + # Check that it is an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + def test_patch_pred_store_cdict() -> None: """Test patch_pred_store with a class dict.""" diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 4b75400d1..02a5a8c86 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -94,8 +94,10 @@ def cli_output_type( input_type: click.Choice | None = None, ) -> Callable: """Enables --file-types option for cli.""" - if input_type is None: - input_type = click.Choice(["zarr", "AnnotationStore"], case_sensitive=False) + click_choices = click.Choice( + choices=["zarr", "AnnotationStore"], case_sensitive=False + ) + input_type = click_choices if input_type is None else input_type return click.option( "--output-type", help=add_default_to_usage_help(usage_help, default), @@ -410,20 +412,6 @@ def cli_yaml_config_path( ) -def cli_on_gpu( - usage_help: str = "Run the model on GPU.", - *, - default: bool = False, -) -> Callable: - """Enables --on-gpu option for cli.""" - return click.option( - "--on-gpu", - type=bool, - default=default, - help=add_default_to_usage_help(usage_help, default), - ) - - def cli_num_loader_workers( usage_help: str = "Number of workers to load the data. Please note that they will " "also perform preprocessing.", diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index 7f22acb0c..a534c8f50 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -2,8 +2,6 @@ from __future__ import annotations -import click - from tiatoolbox.cli.common import ( cli_batch_size, cli_device, @@ -45,7 +43,6 @@ @cli_num_loader_workers(default=0) @cli_output_type( default="AnnotationStore", - input_type=click.Choice(["zarr", "AnnotationStore"], case_sensitive=False), ) @cli_patch_mode(default=False) @cli_return_probabilities(default=True) From 92373689484756f0d736a6fc57fd684492d6fa0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 18:02:52 +0000 Subject: [PATCH 33/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/__init__.py | 22 +++++++++++----------- tiatoolbox/models/dataset/__init__.py | 8 ++++---- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 9b2dac774..ab52740ed 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -23,27 +23,27 @@ from .engine.semantic_segmentor import DeepFeatureExtractor, SemanticSegmentor __all__ = [ - "architecture", - "dataset", - "engine", - "models_abc", "SCCNN", + "DeepFeatureExtractor", "HoVerNet", "HoVerNetPlus", "IDaRS", + "IOInstanceSegmentorConfig", + "IOPatchPredictorConfig", + "IOSegmentorConfig", "MapDe", "MicroNet", + "ModelIOConfigABC", "MultiTaskSegmentor", "NuClick", "NucleusInstanceSegmentor", + "PatchDataset", "PatchPredictor", "SemanticSegmentor", - "IOPatchPredictorConfig", - "IOSegmentorConfig", - "IOInstanceSegmentorConfig", - "ModelIOConfigABC", - "DeepFeatureExtractor", - "WSIStreamDataset", "WSIPatchDataset", - "PatchDataset", + "WSIStreamDataset", + "architecture", + "dataset", + "engine", + "models_abc", ] diff --git a/tiatoolbox/models/dataset/__init__.py b/tiatoolbox/models/dataset/__init__.py index 49d59a61a..16c80fd18 100644 --- a/tiatoolbox/models/dataset/__init__.py +++ b/tiatoolbox/models/dataset/__init__.py @@ -11,11 +11,11 @@ from .info import DatasetInfoABC, KatherPatchDataset __all__ = [ - "predefined_preproc_func", + "DatasetInfoABC", + "KatherPatchDataset", + "PatchDataset", "PatchDatasetABC", "WSIPatchDataset", - "PatchDataset", "WSIStreamDataset", - "DatasetInfoABC", - "KatherPatchDataset", + "predefined_preproc_func", ] From e93d98ac13e5d6f45ef20c8f32c26683fbbbcb05 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:37:30 +0000 Subject: [PATCH 34/36] :bug: Fix PLC0206 (#907) * :bug: Fix PLC0206 - Fix PLC0206 Extracting value from dictionary without calling `.items()` --- tiatoolbox/models/engine/engine_abc.py | 8 +++----- tiatoolbox/utils/misc.py | 11 +++++------ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 6e33da409..8f0adc310 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -505,13 +505,11 @@ def get_dataloader( @staticmethod def _update_model_output(raw_predictions: dict, raw_output: dict) -> dict: """Helper function to append raw output during inference.""" - for key in raw_output: + for key, value in raw_output.items(): if raw_predictions[key] is None: - raw_predictions[key] = raw_output[key] + raw_predictions[key] = value else: - raw_predictions[key] = np.append( - raw_predictions[key], raw_output[key], axis=0 - ) + raw_predictions[key] = np.append(raw_predictions[key], value, axis=0) return raw_predictions diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 473fa94c0..b786517a7 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1490,20 +1490,19 @@ def write_to_zarr_in_cache_mode( # case 1 - new zarr group if not zarr_group: - for key in output_data_to_save: - data_to_save = output_data_to_save[key] + for key, value in output_data_to_save.items(): # populate the zarr group for the first time zarr_dataset = zarr_group.create_dataset( name=key, - shape=data_to_save.shape, + shape=value.shape, compressor=compressor, ) - zarr_dataset[:] = data_to_save + zarr_dataset[:] = value return zarr_group # case 2 - append to existing zarr group - for key in output_data_to_save: - zarr_group[key].append(output_data_to_save[key]) + for key, value in output_data_to_save.items(): + zarr_group[key].append(value) return zarr_group From 648d02ac2eeb604870e3d65d297d0478d23fe5c0 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 5 Feb 2025 11:48:18 +0000 Subject: [PATCH 35/36] :bug: Fix tiatoolbox type_hints import --- tiatoolbox/models/engine/semantic_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index f7bc02290..eb7d17b75 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -29,7 +29,7 @@ from .io_config import IOSegmentorConfig if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.typing import IntPair, Resolution, Units + from tiatoolbox.type_hints import IntPair, Resolution, Units def _estimate_canvas_parameters( From 8571e141b17ed26a669c3538517501c68b870ecd Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 5 Feb 2025 12:03:04 +0000 Subject: [PATCH 36/36] :bug: Fix deepsource bugs --- tests/engines/test_engine_abc.py | 4 ++-- tests/models/test_models_abc.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index acfcfc888..aec2bf310 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -355,7 +355,7 @@ def test_engine_run() -> NoReturn: def test_engine_run_with_verbose() -> NoReturn: """Test engine run with verbose.""" - """Run pytest with `-rP` option to view progress bar on the captured stderr call""" + # Run pytest with `-rP` option to view progress bar on the captured stderr call. eng = TestEngineABC(model="alexnet-kather100k", verbose=True) out = eng.run( @@ -401,7 +401,7 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: ) assert Path.exists(out), "Zarr output file does not exist" - """ test custom zarr output file name""" + # Test custom zarr output file name eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), diff --git a/tests/models/test_models_abc.py b/tests/models/test_models_abc.py index 7451ae9ac..f5d744f0f 100644 --- a/tests/models/test_models_abc.py +++ b/tests/models/test_models_abc.py @@ -70,7 +70,6 @@ def forward(self: Proto) -> None: # skipcq def infer_batch() -> None: """Define infer batch.""" - pass # base class definition pass # noqa: PIE790 @pytest.mark.skipif(