Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎨 Refactor ModelABC to Help Use Default Torch Models #867

Draft
wants to merge 8 commits into
base: dev-define-engines-abc
Choose a base branch
from
4 changes: 2 additions & 2 deletions tests/models/test_arch_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.architecture.vanilla import CNNModel, infer_batch
from tiatoolbox.models.models_abc import model_to
from tiatoolbox.utils.misc import select_device

Expand Down Expand Up @@ -46,7 +46,7 @@ def test_functional() -> None:
for backbone in backbones:
model = CNNModel(backbone, num_classes=1)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU))
infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU))
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand Down
103 changes: 34 additions & 69 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,40 @@ def _get_architecture(
return model.features


def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
device: str = "cpu",
) -> dict[str, np.ndarray]:
"""Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

"""
img_patches_device = batch_data.to(device).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()

# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return {"probabilities": output.cpu().numpy()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current develop branch, neither CNNModel, nor CNNBackbone returned dictionaries as output of their infer_batch() methods. Also, CNNModel currently returns an array, while CNNBackbone returns a list with the array. It might be fine, just wanted to highlight this.

CNNModel

return output.cpu().numpy()

CNNBackbone

return [output.cpu().numpy()]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. We are aware of this. Our preference is to use torch nn models but to generalise for multi modal output we may need dictionaries. This PR is to check if we can move to generic torch models or we will need a sub class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.



class CNNModel(ModelABC):
"""Retrieve the model backbone and attach an extra FCN to perform classification.

Expand Down Expand Up @@ -137,40 +171,6 @@ def postproc(image: np.ndarray) -> np.ndarray:
"""
return np.argmax(image, axis=-1)

@staticmethod
def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
device: str = "cpu",
) -> dict[str, np.ndarray]:
"""Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

"""
img_patches_device = batch_data.to(device).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()

# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return {"probabilities": output.cpu().numpy()}


class CNNBackbone(ModelABC):
"""Retrieve the model backbone and strip the classification layer.
Expand Down Expand Up @@ -233,38 +233,3 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
feat = self.feat_extract(imgs)
gap_feat = self.pool(feat)
return torch.flatten(gap_feat, 1)

@staticmethod
def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
device: str,
) -> dict[str, np.ndarray]:
"""Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

"""
img_patches_device = batch_data.to(device).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()

# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)

# Output should be a single tensor or scalar
return {"probabilities": output.cpu().numpy()}
3 changes: 2 additions & 1 deletion tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from tiatoolbox import DuplicateFilter, logger
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture.vanilla import infer_batch
from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset
from tiatoolbox.models.models_abc import load_torch_model
from tiatoolbox.utils.misc import (
Expand Down Expand Up @@ -573,7 +574,7 @@ def infer_patches(
zarr_group = zarr.open(save_path, mode="w")

for _, batch_data in enumerate(dataloader):
batch_output = self.model.infer_batch(
batch_output = infer_batch(
self.model,
batch_data["image"],
device=self.device,
Expand Down
24 changes: 0 additions & 24 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,6 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
"""Torch method, this contains logic for using layers defined in init."""
... # pragma: no cover

@staticmethod
@abstractmethod
def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict:
"""Run inference on an input batch.

Contains logic for forward operation as well as I/O aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (np.ndarray):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

Returns:
dict:
Returns a dictionary of predictions and other expected outputs
depending on the network architecture.

"""
... # pragma: no cover

@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Define the pre-processing of this class of model."""
Expand Down
Loading