diff --git a/src/anomalib/data/utils/transform.py b/src/anomalib/data/utils/transform.py index 38bd9e2bc9..a46d8a2c2e 100644 --- a/src/anomalib/data/utils/transform.py +++ b/src/anomalib/data/utils/transform.py @@ -86,10 +86,16 @@ def get_transforms( if isinstance(config, DictConfig): logger.info("Loading transforms from config File") transforms_list = [] + + if "Resize" not in config.keys() and image_size is not None: + resize_height, resize_width = get_image_height_and_width(image_size) + transforms_list.append(A.Resize(height=resize_height, width=resize_width, always_apply=True)) + logger.info("Resize %s added!", (resize_height, resize_width)) + for key, value in config.items(): if hasattr(A, key): transform = getattr(A, key)(**value) - logger.info(f"Transform {transform} added!") + logger.info("Transform %s added!", transform) transforms_list.append(transform) else: raise ValueError(f"Transformation {key} is not part of albumentations") diff --git a/src/anomalib/deploy/inferencers/openvino_inferencer.py b/src/anomalib/deploy/inferencers/openvino_inferencer.py index 745b682aae..4fd9ecc79d 100644 --- a/src/anomalib/deploy/inferencers/openvino_inferencer.py +++ b/src/anomalib/deploy/inferencers/openvino_inferencer.py @@ -5,6 +5,7 @@ from __future__ import annotations +import logging from importlib.util import find_spec from pathlib import Path from typing import Any @@ -18,10 +19,12 @@ from .base_inferencer import Inferencer +logger = logging.getLogger("anomalib") + if find_spec("openvino") is not None: from openvino.runtime import Core else: - raise ImportError("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.") + logger.warning("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.") class OpenVINOInferencer(Inferencer): diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index 3c6c2a5d49..b9a7368337 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -179,7 +179,7 @@ def _collect_outputs( image_metric.update(output["pred_scores"], output["label"].int()) if "mask" in output.keys() and "anomaly_maps" in output.keys(): pixel_metric.cpu() - pixel_metric.update(output["anomaly_maps"], output["mask"].int()) + pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) @staticmethod def _post_process(outputs: STEP_OUTPUT) -> None: diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 0927abed35..0ec0f908e2 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -24,7 +24,7 @@ from anomalib.data.utils import DownloadInfo, download_and_extract from anomalib.models.components import AnomalyModule -from .torch_model import EfficientAdModel, EfficientAdModelSize +from .torch_model import EfficientAdModel, EfficientAdModelSize, reduce_tensor_elems logger = logging.getLogger(__name__) @@ -192,19 +192,8 @@ def _get_quantiles_of_maps(self, maps: list[Tensor]) -> tuple[Tensor, Tensor]: Returns: tuple[Tensor, Tensor]: Two scalars - the 90% and the 99.5% quantile. """ - maps_flat = torch.flatten(torch.cat(maps)) - # torch.quantile only works with input size up to 2**24 elements, see - # https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291 - # if we have more elements we need to decrease the size - # we do this by sampling random elements of maps_flat because then - # the locations of the quantiles (90% and 99.5%) will still be - # valid even though they might not be the exact quantiles. - max_input_size = 2**24 - if len(maps_flat) > max_input_size: - # select a random subset with max_input_size elements. - perm = torch.randperm(len(maps_flat), device=self.device) - idx = perm[:max_input_size] - maps_flat = maps_flat[idx] + + maps_flat = reduce_tensor_elems(torch.cat(maps)) qa = torch.quantile(maps_flat, q=0.9).to(self.device) qb = torch.quantile(maps_flat, q=0.995).to(self.device) return qa, qb diff --git a/src/anomalib/models/efficient_ad/torch_model.py b/src/anomalib/models/efficient_ad/torch_model.py index 13aa56cdad..acda83f43c 100644 --- a/src/anomalib/models/efficient_ad/torch_model.py +++ b/src/anomalib/models/efficient_ad/torch_model.py @@ -24,6 +24,29 @@ def imagenet_norm_batch(x): return x_norm +def reduce_tensor_elems(tensor: torch.Tensor, m=2**24) -> torch.Tensor: + """Flattens n-dimensional tensors, selects m elements from it + and returns the selected elements as tensor. It is used to select + at most 2**24 for torch.quantile operation, as it is the maximum + supported number of elements. + https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291 + + Args: + tensor (torch.Tensor): input tensor from which elements are selected + m (int): number of maximum tensor elements. Default: 2**24 + + Returns: + Tensor: reduced tensor + """ + tensor = torch.flatten(tensor) + if len(tensor) > m: + # select a random subset with m elements. + perm = torch.randperm(len(tensor), device=tensor.device) + idx = perm[:m] + tensor = tensor[idx] + return tensor + + class EfficientAdModelSize(str, Enum): """Supported EfficientAd model sizes""" @@ -123,7 +146,6 @@ class Decoder(nn.Module): def __init__(self, out_channels, padding, img_size, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.img_size = img_size - self.last_upsample = 64 if padding else 56 self.last_upsample = int(img_size / 4) if padding else int(img_size / 4) - 8 self.deconv1 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) self.deconv2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) @@ -279,6 +301,7 @@ def forward(self, batch: Tensor, batch_imagenet: Tensor = None) -> Tensor | dict if self.training: # Student loss + distance_st = reduce_tensor_elems(distance_st) d_hard = torch.quantile(distance_st, 0.999) loss_hard = torch.mean(distance_st[distance_st >= d_hard]) student_output_penalty = self.student(batch_imagenet)[:, : self.teacher_out_channels, :, :]