diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index cd88ec073..8f62f5037 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -446,6 +446,46 @@ def test_engine_run_wsi_annotation_store( shutil.rmtree(save_dir) +def test_engine_run_wsi_annotation_store_power( + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """Test the engine run for Whole slide images.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + eng = PatchPredictor(model="alexnet-kather100k") + + patch_size = np.array([224, 224]) + save_dir = f"{tmp_path}/model_wsi_output" + + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 20, + "save_dir": save_dir, + "units": "power", + } + + output = eng.run( + images=[mini_wsi_svs], + masks=[mini_wsi_msk], + patch_mode=False, + output_type="AnnotationStore", + **kwargs, + ) + + output_ = output[mini_wsi_svs] + + assert output_.exists() + assert output_.suffix == ".db" + predictions = _extract_probabilities_from_annotation_store(output_) + assert _validate_probabilities(predictions) + + shutil.rmtree(save_dir) + + # ------------------------------------------------------------------------------------- # Command Line Interface # ------------------------------------------------------------------------------------- diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index e8f3c63f6..465230116 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1053,7 +1053,7 @@ def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, floa # in this case dataloader resolution / slide resolution will be # equal to dataloader resolution. - if dataloader_units in ["mpp", "level", "objective_power"]: + if dataloader_units in ["mpp", "level", "power"]: wsimeta_dict = dataloader.dataset.reader.info.as_dict() if dataloader_units == "mpp": @@ -1065,8 +1065,8 @@ def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, floa downsample_ratio = wsimeta_dict["level_downsamples"][dataloader_resolution] return 1.0 / downsample_ratio, 1.0 / downsample_ratio - if dataloader_resolution == "objective_power": - slide_objective_power = wsimeta_dict["power"] + if dataloader_units == "power": + slide_objective_power = wsimeta_dict["objective_power"] return ( dataloader_resolution / slide_objective_power, dataloader_resolution / slide_objective_power,