Skip to content

Commit

Permalink
Minor fixes (#1182)
Browse files Browse the repository at this point in the history
* squeeze output map

* add resize from dataset config, openvino warning

* fix precommit

* OpenVino log warning

* pre commit

* squeeze maps and mask

* quantile element reduction

* docstring, fixes

* f strings

---------

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
alexriedel1 and samet-akcay authored Jul 20, 2023
1 parent d3ebced commit 2083c51
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 18 deletions.
8 changes: 7 additions & 1 deletion src/anomalib/data/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,16 @@ def get_transforms(
if isinstance(config, DictConfig):
logger.info("Loading transforms from config File")
transforms_list = []

if "Resize" not in config.keys() and image_size is not None:
resize_height, resize_width = get_image_height_and_width(image_size)
transforms_list.append(A.Resize(height=resize_height, width=resize_width, always_apply=True))
logger.info("Resize %s added!", (resize_height, resize_width))

for key, value in config.items():
if hasattr(A, key):
transform = getattr(A, key)(**value)
logger.info(f"Transform {transform} added!")
logger.info("Transform %s added!", transform)
transforms_list.append(transform)
else:
raise ValueError(f"Transformation {key} is not part of albumentations")
Expand Down
5 changes: 4 additions & 1 deletion src/anomalib/deploy/inferencers/openvino_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import logging
from importlib.util import find_spec
from pathlib import Path
from typing import Any
Expand All @@ -18,10 +19,12 @@

from .base_inferencer import Inferencer

logger = logging.getLogger("anomalib")

if find_spec("openvino") is not None:
from openvino.runtime import Core
else:
raise ImportError("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.")
logger.warning("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.")


class OpenVINOInferencer(Inferencer):
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _collect_outputs(
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output.keys() and "anomaly_maps" in output.keys():
pixel_metric.cpu()
pixel_metric.update(output["anomaly_maps"], output["mask"].int())
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))

@staticmethod
def _post_process(outputs: STEP_OUTPUT) -> None:
Expand Down
17 changes: 3 additions & 14 deletions src/anomalib/models/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from anomalib.data.utils import DownloadInfo, download_and_extract
from anomalib.models.components import AnomalyModule

from .torch_model import EfficientAdModel, EfficientAdModelSize
from .torch_model import EfficientAdModel, EfficientAdModelSize, reduce_tensor_elems

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -192,19 +192,8 @@ def _get_quantiles_of_maps(self, maps: list[Tensor]) -> tuple[Tensor, Tensor]:
Returns:
tuple[Tensor, Tensor]: Two scalars - the 90% and the 99.5% quantile.
"""
maps_flat = torch.flatten(torch.cat(maps))
# torch.quantile only works with input size up to 2**24 elements, see
# https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291
# if we have more elements we need to decrease the size
# we do this by sampling random elements of maps_flat because then
# the locations of the quantiles (90% and 99.5%) will still be
# valid even though they might not be the exact quantiles.
max_input_size = 2**24
if len(maps_flat) > max_input_size:
# select a random subset with max_input_size elements.
perm = torch.randperm(len(maps_flat), device=self.device)
idx = perm[:max_input_size]
maps_flat = maps_flat[idx]

maps_flat = reduce_tensor_elems(torch.cat(maps))
qa = torch.quantile(maps_flat, q=0.9).to(self.device)
qb = torch.quantile(maps_flat, q=0.995).to(self.device)
return qa, qb
Expand Down
25 changes: 24 additions & 1 deletion src/anomalib/models/efficient_ad/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ def imagenet_norm_batch(x):
return x_norm


def reduce_tensor_elems(tensor: torch.Tensor, m=2**24) -> torch.Tensor:
"""Flattens n-dimensional tensors, selects m elements from it
and returns the selected elements as tensor. It is used to select
at most 2**24 for torch.quantile operation, as it is the maximum
supported number of elements.
https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291
Args:
tensor (torch.Tensor): input tensor from which elements are selected
m (int): number of maximum tensor elements. Default: 2**24
Returns:
Tensor: reduced tensor
"""
tensor = torch.flatten(tensor)
if len(tensor) > m:
# select a random subset with m elements.
perm = torch.randperm(len(tensor), device=tensor.device)
idx = perm[:m]
tensor = tensor[idx]
return tensor


class EfficientAdModelSize(str, Enum):
"""Supported EfficientAd model sizes"""

Expand Down Expand Up @@ -123,7 +146,6 @@ class Decoder(nn.Module):
def __init__(self, out_channels, padding, img_size, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.img_size = img_size
self.last_upsample = 64 if padding else 56
self.last_upsample = int(img_size / 4) if padding else int(img_size / 4) - 8
self.deconv1 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
self.deconv2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
Expand Down Expand Up @@ -279,6 +301,7 @@ def forward(self, batch: Tensor, batch_imagenet: Tensor = None) -> Tensor | dict

if self.training:
# Student loss
distance_st = reduce_tensor_elems(distance_st)
d_hard = torch.quantile(distance_st, 0.999)
loss_hard = torch.mean(distance_st[distance_st >= d_hard])
student_output_penalty = self.student(batch_imagenet)[:, : self.teacher_out_channels, :, :]
Expand Down

0 comments on commit 2083c51

Please sign in to comment.