diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index a87424dfd..84d8bdf3d 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -4,7 +4,7 @@ import pytest import torch -from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel +from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel, infer_batch from tiatoolbox.models.models_abc import model_to ON_GPU = False @@ -45,7 +45,7 @@ def test_functional() -> None: for backbone in backbones: model = CNNModel(backbone, num_classes=1) model_ = model_to(device=device, model=model) - model.infer_batch(model_, samples, device=device) + infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc @@ -72,7 +72,7 @@ def test_timm_functional() -> None: for backbone in backbones: model = TimmModel(backbone=backbone, num_classes=1, pretrained=False) model_ = model_to(device=device, model=model) - model.infer_batch(model_, samples, device=device) + infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index cb487ec53..c5c41f912 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -79,6 +79,40 @@ def _get_architecture( return model.features +def infer_batch( + model: nn.Module, + batch_data: torch.Tensor, + *, + device: str = "cpu", +) -> dict[str, np.ndarray]: + """Run inference on an input batch. + + Contains logic for forward operation as well as i/o aggregation. + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (torch.Tensor): + A batch of data generated by + `torch.utils.data.DataLoader`. + device (str): + Transfers model to the specified device. Default is "cpu". + + """ + img_patches_device = batch_data.to(device).type( + torch.float32, + ) # to NCHW + img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() + + # Inference mode + model.eval() + # Do not compute the gradient (not training) + with torch.inference_mode(): + output = model(img_patches_device) + # Output should be a single tensor or scalar + return {"probabilities": output.cpu().numpy()} + + def _get_timm_architecture( arch_name: str, *, diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 8f0adc310..b04d4aded 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -18,6 +18,7 @@ from tiatoolbox import DuplicateFilter, logger, rcParam from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.architecture.utils import compile_model +from tiatoolbox.models.architecture.vanilla import infer_batch from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( @@ -571,7 +572,7 @@ def infer_patches( zarr_group = zarr.open(save_path, mode="w") for _, batch_data in enumerate(dataloader): - batch_output = self.model.infer_batch( + batch_output = infer_batch( self.model, batch_data["image"], device=self.device, diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 0e2b7d81f..5916edfe6 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -77,30 +77,6 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover - @staticmethod - @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict: - """Run inference on an input batch. - - Contains logic for forward operation as well as I/O aggregation. - - Args: - model (nn.Module): - PyTorch defined model. - batch_data (np.ndarray): - A batch of data generated by - `torch.utils.data.DataLoader`. - device (str): - Transfers model to the specified device. Default is "cpu". - - Returns: - dict: - Returns a dictionary of predictions and other expected outputs - depending on the network architecture. - - """ - ... # pragma: no cover - @staticmethod def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of model."""